In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Parameters
data_dir = '../../../bucket/wikipedia/1000docs_19513contexts_30maxtokens/'
contexts_filename = 'contexts.pickle'
acts_filename = 'activations.npz'

layers = ['arr_0','arr_3','arr_6', 'arr_9', 'arr_12']  # which layers to visualize
# layers = ['arr_0']  # good for debugging
reduction, dim = 'PCA', 2
view_vis_as_html = False  # If True, running the vis will also generate an interactive html file and open it

In [None]:
# Imports
# LOAD
import pickle
import numpy as np
import os
import sys
project_path = os.path.abspath('../../..')
sys.path.insert(0, project_path)
from src.utils import acts_util
# TAG
import nltk
import re
import pandas as pd
from src.utils import context_util
# VIS
from bokeh.plotting import figure, show, output_file
from bokeh.io import output_notebook
output_notebook()
if view_vis_as_html:
    output_file('visualize-wiki.html')
from bokeh.models import Div
from bokeh.layouts import gridplot
from src.utils import vis_util

# Loading contexts

In [None]:
# Load contexts and acts
with open(os.path.join(os.path.abspath(data_dir), contexts_filename), 'rb') as f:
    contexts_list = pickle.load(f)
acts_npz = np.load(os.path.join(data_dir, acts_filename))


In [None]:
layer_to_acts = {layer: acts_npz[layer] for layer in layers}

In [None]:
# Reductions
reduced_acts = {layer:acts_util.reduce_acts(acts, reduction, dim) for (layer,acts) in layer_to_acts.items()}

# Properties of contexts

In [None]:
contexts = pd.DataFrame()
# basics
contexts['tokens'] = [toks for toks, position in contexts_list]
contexts['position'] = [position for toks, position in contexts_list]
contexts['context html'] = contexts['tokens'].combine(contexts['position'], context_util.context_html)
contexts['context length'] = contexts['tokens'].apply(len)
contexts['abbreviated context'] = contexts['tokens'].combine(contexts['position'], context_util.abbreviated_context)
contexts['abbreviated context html'] = contexts['tokens'].combine(contexts['position'], context_util.abbreviated_context_html)
contexts['token'] = contexts['tokens'].combine(contexts['position'], lambda toks,position: toks[position])
contexts['doc'] = contexts['tokens'].apply(lambda toks: ' '.join(toks))
# activations
for layer in layers:
    contexts[f'{layer} x'] = layer_to_acts[layer][:,0]
    contexts[f'{layer} y'] = layer_to_acts[layer][:,1] 
for layer in layers:
    contexts[f'{layer} {reduction} x'] = reduced_acts[layer][:,0]
    contexts[f'{layer} {reduction} y'] = reduced_acts[layer][:,1]    

# subspace activations
subspaces_to_inspect = {} 
toks_to_inspect = ['[CLS]', '[SEP]', '.', 'the', ',','born',]
for tok in toks_to_inspect:
    subspaces_to_inspect[tok] = (contexts[contexts['token']==tok]).index
for (name, context_idxs) in subspaces_to_inspect.items():
    print(f'{name} ({len(context_idxs)} contexts)')
    for layer in layers:
        subspace_acts = layer_to_acts[layer][context_idxs]
        subspace_reduced_acts = acts_util.reduce_acts(subspace_acts, reduction, dim)
        contexts.loc[context_idxs, f'{layer} "{name}" {reduction} x'] = subspace_reduced_acts[:,0] 
        contexts.loc[context_idxs, f'{layer} "{name}" {reduction} y'] = subspace_reduced_acts[:,1] 

In [None]:
# properties
def reverse_position(toks, position): return len(toks)-1-position
contexts['position from end'] = contexts['tokens'].combine(contexts['position'], reverse_position)
def POS_tag(toks, position): return nltk.pos_tag(toks)[position][1]
contexts['POS'] = contexts['tokens'].combine(contexts['position'], POS_tag)
contexts['CLS'] = contexts['token']=='[CLS]'
contexts['SEP'] = contexts['token']=='[SEP]'
contexts['.'] = contexts['token']=='.'
contexts['token length'] = contexts['token'].apply(len)
contexts['1st'] = contexts['position']==0
contexts['2nd'] = contexts['position']==1
contexts['nth'] = contexts['position']+1==contexts['context length']
contexts['(n-1)th'] = contexts['position']+2==contexts['context length']
contexts['(n-2)th'] = contexts['position']+1+2==contexts['context length']
def is_capitalized(tok): return bool(re.match('[A-Z]', tok))
contexts['capitalized'] = contexts['token'].apply(is_capitalized)
def is_partial(tok): return tok.startswith('##')
contexts['partial'] = contexts['token'].apply(is_partial)
def has_number(tok): return bool(re.search('[0-9]', tok))
contexts['has number'] = contexts['token'].apply(has_number)
def is_month(tok): return tok in ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
contexts['is month'] = contexts['token'].apply(is_month)
def is_year(tok): return bool(re.match('^[12][0-9]{3}$', tok))
contexts['is year'] = contexts['token'].apply(is_year)
def before_partial(toks, pos): return (pos+2 < len(toks)) and is_partial(toks[pos+2])
contexts['before partial'] = contexts['tokens'].combine(contexts['position'], before_partial)
def before_double_capitals(toks, pos): 
    return (pos+2 < len(toks)) and is_capitalized(toks[pos+1]) and is_capitalized(toks[pos+2])
contexts['before double capitals'] = contexts['tokens'].combine(contexts['position'], before_double_capitals)
def not_before_sep(toks, pos): return (pos+1 < len(toks)) and toks[pos+1]!='[SEP]'
contexts['not_before_sep'] = contexts['tokens'].combine(contexts['position'], not_before_sep)
def after_year(toks, pos): return (pos-1 >= 0) and is_year(toks[pos-1])
contexts['after_year'] = contexts['tokens'].combine(contexts['position'], after_year)
def is_initial(tok): return bool(re.match('^[A-Z]$', tok))
def after_initial(toks, pos): return (pos-1 >= 0) and is_initial(toks[pos-1])
contexts['after_initial'] = contexts['tokens'].combine(contexts['position'], after_initial)
def is_number(tok): return bool(re.match('^[0-9]+$', tok))
def number_seperator(toks, pos): return (pos-1 >= 0) and is_number(toks[pos-1]) and pos+1 < len(toks) and is_number(toks[pos+1])
contexts['number_seperator'] = contexts['tokens'].combine(contexts['position'], number_seperator)
def after_partial(toks, pos): return (pos-1 >= 0) and is_partial(toks[pos-1])
contexts['after_partial'] = contexts['tokens'].combine(contexts['position'], after_partial)
def after_capitalized(toks, pos): return (pos-1 >= 0) and is_capitalized(toks[pos-1])
contexts['after_capitalized'] = contexts['tokens'].combine(contexts['position'], after_capitalized)
def is_CC(tok): return tok in ['and', 'but']
def before_CC(toks, pos): return (pos+1 < len(toks)) and is_CC(toks[pos+1])
contexts['before_CC'] = contexts['tokens'].combine(contexts['position'], before_CC)
def is_date(tok): return bool(re.match('^([1-9]|[12][0-9]|3[01])$', tok))
def date_separator(toks, pos): 
    return  (pos-1 >= 0 and
             pos+1 < len(toks) and 
             toks[pos] == ',' and
             is_date(toks[pos-1]) and 
             is_year(toks[pos+1])
            )
contexts['date_separator'] = contexts['tokens'].combine(contexts['position'], date_separator)
def before_capitalized(toks, pos): return (pos+1 < len(toks)) and is_capitalized(toks[pos+1])
contexts['before_capitalized'] = contexts['tokens'].combine(contexts['position'], before_capitalized)
def _between_caps(toks, pos): 
    return (
        (pos-1 >= 0) and 
        (is_capitalized(toks[pos-1]) or is_partial(toks[pos-1])) and 
        (pos+1 < len(toks) and is_capitalized(toks[pos+1]))
    )
contexts['_between_caps'] = contexts['tokens'].combine(contexts['position'], _between_caps)
def between_caps(toks, pos): return (pos-1 >= 0) and is_capitalized(toks[pos-1]) and pos+1 < len(toks) and is_capitalized(toks[pos+1])
contexts['between_caps'] = contexts['tokens'].combine(contexts['position'], between_caps)

# Visualize contexts

In [None]:
# Fresh visualization
layer_labels = [Div(text=layer, align=('center', 'center')) for layer in layers]
visualizations = [[None]+layer_labels]

In [None]:
# Global space
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, []))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('CLS','.','SEP')))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('1st','2nd','(n-2)th','(n-1)th','nth')))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('position',)))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('token length',)))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('capitalized','partial')))
# visualizations.append(vis_util.visualize_columns(contexts, layers, reduction, ('has number','is month')))

In [None]:
# Local subspaces
visualizations.append(vis_util.visualize_columns(contexts, layers, f'"[CLS]" {reduction}', []))

# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"[SEP]" {reduction}', []))
visualizations.append(vis_util.visualize_columns(contexts, layers, f'"[SEP]" {reduction}', ('position',)))

# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"." {reduction}', []))
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"." {reduction}', ('position',)))
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"." {reduction}', (
#     'not_before_sep','after_year',
# )))
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"." {reduction}', (
#     'not_before_sep','after_capitalized','after_partial','after_year','after_initial', 'number_seperator' 
# )))

# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"," {reduction}', []))
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"," {reduction}', ('position',)))
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"," {reduction}', (
#     'after_capitalized','after_partial','after_year','before_CC', 'number_seperator', 'date_separator', 
#     '_between_caps'
# )))

In [None]:
# visualizations.append(vis_util.visualize_columns(contexts, layers, f'"born" {reduction}', []))

In [None]:
show(gridplot(zip(*visualizations)))
