In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Parameters
data_dir = '../../bucket/wikipedia/1000docs_19513contexts_30maxtokens'
contexts_filename = 'contexts.pickle'
acts_filename = 'activations.npz'
view_vis_in_notebook = True  # If False, running the vis will instead generate an interactive html file and open it

# Loading contexts

In [None]:
import pickle
import numpy as np
import os
import sys
project_path = os.path.abspath('../..')
sys.path.insert(0, project_path)
from src.utils import acts_util

In [None]:
# Load contexts and acts
with open(os.path.join(os.path.abspath(data_dir), contexts_filename), 'rb') as f:
    contexts_list = pickle.load(f)
acts = np.load(os.path.join(data_dir, acts_filename))


In [None]:
layer_to_acts = {layer: acts[layer] for layer in acts.files}

In [None]:
# Reductions
reduction, dim = ('PCA', 2)
PCA_acts = {layer:acts_util.reduce_activations(acts, reduction, dim) for (layer,acts) in layer_to_acts.items()}

# Properties of contexts

In [None]:
import nltk
import re
import pandas as pd
from src.utils import context_util

In [None]:
contexts = pd.DataFrame()
# basics
contexts['tokens'] = [toks for toks, position in contexts_list]
contexts['position'] = [position for toks, position in contexts_list]
contexts['context'] = contexts['tokens'].combine(contexts['position'], context_util.context_str)
contexts['context length'] = contexts['tokens'].apply(len)
contexts['abbreviated context'] = contexts['tokens'].combine(contexts['position'], context_util.abbreviated_context)
contexts['abbreviated context html'] = contexts['tokens'].combine(contexts['position'], context_util.abbreviated_context_html)
contexts['token'] = contexts['tokens'].combine(contexts['position'], lambda toks,position: toks[position])
contexts['doc'] = contexts['tokens'].apply(lambda toks: ' '.join(toks))
# activations
for layer in PCA_acts:
    contexts[f'{layer} PCA x'] = PCA_acts[layer][:,0]
    contexts[f'{layer} PCA y'] = PCA_acts[layer][:,1]    
# more
def reverse_position(toks, position): return len(toks)-1-position
contexts['position from end'] = contexts['tokens'].combine(contexts['position'], reverse_position)
def POS_tag(toks, position): return nltk.pos_tag(toks)[position][1]
contexts['POS'] = contexts['tokens'].combine(contexts['position'], POS_tag)
contexts['CLS'] = contexts['token']=='[CLS]'
contexts['SEP'] = contexts['token']=='[SEP]'
contexts['.'] = contexts['token']=='.'
contexts['token length'] = contexts['token'].apply(len)
contexts['1st'] = contexts['position']==0
contexts['nth'] = contexts['position']+1==contexts['context length']
contexts['(n-1)th'] = contexts['position']+2==contexts['context length']
def capitalized(tok): return bool(re.match('[A-Z]', tok))
contexts['capitalized'] = contexts['token'].apply(is_capitalized)
def partial(tok): return tok.startswith('##')
contexts['partial'] = contexts['token'].apply(is_partial)

contexts

# Vis

In [None]:
from bokeh.plotting import figure, show, output_file
from bokeh.io import output_notebook
if view_vis_in_notebook:
    output_notebook()
else:
    output_file('visualize-wiki.html')
from bokeh.layouts import gridplot
from bokeh.models import Div, HoverTool, ColumnDataSource, PanTool, BoxZoomTool, WheelZoomTool, ResetTool
from bokeh.palettes import Inferno, Category10, Category20, Category20c, Pastel1, Pastel2, Bokeh, Plasma
from bokeh.models.annotations import Legend, LegendItem
import math
from src.utils import vis_util

In [None]:
combos_to_visualize = [
    ('CLS', '.', 'SEP'), 
    ('1st', '(n-1)th', 'nth'), 
    ('token length',),
    ('capitalized',), 
    ('partial',),
    ('position',),
]

rows = []
header = [None] + [Div(text=' / '.join(combo), align='center') for combo in combos_to_visualize]
rows.append(header)

# for layer_idx, layer in enumerate(['arr_0']):  # for quick debugging
for layer_idx, layer in enumerate(['arr_0', 'arr_1', 'arr_3', 'arr_6', 'arr_9', 'arr_12']):
    row = []
    row.append(Div(text=layer, align=('center', 'center')))
    for combo in combos_to_visualize:
        # Make blank figure
        p = figure(width=200, height=200)
        p.axis.visible = False
        p.grid.visible = False
        if layer_idx == 0: # add legend
            p.height=200
            p.add_layout(Legend(orientation='horizontal', label_text_font_size='6pt', label_width=10), 'above')
       
        # Figure visualizes a few binary columns
        if len(combo)>1:
            # add all contexts in grey
            source = ColumnDataSource(
                { 
                    'x': contexts[f'{layer} PCA x'],
                    'y': contexts[f'{layer} PCA y'],
                    'token': contexts['token'],
                    'abbreviated context': contexts['abbreviated context'],
                    'abbreviated context html': contexts['abbreviated context html'],
                    'context': contexts['context']
                }
            )
            p.circle('x', 'y', color='lightgrey', source=source)
            # add each columns' contexts in a color
            for col_idx, col in enumerate(combo):
                selected_contexts = contexts[ contexts[col] ]
                source = ColumnDataSource(
                    { 
                        'x': selected_contexts[f'{layer} PCA x'],
                        'y': selected_contexts[f'{layer} PCA y'],
                        'token': selected_contexts['token'],
                        'abbreviated context': selected_contexts['abbreviated context'],
                        'abbreviated context html': selected_contexts['abbreviated context html'],
                        'context': selected_contexts['context'],
                    }
                )
                if layer_idx == 0: # add legend
                    p.circle('x', 'y', color=default_palette[col_idx], legend_label=col, source=source)
                else:
                    p.circle('x', 'y', color=default_palette[col_idx], source=source)

        
        # Figure visualizes one categorical column
        else:
            col = combo[0]
            source = ColumnDataSource(
                {
                    'x': contexts[f'{layer} PCA x'], 
                    'y': contexts[f'{layer} PCA y'],
                    'color': vis_util.categorical_list_to_color_list(contexts[col]),
                    'legend label': contexts[col],
                    'token': contexts['token'],
                    'abbreviated context': contexts['abbreviated context'],
                    'abbreviated context html': contexts['abbreviated context html'],
                    'context': contexts['context']
                }
            )
            if layer_idx == 0: # add legend
                p.circle('x', 'y', color='color', legend_group='legend label', source=source)
            else:
                p.circle('x', 'y', color='color', source=source)
            
        # Add hover
        hover = HoverTool(
            tooltips = 
                """<div style=
                '
                border-bottom-style:solid;
                border-width:1px;
                '
                >@{abbreviated context html}</div>"""
        )
        p.tools = [PanTool(), WheelZoomTool(), BoxZoomTool(), ResetTool(), hover]
        row.append(p)
    
    rows.append(row)
grid = gridplot(rows)
show(grid)