# Visualizing sample and all masked variants

## Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np
from scipy.spatial import distance_matrix
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_distances
import glob
import warnings
from typing import List
from bokeh.layouts import gridplot
from bokeh.plotting import figure, show, output_file
from bokeh.io import output_notebook
from bokeh.models import Div, HoverTool, ColumnDataSource, PanTool, BoxZoomTool, WheelZoomTool, ResetTool
from IPython.core.display import display, HTML
output_notebook()

import sys
sys.path.insert(0, os.path.abspath('../../..')) # -> vis -> src
from src.utils import acts_util, vis_util, html_util, context_util, bert_util
from src.utils.vis_util import Grid, plot_evolution
from src import references as refs

## Load tokens and activations

In [None]:
# set sample directory
dir_path = os.path.abspath('../../../data/sentences/art/')

In [None]:
# load toks and acts
doc = pickle.load(open(os.path.join(dir_path, refs.toks_fn), 'rb'))
doc_acts = np.load(os.path.join(dir_path, refs.acts_fn))
layer_names = acts_npz.files
layers = list(range(len(layer_names)))
print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layer_names)}')

In [None]:
# get all variants
variants_docs = [bert_util.mask(tokens, (mask_idx,)) for mask_idx in range(len(tokens))]
variants_contexts, variants_acts = bert_util.get_contexts_and_acts(custom_seqs, tokenized=True)

In [None]:
vis_act_size = True
if vis_act_size:
    variants_acts_sizes = [acts_util.normalize((np.linalg.norm(variants_acts[layer], axis=1))) for layer in layers]

In [None]:
# reductions for vis
reduction_2d = 'KernelPCA'
variants_acts_2d = [acts_util.reduce_acts(variants_acts[layer], reduction=reduction_2d, dim=2) for layer in layers]
reduction, dim = 'KMeans', 6
variants_acts_reduced = [acts_util.reduce_acts(variants_acts[layer], reduction=reduction, dim=dim) for layer in layers]

## Vis evolution of clusters

In [None]:

max_font_size = 100
pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
labels = [np.argmax(variants_acts_reduced[layer], axis=1) for layer in layers]
rgbs = [vis_util.channels_to_rgbs(reduced_acts) for reduced_acts in variants_acts_reduced]
for layer in layers:
    print(layer_names[layer])
    clusters_contexts = []
    clusters_acts_sizes = []
    clusters = list(range(dim))
    for cluster in clusters:
        cluster_contexts_idxs = np.where(labels[layer]==cluster)[0]
        cluster_contexts = [variants_contexts[context_idx] for context_idx in cluster_contexts_idxs]
        cluster_contexts.sort(key=lambda context: context[1])
        cluster_contexts.sort(key=lambda context: context[0][context[1]]=='[MASK]')
        clusters_contexts.append(cluster_contexts)
        clusters_acts_sizes.append(variants_acts_sizes[layer][cluster_contexts_idxs])
    clusters.sort(key=lambda cluster: clusters_contexts[cluster])
    for i, cluster in enumerate(clusters):
        rgb = pure_rgbs[i]
        marker = html_util.highlighter(color=html_util.rgb_to_color(*rgb))
        contexts = clusters_contexts[cluster]
        for context_idx, context in enumerate(contexts):
            html = context_util.context_html(*context, marker=marker) 
            if vis_act_size:
                html = html_util.font_size(html, clusters_acts_sizes[cluster][context_idx]*max_font_size)
            display(HTML(html))
    print('\n\n')

## Visualize evolution of points

In [None]:
points = variants_acts_2d
labels = [context_util.context_html(*context) for context in variants_contexts]
rgbs = [vis_util.channels_to_rgbs(reduced_acts) for reduced_acts in variants_acts_reduced]
distance_mats = [cosine_distances(variants_acts[layer]) for layer in layers]

grid = Grid(row_names=layer_names)
grid.add_column(reduction, plot_evolution(points=points, labels=labels, colors=rgbs, distance_mats=None ))
grid.show()