# Params

In [None]:
# single_doc_dir = True
# data_dir = '../../data/alice/sample3'


In [None]:
single_doc_dir = False
data_dir = '../../../bucket/wikipedia/1000docs_19513contexts_30maxtokens/'
which_doc = 101

In [None]:
# layers = ['arr_0','arr_3','arr_6', 'arr_9', 'arr_12']  # which layers to visualize
layers = [f'arr_{i}' for i in range(13)]
# layers = ['arr_0']  # good for debugging
reductions = [('KernelPCA',2)]
view_vis_as_html = False  # If True, running the vis will also generate an interactive html file and open it

# Imports

In [None]:
import os
import pickle
import numpy as np
import sys
import math

from bokeh.plotting import figure, show, output_file
from bokeh.io import output_notebook
output_notebook()
if view_vis_as_html:
    output_file('vis.html')
from bokeh.models import Label, LabelSet, Div, ColumnDataSource, Legend, LegendItem, Range1d
from bokeh.models import HoverTool, CustomJS, PanTool, BoxZoomTool, WheelZoomTool, ResetTool, TapTool, OpenURL
from bokeh.models.glyphs import Circle
from bokeh.layouts import gridplot
from bokeh import events
from bokeh.palettes import Inferno, Category10, Category20, Category20c, Pastel1, Pastel2, Bokeh, Plasma


project_path = os.path.abspath('../../..')
sys.path.insert(0, project_path)
from src.utils import context_util, vis_util, html_util, acts_util

# Load

In [None]:
if single_doc_dir:
    tokens_path = os.path.join(data_dir, "tokens.pickle")
    with open(tokens_path, 'rb') as f:
        doc = pickle.load(f)
    # which acts correspond to this doc:
    doc_ids = range(len(doc))  

else:
    with open(os.path.join(os.path.abspath(data_dir), 'contexts.pickle'), 'rb') as f:
        contexts = pickle.load(f)
    doc_number = 101
    # which acts correspond to this doc:
    doc_ids = context_util.get_doc_ids(contexts, doc_number)
    doc, _ = contexts[doc_ids[0]]

print(doc)
acts_path = os.path.join(data_dir, f"activations.npz")
layer_to_acts = np.load(acts_path)

# Calculate intertok distances

In [None]:
layer_to_intertok_distances = {layer:np.full((len(doc), len(doc)), 0.) for layer in layers}
for layer in layers:
    acts = layer_to_acts[layer]
    intertok_distances = layer_to_intertok_distances[layer]
    for positionA in range(len(doc)):
        tokA_acts = acts[doc_ids[positionA]]
        for positionB in range(positionA+1, len(doc)):
            tokB_acts = acts[doc_ids[positionB]]
            distance = np.linalg.norm(tokA_acts-tokB_acts)
            intertok_distances[positionA, positionB] = distance
            intertok_distances[positionB, positionA] = distance

In [None]:
# Normalize distances to be between 0 and 1

all_layers_intertok_distances = np.array([distances for layer, distances in layer_to_intertok_distances.items()])
max_val = np.max(all_layers_intertok_distances)
print(all_layers_intertok_distances.shape)
all_layers_intertok_distances /= max_val  # scale to be between 0 and 1

layer_to_normalized_intertok_distances = {}
for layer_idx, layer in enumerate(layers):
    layer_to_normalized_intertok_distances[layer] = all_layers_intertok_distances[layer_idx]

# Visualize

In [None]:
start_pos, end_pos = 0,-2
phrase = doc[start_pos:end_pos]
phrase_ids = doc_ids[start_pos:end_pos]

In [None]:
plots = []
for layer in layers:
    intertok_distances = layer_to_normalized_intertok_distances[layer]
    xs = [0]
    dists = [0]
    for tok_pos in range(start_pos+1, len(doc)+end_pos):
        dists.append(f'{intertok_distances[tok_pos-1][tok_pos]:.2f}')
        xs.append(xs[-1] + intertok_distances[tok_pos-1][tok_pos])
#         xs.append(xs[-1] + tok_pos)
    
    toks_source = ColumnDataSource({'x': xs, 'dist': dists,  'y': [1.5]*len(phrase), 'label': phrase, 
                                    # 'alpha': [(1-tok_movement[tok_idx]) for tok_idx in range(len(phrase))],
                                    'hover label': phrase
                                   })    
    p = vis_util.empty_plot(width=100, height=50)
    tok_points = p.circle(x='x', y='y', color='red', size=5, source=toks_source)
    p.circle(x='x', y='y', color=None, size=5, line_color='red', source=toks_source)
#     p.add_layout(LabelSet(x='x', y='y', y_offset='4', text='label', text_font_size='10pt', text_align='center', source=toks_source))
#     p.add_layout(LabelSet(x='x', y='y', y_offset='4', text='dist', text_font_size='10pt', text_align='center', source=toks_source))
    wheelzoomtool = WheelZoomTool()
    hover_tool = HoverTool(tooltips=vis_util.custom_bokeh_tooltip('hover label', border=False), renderers=[tok_points])
    p.tools = [PanTool(), wheelzoomtool, ResetTool(), hover_tool]
    p.toolbar.active_scroll = wheelzoomtool
    p.x_range = Range1d(-1, len(phrase))
    plots.append(p)
    p.outline_line_color = None
layer_labels = [Div(text=layer, align='start') for layer in layers]
show(gridplot(zip(*[layer_labels, plots]), toolbar_options={'logo': None}))