# Studying the evolution of tokens
How do tokens evolve? Are they originally understood in a shallow way and then understood in richer ways?
Rich in what sense? More similar surrounding words?

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
view_vis_as_html = False

In [None]:
# Imports
# # LOAD
import pickle
import numpy as np
import os
import sys
sys.path.insert(0, os.path.abspath('../../..'))
from src import references as refs
# # PROCESS
import nltk
import pandas as pd
from sklearn.cluster import KMeans
from src.utils import context_util, acts_util
# # VIS
from IPython.core.display import display, HTML
from bokeh.plotting import figure, show, output_file
from bokeh.layouts import gridplot
from bokeh.palettes import Inferno, Category10, Category20, Category20b, Category20c, Pastel1, Pastel2, Bokeh, Plasma, Colorblind
from bokeh.models import Label, LabelSet, Div, ColumnDataSource, Legend, LegendItem, Range1d
from bokeh.models import HoverTool, CustomJS, PanTool, BoxZoomTool, WheelZoomTool, ResetTool, TapTool, OpenURL
from bokeh.models.glyphs import Circle
from bokeh.io import output_notebook
output_notebook()
if view_vis_as_html:
    output_file('visualize-wiki.html')
from src.utils import vis_util, html_util

# Loading contexts

In [None]:
# Load all contexts and acts
data_dir = '../../../bucket/wiki-large/wiki-split/'
data_dir = os.path.abspath(data_dir)
layers = [f'arr_{i}' for i in range(13)] # which layers to visualize
# layers = ['arr_0','arr_3','arr_6', 'arr_9', 'arr_12']  
# layers = ['arr_0']  # good for debugging
contexts = pickle.load(open(os.path.join(data_dir, refs.contexts_fn), 'rb'))
acts = np.load(os.path.join(data_dir, refs.acts_fn))

In [None]:
# Choose contexts and acts
token = 'woman'
chosen_idxs, chosen_contexts = zip(*[(i, (doc, pos)) for (i, (doc, pos)) in enumerate(contexts) if doc[pos]==token])
n_chosen = len(chosen_contexts)
print(n_chosen, 'contexts:')
for i, context in enumerate(chosen_contexts):
    display(HTML(f'({i}) {context_util.context_html(*context)}'))

In [None]:
chosen_acts = {layer: acts[layer][list(chosen_idxs)] for layer in layers}

# Show clustered contexts as text

In [None]:
# Get contentful reduction
n_clusters = 5
palette = Category10[10]
css='font-size:10px; line-height: 12px; display: block; text-align: left;'
sort = True

clusters = {layer: KMeans(n_clusters=n_clusters).fit(chosen_acts[layer]).labels_ for layer in layers}
for layer in layers:
    display(HTML(html_util.highlight(layer)))
    contexts_and_clusters = list(zip(chosen_contexts, clusters[layer]))
    if sort:
        for cluster_idx in range(n_clusters):
            html = ''
            cluster_contexts = [context for context, cluster in contexts_and_clusters if cluster==cluster_idx]
            for context in cluster_contexts:
                color = palette[cluster_idx] # f'{cluster_idx}' + 
                html += html_util.style(context_util.context_html(*context, marker=html_util.highlighter(color)), css=css)
            display(HTML(html))
    else:
        html = ''
        for context, cluster in contexts_and_clusters:
            color = palette[cluster]
            html += html_util.style(context_util.context_html(*context, marker=html_util.highlighter(color)), css)
        display(HTML(html))
    print()    
    

# Vis contexts as plots

In [None]:
# Get 2D reduction
reduction_2d = 'PCA'
reducers_and_reduced_acts_2d = {layer: acts_util.fit_reducer(chosen_acts[layer], reduction=reduction_2d, dim=2) for layer in layers}
reducers_2d = {layer: reducers_and_reduced_acts_2d[layer][0] for layer in layers}
reduced_chosen_acts_2d = {layer: reducers_and_reduced_acts_2d[layer][1] for layer in layers}

In [None]:
# Create info table
contexts_info = pd.DataFrame()
# add 2d points
for layer in layers:
    contexts_info[f'{layer} {reduction_2d} x'] = reduced_chosen_acts_2d[layer][:,0]
    contexts_info[f'{layer} {reduction_2d} y'] = reduced_chosen_acts_2d[layer][:,1]

In [None]:
# add basic properties
contexts_info['doc'] = [doc for doc, pos in chosen_contexts]
contexts_info['position'] = [pos for doc, pos in chosen_contexts]
contexts_info['token'] = contexts_info['doc'].combine(contexts_info['position'], lambda doc,position: doc[position])
contexts_info['context str'] = contexts_info['doc'].combine(contexts_info['position'], context_util.context_str)
contexts_info['context html'] = contexts_info['doc'].combine(contexts_info['position'], context_util.context_html)
contexts_info['abbreviated context'] = contexts_info['doc'].combine(contexts_info['position'], context_util.abbreviated_context)
contexts_info['abbreviated context html'] = contexts_info['doc'].combine(contexts_info['position'], context_util.abbreviated_context_html)
# more properties
contexts_info['doc length'] = contexts_info['doc'].apply(len)
def reverse_position(doc, position): return len(doc)-1-position
contexts_info['position from end'] = contexts_info['doc'].combine(contexts_info['position'], reverse_position)
def POS_tag(doc, pos): return nltk.pos_tag(doc)[pos][1]
contexts_info['POS'] = contexts_info['doc'].combine(contexts_info['position'], POS_tag)

In [None]:
# Fresh vis
palette = Category20[20]
columns = []
layer_name_column = [None] + [Div(text=layer, align=('center', 'center')) for layer in layers]
columns.append(layer_name_column)
columns.append(vis_util.visualize_columns(contexts_info, layers, reduction_2d, ('position',), size=100))
show(gridplot(zip(*columns)))

# Color by clusters

In [None]:
# Get clusters
n_clusters = 5
knn_models = {layer: KMeans(n_clusters=n_clusters).fit(chosen_acts[layer]) for layer in layers}
components = {layer: knn_models[layer].labels_ for layer in layers}
reduced_chosen_acts = {layer: knn_models[layer].predict(chosen_acts[layer]) for layer in layers}   
del knn_models

In [None]:
# Add to contexts df and vis
for layer in layers:
    contexts_info[f'{layer} clusters'] = reduced_chosen_acts[layer]
columns.append(vis_util.visualize_columns(contexts_info, layers, reduction_2d, ['clusters'], size=100, layerwise=True))
show(gridplot(zip(*columns)))

# Vis contexts and neighbors as plots

In [None]:
# Get all contexts' neighborhoods
knn_models_fn = 'KNN_models_K5.pickle'
n_neighbors = 4
neighborhoods = {}
neighborhoods_contexts = {}
with open(os.path.join(data_dir, knn_models_fn), 'rb') as f:
    for layer in layers:
        print(layer)
        print('Loading nearest neighbors model.')
        knn_model = pickle.load(f)
        print('Finding neighbors')
        _neighborhoods = knn_model.kneighbors(chosen_acts[layer], n_neighbors=n_neighbors, return_distance=False) # indices
        neighborhoods[layer] = [neighborhood[1:] for neighborhood in _neighborhoods]
        neighborhoods_contexts[layer] = [[contexts[idx] for idx in neighborhood] for neighborhood in neighborhoods[layer]]

In [None]:
# Reduce neighborhoods
neighborhoods_2d = {layer: [reducers_2d[layer].transform(acts[layer][neighborhood]) for neighborhood in neighborhoods[layer]] 
                              for layer in layers}

In [None]:
palette = Category20[20] + Category20b[20]
contexts_colors = [palette[i] for i in range(len(chosen_contexts))]
chosen_htmls = [context_util.context_html(*context) for context in chosen_contexts]
for layer in layers:
    points = reduced_chosen_acts_2d[layer]
    p = vis_util.empty_plot(size=200)
    p.tools = [vis_util.hover_tool('hover label')]
    points_source = {'x': points[:,0], 'y': points[:,1], 'color': contexts_colors, 'hover label': chosen_htmls}
    p.triangle('x', 'y', color='color', line_color='black', size=10, source=ColumnDataSource(points_source))
    for i, neighborhood_2d in enumerate(neighborhoods_2d[layer]):
        neighs_htmls = [context_util.context_html(*contexts[idx]) for idx in neighborhoods[layer][i]]
        color = contexts_colors[i]
        neighs_source = {'x': neighborhood_2d[:,0], 'y': neighborhood_2d[:,1], 'hover label': neighs_htmls}
        p.circle(x='x', y='y', color=color, source=neighs_source)
    show(p)