# Todo: This notebook is currently broken bc of some refactoring.

In [None]:
# Parameters
data_dir = '../../bucket/wikipedia/1000docs_19513contexts_30maxtokens/'
contexts_filename = 'contexts.pickle'
acts_filename = 'activations.npz'

# layers = ['arr_0','arr_3','arr_6', 'arr_9', 'arr_12']  # which layers to visualize
layers = [f'arr_{i}' for i in range(13)]
# layers = ['arr_0']  # good for debugging

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
# LOAD
import pickle
import numpy as np
import math
import os
import sys
project_path = os.path.abspath('../..')
sys.path.insert(0, project_path)
from src.utils import context_util, vis_util, html_util, acts_util
# ANALYZE
from sklearn.cluster import KMeans
# VIS
from IPython.core.display import display, HTML
from bokeh.palettes import Inferno, Category10, Category20, Category20c, Pastel1, Pastel2, Bokeh, Plasma, Colorblind
from bokeh.plotting import figure, show, output_file
from bokeh.models import ColumnDataSource, Label, LabelSet, Range1d, Div, Range1d, HoverTool
from bokeh.layouts import gridplot, row, column
from bokeh.models.annotations import Legend, LegendItem
from bokeh.io import output_notebook
output_notebook()

# Loading

In [None]:
# Load contexts and acts
with open(os.path.join(os.path.abspath(data_dir), contexts_filename), 'rb') as f:
    contexts = pickle.load(f)
acts_npz = np.load(os.path.join(data_dir, acts_filename))
layer_to_acts = {layer: acts_npz[layer] for layer in layers}

# Helpers

In [None]:
def shift(vals):
    return vals - np.nanmin(vals)

def normalize(vals):
    vals = vals - np.nanmin(vals) # starting at 0
    vals = vals / np.nanmax(vals)  # between 0 and 1
    return vals

# # test
# n = normalize(np.random.rand(10))
# np.min(n), np.mean(n), np.max(n)

# Choose single document

In [None]:
# One doc
which_doc = 210
doc_ids = context_util.get_doc_ids(contexts, which_doc)
doc, _ = contexts[doc_ids[0]]
print(context_util.doc_str(doc))

In [None]:
doc = [tok if tok!='Who' else f'Who{pos}' for pos, tok in enumerate(doc)]

In [None]:
layer_to_intertok_distances = {layer:np.full((len(doc), len(doc)), 0.) for layer in layers}
for layer in layers:
    acts = layer_to_acts[layer]
    intertok_distances = layer_to_intertok_distances[layer]
    for posA in range(len(doc)):
        tokA_id = doc_ids[posA]
        tokA_acts = acts[tokA_id]
        for posB in range(posA+1, len(doc)):
            tokB_id = doc_ids[posB]
            tokB_acts = acts[tokB_id]
            distance = np.linalg.norm(tokA_acts-tokB_acts)
            intertok_distances[posA,posB] = distance
            intertok_distances[posB,posA] = distance

In [None]:
# Normalize distances to be between 0 and 1
all_layers_intertok_distances = np.array([distances for layer, distances in layer_to_intertok_distances.items()])
max_val = np.max(all_layers_intertok_distances, axis=(1,2))[:,None,None]
all_layers_intertok_distances = all_layers_intertok_distances / max_val  # scale to be between 0 and 1
layer_to_normalized_intertok_distances = {}
for layer_idx, layer in enumerate(layers):
    layer_to_normalized_intertok_distances[layer] = all_layers_intertok_distances[layer_idx]

# Token-token similarity

In [None]:
start_pos, end_pos = 2, -2
phrase = doc[start_pos:end_pos]
green_highlighter = lambda tok: html_util.highlight_html(tok, color='limegreen')
palette = Category10[10]
dim = 200

# get pair evolutions
pairs = []
pair_distance = []
for tok_i in range(len(phrase)):
    for tok_j in range(tok_i+1, len(phrase)):
        pairs.append((tok_i, tok_j))
        pair_distance.append([layer_to_intertok_distances[layer][tok_i,tok_j] for layer in layers])

# based on their evolutions, cluster the pairs into a few prototypical evolutions
pair_cluster_assignments = KMeans(n_clusters=4).fit(pair_distance).labels_
cluster_plots = {cluster: figure(width=dim, height=dim) for cluster in set(pair_cluster_assignments)}
cluster_html = {cluster: '' for cluster in set(pair_cluster_assignments)}

# plot
main_plot = figure(width=dim, height=dim)
main_plot.add_layout(Legend(orientation='horizontal', label_text_font_size='6pt', label_width=10), 'above')
for pair_idx in range(len(pairs)):
    pair = pairs[pair_idx]
    distances = pair_distance[pair_idx]
    cluster = pair_cluster_assignments[pair_idx]
    main_plot.line(range(len(distances)), distances, color=palette[cluster])
    cluster_plots[cluster].line(range(len(distances)), distances, color=palette[cluster])
    cluster_html[cluster] += f'''
        <div style= 'font-size:8pt;'>
        {context_util.multi_context_str(doc, list(pair), marker=green_highlighter)}
        </div>
        '''
plots = [main_plot]+list(cluster_plots.values())
divs = [None] + [Div(text=html, width=800) for html in list(cluster_html.values())]
# plot properties
for plot in plots: # set axes
    plot.xaxis.ticker = list(range(len(layers)))
    plot.xaxis.major_label_overrides = {i:layer for i, layer in enumerate(layers)}
    plot.xaxis.major_label_text_font_size = '6pt'
    plot.x_range = Range1d(0,len(distances)-1)
#     plot.y_range = Range1d(0.2,1.3)
show(gridplot(zip(*[plots, divs])))

In [None]:
start_pos, end_pos = 2, -2
phrase = doc[start_pos:end_pos]
phrase_ids = doc_ids[start_pos:end_pos]

In [None]:
htmls = {layer: layer for layer in layers}
plots = []
for layer in layers:
    intertok_distances = layer_to_normalized_intertok_distances[layer]
    source = {'x':[], 'y':[], 'alpha':[], 'hover label':[]}
    for posA in range(start_pos, len(doc)+end_pos):
        for posB in range(start_pos, len(doc)+end_pos):
            source['x'].append(doc[posA])
            source['y'].append(doc[posB])
            source['alpha'].append(1-intertok_distances[posA,posB])
            source['hover label'].append(context_util.multi_context_str(doc, [posA, posB], marker=html_util.highlighter()))
#     print(intertok_distances)
    
    p = figure(x_axis_location="above", x_range=phrase, y_range=list(reversed(phrase)), width=200, height=200)
    p.rect(x='x', y='y', width=.9, height=.9, color='purple', alpha='alpha', source=source)
    p.grid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_standoff = 0
    p.xaxis.major_label_orientation = np.pi/3
    p.axis.major_label_text_font_size = "9px"
    p.tools = [HoverTool(tooltips=vis_util.custom_bokeh_tooltip('hover label'))]
    plots.append(p)
show(gridplot([[Div(text=layer, align='center') for layer in layers], plots]))

In [None]:
plots = []
for layer in layers:
    intertok_distances = layer_to_intertok_distances[layer]
    tok_tok_distances = [intertok_distances[tok_pos,tok_pos+1] for tok_pos in range(start_pos, len(doc)+end_pos-1)]

    hist, edges = np.histogram(tok_tok_distances, bins=30)
    
    p = figure(width=500, height=100)
    p.quad(top = hist, bottom=0, left=edges[:-1], right=edges[1:], fill_color="purple", line_color="white", alpha=0.5)
#     p.x_range = Range1d(0, 40)
    p.y_range = Range1d(0,10)
    plots.append(p)
show(gridplot(zip(*[[Div(text=layer, align='center') for layer in layers], plots])))

# Archive

In [None]:
htmls = {layer: layer for layer in layers}
for layer in layers:
    intertok_distances = layer_to_normalized_intertok_distances[layer]   
    for tok_i in range(len(doc)):
        tok_i_html = ''
        for tok_j in range(len(doc)):
            if tok_i != tok_j:
                distance = intertok_distances[tok_i, tok_j]
                color_intensity = int((1-distance) * 255)  # color faraway words
#                 print(distance, color_intensity)
                tok_i_html += f''' <span style='background-color: rgba(0,{color_intensity},0); font-size: 4pt;'>{doc[tok_j]} </span>'''
            else:
                tok_i_html += f''' <span style='background-color: rgba(0,255,255); font-size: 4pt;'>{doc[tok_i]} </span>'''
        htmls[layer] += html_util.box(tok_i_html, css='background-color: black;')
divs = [Div(text=html, width=800) for html in htmls.values()]
show(column(divs))

In [None]:
htmls = {layer: layer for layer in layers}
start_pos, end_pos = 2, -2
sub_doc = doc[start_pos:end_pos]
for layer in layers:
    intertok_distances = layer_to_normalized_intertok_distances[layer][start_pos:end_pos,start_pos:end_pos]
    for tok_i in range(len(sub_doc)):
        tok_i_html = ''
        for tok_j in range(len(sub_doc)):
            if tok_i != tok_j:
                distance = intertok_distances[tok_i, tok_j]
                color_intensity = int((1-distance) * 255)  # color faraway words
#                 print(distance, color_intensity)
                tok_i_html += f''' <span style='background-color: rgba(0,{color_intensity},0)'>{sub_doc[tok_j]} </span>'''
            else:
                tok_i_html += f''' <span style='background-color: rgba(0,255,255)'>{sub_doc[tok_i]} </span>'''
        htmls[layer] += html_util.box(tok_i_html, css='background-color: black;')
divs = [Div(text=html, width=800) for html in htmls.values()]
show(column(divs))

# Token-to-next-token similarity

In [None]:
colors = Inferno[256][::15]
p = figure(width=1000, height=300)
p.add_layout(Legend(orientation='horizontal', label_text_font_size='6pt', label_width=10), 'above')
for i,layer in enumerate(layers):
    doc_acts = layer_to_acts[layer][doc_ids]
    intertok_difference = [doc_acts[i]-doc_acts[i+1] for i in range(len(doc)-2)]
    intertok_distance = [np.linalg.norm(direction) for direction in intertok_difference]
    p.line(range(len(intertok_distance)), intertok_distance, color=colors[i], legend_label=layer)
    p.circle(range(len(intertok_distance)), intertok_distance, color=colors[i])
p.axis.visible = True
p.xaxis.ticker = list(range(len(intertok_distance)))
p.xaxis.major_label_overrides = {x:f'-{doc[x+1]}' for x in range(1, len(intertok_distance))}
p.xaxis.major_label_overrides[0] = f'{doc[0]}-{doc[1]}'
p.xaxis.major_label_text_font_size = '6pt'
show(p)

In [None]:
colors = Inferno[256][::15]
p = figure(width=1000, height=300)
p.add_layout(Legend(orientation='horizontal', label_text_font_size='6pt', label_width=10), 'above')
for i,layer in enumerate(layers):
    doc_acts = layer_to_acts[layer][doc_ids]
    intertok_difference = [doc_acts[i]-doc_acts[i+1] for i in range(len(doc)-2)]
    intertok_distance = [np.linalg.norm(direction) for direction in intertok_difference]
    intertok_distance = shift(intertok_distance)
    p.line(range(len(intertok_distance)), intertok_distance, color=colors[i], legend_label=layer)
    p.circle(range(len(intertok_distance)), intertok_distance, color=colors[i])
p.axis.visible = True
p.xaxis.ticker = list(range(len(intertok_distance)))
p.xaxis.major_label_overrides = {x:f'-{doc[x+1]}' for x in range(1, len(intertok_distance))}
p.xaxis.major_label_overrides[0] = f'{doc[0]}-{doc[1]}'
p.xaxis.major_label_text_font_size = '6pt'
show(p)

In [None]:
colors = Inferno[256][::15]
column_plots = []
for i,layer in enumerate(layers):
    doc_acts = layer_to_acts[layer][doc_ids]
    intertok_difference = [doc_acts[i]-doc_acts[i+1] for i in range(len(doc)-2)]
    intertok_distance = [np.linalg.norm(direction) for direction in intertok_difference]
    intertok_distance = shift(intertok_distance)
    p = figure(width=225, height=50)
    p.line(range(len(intertok_distance)), intertok_distance, color=colors[i])
    p.circle(range(len(intertok_distance)), intertok_distance, color=colors[i])
    p.axis.visible = True
    p.xaxis.ticker = list(range(len(intertok_distance)))
    p.xaxis.major_label_overrides = {x:f'-{doc[x+1]}' for x in range(1, len(intertok_distance))}
    p.xaxis.major_label_overrides[0] = f'{doc[0]}-{doc[1]}'
    # p.xaxis.major_label_orientation = math.pi/2
    p.xaxis.major_label_text_font_size = '6pt'
    column_plots.append(p)
show(column(column_plots))

In [None]:
for layer in layers:
    doc_acts = layer_to_acts[layer][doc_ids]
    intertok_directions = [doc_acts[i]-doc_acts[i+1] for i in range(len(doc)-1)]
    intertok_distances = [np.linalg.norm(direction) for direction in intertok_directions]
    intertok_distances = normalize(intertok_distances)
    
    font_size = '9pt'
    html = f''
    # html = f"<span style='background-color: rgba(0,0,0); font-size: {font_size};'> {tok} </span>"
    for tok, distance in zip(doc[1:],intertok_distances):
        color_intensity = int((distance) * 255)  # color faraway words
        html += f"<span style='background-color: rgba({color_intensity},0,0); font-size: {font_size};'> {tok} </span>"
        # black indicates very similar to previous token
    display(HTML(html))

# Token-token similarity

In [None]:
def get_points_on_circle(n_points, radius=1):
    radians = np.radians(np.linspace(360/n_points,360,n_points))
    xy_points = radius * np.array([[math.cos(val),math.sin(val)] for val in radians])
    return [[xy[0],xy[1]] for xy in xy_points]

In [None]:
tok_points = get_points_on_circle(len(doc)) # points
tok_xs, tok_ys = zip(*tok_points)
tok_xs, tok_ys = list(tok_xs), list(tok_ys)
source = ColumnDataSource(
    {
    'tok_xs': tok_xs,
    'tok_ys': tok_ys,
    'toks': doc,
#     'offset': [7]*(len(doc)//4) + [-30]*(len(doc)//2) + [7]*(len(doc)//4)
    }
)

rows = []
rows.append([Div(text=tok) for tok in doc])
for layer in layers:
    intertok_distances = np.empty((len(doc), len(doc)))
    intertok_distances[:] = np.nan
    doc_acts = layer_to_acts[layer][doc_ids]
    # get distances
    for tok1_idx in range(len(doc)):
        for tok2_idx in range(len(doc)):
            difference = doc_acts[tok1_idx]-doc_acts[tok2_idx]
            intertok_distances[tok1_idx][tok2_idx] = np.linalg.norm(difference)
    intertok_distances = normalize(intertok_distances)
    # draw row of plots
    plots = []
    for tok1_idx in range(len(doc)):
        p = figure(width=50, height=50)
#         p.scatter(x='tok_xs', y='tok_ys', color='black', source=source)# draw points
#         p.add_layout(LabelSet(x='tok_xs', y='tok_ys', text='toks', x_offset='offset', y_offset=0, source=source, render_mode='canvas', text_font_size='6pt'))
        for tok2_idx in range(len(doc)):
            distance = intertok_distances[tok1_idx][tok2_idx]
            color_intensity = int((1-distance) * 255)  # color close connections
            p.line(x=[tok_xs[tok1_idx], tok_xs[tok2_idx]], y=[tok_ys[tok1_idx], tok_ys[tok2_idx]], color=f'rgb({255-color_intensity},255,{255-color_intensity})', width=2)
        p.grid.visible = False
        p.axis.visible = False
        plots.append(p)
    rows.append(plots)
show(gridplot(rows))

# Dimensionality reduced

In [None]:
reduction, dim = 'NMF', 4
layer_to_doc_reduced_acts = {layer:acts_util.reduce_activations(acts, reduction, dim) for layer,acts in layer_to_doc_acts.items()}
pure_directions = np.eye(dim)

In [None]:
colors = Inferno[256][::15]
p = figure(width=1000, height=300)
p.add_layout(Legend(orientation='horizontal', label_text_font_size='6pt', label_width=10), 'above')
for i,layer in enumerate(layers):
    intertok_difference = [layer_to_doc_reduced_acts[layer][i]-layer_to_doc_reduced_acts[layer][i+1] for i in range(len(doc)-2)]
    intertok_distance = [np.linalg.norm(direction) for direction in intertok_difference]
#     mean = np.average(intertok_distance)
#     intertok_distance = intertok_distance - mean
    p.line(range(len(intertok_distance)), intertok_distance, color=colors[i], legend_label=layer)
    p.circle(range(len(intertok_distance)), intertok_distance, color=colors[i])
p.axis.visible = True
p.xaxis.ticker = list(range(len(intertok_distance)))
p.xaxis.major_label_overrides = {x:f'-{doc[x+1]}' for x in range(1, len(intertok_distance))}
p.xaxis.major_label_overrides[0] = f'{doc[0]}-{doc[1]}'
p.xaxis.major_label_text_font_size = '6pt'
show(p)

In [None]:
layer_to_rgbs = {layer:vis_util.channels_to_rgbs(reduced_acts) for (layer,reduced_acts) in doc_reduced_acts.items()}
pure_colors = [vis_util.channels_to_rgbs(direction) for direction in pure_directions]
# pure_colors, layer_to_rgbs['arr_0'][1]

In [None]:
legend_html = ''
for i, color in enumerate(pure_colors):
    color_str = html_util.rgb_to_color(*color[0])
    legend_html += html_util.highlight_html(f' {i} ', color=color_str)
display(HTML(legend_html))

for layer, rgbs in layer_to_rgbs.items():
    html = ''
    for tok, rgb in zip(doc, rgbs):
        color_str = html_util.rgb_to_color(*rgb)
        html += f"<span style='background-color: rgba({rgb[0]},{rgb[1]},{rgb[2]},1);'> {tok} </span>"
    print(layer)
    display(HTML(html))

In [None]:
n_clusters = 6
pure_directions = np.eye(n_clusters)
pure_rgbs = [list(vis_util.channels_to_rgbs(direction)[0]) for direction in pure_directions]
legend_html = ''
for i, rgb in enumerate(pure_rgbs):
    color = html_util.rgb_to_color(*rgb)
    legend_html += html_util.highlight_html(f' {i} ', color=color)
display(HTML(legend_html))
for layer in layers:
    doc_acts = layer_to_doc_acts[layer]
    doc_cluster_labels = KMeans(n_clusters=n_clusters).fit(doc_acts).labels_
    html = ''
    for tok, cluster_label in zip(doc, doc_cluster_labels):
        rgb = pure_rgbs[cluster_label]
        color = html_util.rgb_to_color(*rgb)
        html += f"<span style='background-color: {color};'> {tok} </span>"
    print(layer)
    display(HTML(html))
