# Studying the evolution of tokens
How do tokens evolve? Are they originally understood in a shallow way and then understood in richer ways?
Rich in what sense? More similar surrounding words?

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
view_vis_as_html = False

In [None]:
# Imports
# # LOAD
import pickle
import numpy as np
import os
import sys
from src import references as refs
# # PROCESS
import nltk
import pandas as pd
sys.path.insert(0, os.path.abspath('../../..'))
from sklearn.cluster import KMeans
from src.utils import context_util, acts_util
# # VIS
from IPython.core.display import display, HTML
from bokeh.plotting import figure, show, output_file
from bokeh.layouts import gridplot
from bokeh.palettes import Inferno, Category10, Category20, Category20c, Pastel1, Pastel2, Bokeh, Plasma, Colorblind
from bokeh.models import Label, LabelSet, Div, ColumnDataSource, Legend, LegendItem, Range1d
from bokeh.models import HoverTool, CustomJS, PanTool, BoxZoomTool, WheelZoomTool, ResetTool, TapTool, OpenURL
from bokeh.models.glyphs import Circle
from bokeh.io import output_notebook
output_notebook()
if view_vis_as_html:
    output_file('visualize-wiki.html')
from src.utils import vis_util, html_util

# Loading contexts

In [None]:
# Load all contexts and acts
data_dir = '../../../bucket/wiki-large/wiki-split/'
data_dir = os.path.abspath(data_dir)
layers = [f'arr_{i}' for i in range(13)] # which layers to visualize
# layers = ['arr_0','arr_3','arr_6', 'arr_9', 'arr_12']  
# layers = ['arr_0']  # good for debugging
contexts = pickle.load(open(os.path.join(data_dir, refs.contexts_fn), 'rb'))
acts = np.load(os.path.join(data_dir, refs.acts_fn))

In [None]:
# Choose contexts and acts
token = 'woman'
chosen_idxs, chosen_contexts = zip(*[(i, (doc, pos)) for (i, (doc, pos)) in enumerate(contexts) if doc[pos]==token])
n_chosen = len(chosen_contexts)
print(n_chosen, 'contexts:')
for i, context in enumerate(chosen_contexts):
    display(HTML(f'({i}) {context_util.context_html(*context)}'))

In [None]:
chosen_acts = {layer: acts[layer][list(chosen_idxs)] for layer in layers}

# Show clustered contexts

In [None]:
# Get contentful reduction
n_clusters = 3
palette = Category10[10]
css='font-size:10px; line-height: 12px; display: block; text-align: left;'
sort = True

clusters = {layer: KMeans(n_clusters=n_clusters).fit(chosen_acts[layer]).labels_ for layer in layers}
for layer in layers:
    display(HTML(html_util.highlight(layer)))
    contexts_and_clusters = list(zip(chosen_contexts, clusters[layer]))
    if sort:
        for cluster_idx in range(n_clusters):
            html = ''
            cluster_contexts = [context for context, cluster in contexts_and_clusters if cluster==cluster_idx]
            for context in cluster_contexts:
                color = palette[cluster_idx] # f'{cluster_idx}' + 
                html += html_util.style(context_util.context_html(*context, marker=html_util.highlighter(color)), css=css)
            display(HTML(html))
    else:
        html = ''
        for context, cluster in contexts_and_clusters:
            color = palette[cluster]
            html += html_util.style(context_util.context_html(*context, marker=html_util.highlighter(color)), css)
        display(HTML(html))
    print()    
    

In [None]:
# Fresh vis
columns = []
layer_name_column = [None] + [Div(text=layer, align=('center', 'center')) for layer in layers]
columns.append(layer_name_column)

# create a column of plots
plot_column = []
plot_column.append(Div(text=f'{reduction}', align=('center', 'center'))) # column header
for layer in layers:
    chosen_points = {
        'x': reduced_chosen_acts[layer][:,0],
        'y': reduced_chosen_acts[layer][:,1],
        'color': [palette[idx] for idx in range(n_chosen)],
        'line color': ['black'] * n_chosen,
        'line width': [1] * n_chosen,
        'label': [f'{i}' for i in range(n_chosen)],
        'legend label': [f'{i}' for i in range(n_chosen)],        
        'hover label': [context_util.context_html(*context, highlighter=html_util.highlighter(color='yellow')) for context in chosen_contexts]
        }

    neighbor_points = {'x':[], 'y':[], 'color':[], 'legend label':[], 'label':[], 'hover label':[]}
    processed_neighbors = []
    neighborhoods_info = zip(neighborhoods[layer], neighborhoods_contexts[layer], reduced_neighborhoods_acts[layer])
    for neighborhood_idx, (neighbors, contexts, reduced_acts) in enumerate(neighborhoods_info):
        # visualize different kinds of neighbors differently
        neighbors_info = zip(neighbors, contexts, reduced_acts)
        for neighbor_idx, (neighbor, context, reduced_act) in enumerate(neighbors_info):
            if neighbor in chosen_ids:  # update chosen point to say it's also a neighbor in this neighborhood
                chosen_idx = chosen_ids.index(neighbor)
                chosen_points['label'][chosen_idx] += f'({neighborhood_idx})'
                chosen_points['line color'][chosen_idx] = 'aqua'
                chosen_points['line width'][chosen_idx] = 1
            elif neighbor in processed_neighbors:  # update existing neighbor point to say it's also in this neighborhood
                neighbor_idx = processed_neighbors.index(neighbor)
                neighbor_points['label'][neighbor_idx] += f'({neighborhood_idx})'
                neighbor_points['color'][neighbor_idx] = 'aqua'
            else:  # new neighbor, say which neighborhood
                neighbor_points['x'].append(reduced_act[0])
                neighbor_points['y'].append(reduced_act[1])
                neighbor_points['color'].append(palette[neighborhood_idx])
                neighbor_points['legend label'].append(f'{neighborhood_idx}')
                neighbor_points['label'].append(f'({neighborhood_idx})')
                neighbor_points['hover label'].append(context_util.context_html(*context, highlighter=html_util.highlighter(color='lightgrey')))
                processed_neighbors.append(neighbor)    
    neighbor_points['label'] = [label if len(label)>1 else '' for label in neighbor_points['label']]
    
    # plot 
    chosen_points_source = ColumnDataSource(chosen_points)
    neighbor_points_source = ColumnDataSource(neighbor_points)
    p = vis_util.empty_plot(width=400, height=250, darkmode=False)
    p.add_layout(Legend(), 'right')
    p.circle(x='x', y='y', color='color', size=10, legend_group='legend label', source=neighbor_points_source)
    p.add_layout(LabelSet(x='x', y='y', text='label', x_offset=2, y_offset=2, text_font_size='10pt', source=neighbor_points_source))
    p.triangle(x='x', y='y', color='color', line_color='line color', size=15, line_width='line width', legend_group='legend label', source=chosen_points_source)
    p.add_layout(LabelSet(x='x', y='y', text='label', x_offset=2, y_offset=2, text_font_size='10pt', source=chosen_points_source))
    zoom_tool = WheelZoomTool()
    p.tools = [PanTool(), zoom_tool, BoxZoomTool(), ResetTool(), HoverTool(tooltips=vis_util.custom_bokeh_tooltip('hover label'))]
    p.toolbar.active_scroll = zoom_tool
    plot_column.append(p)
columns.append(plot_column)
show(gridplot(zip(*columns)))