# Understand a context through its clusters, its neighbors, and custom possible neighbors
Examine, at each layer, the given context's KNNs - both KNNs in the corpus and the KNNs out of its own masked variants - as well as any custom contexts you're curious about.

In [None]:
'test: notebook is running'

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import glob
import random
from typing import List
import collections
import itertools
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from IPython.core.display import display, HTML

import os
import sys
sys.path.insert(0, os.path.abspath('../../..'))
from src.utils import acts_util, vis_util, html_util, context_util, bert_util
from src.utils.acts_util import spherize
from src.utils.context_util import context_html
from src.utils.html_util import highlighter, fix_size, font_size, style
from src.utils import references as refs
from src.utils.SimpleBert import SimpleBert

In [None]:
bert = SimpleBert()

## Parameters

In [None]:
spherize = False
vis_color = True
reduction, dim = 'NMF', 3
vis_size = True

vis_masked_KNNs = True
pruning = False
vis_corpus_KNNs = True
vis_custom_contexts = True
together = True
save_to_html = False

doc_txt = "Later that day, he caught her eye."

## Get tokens and activations & choose layers to analyze

In [None]:
doc, doc_acts = bert.get_toks_and_acts(doc_txt)
layers = list(doc_acts)[::3]
# layers = layers[:2] # good for debugging
if spherize: doc_acts = {layer: spherize(doc_acts[layer]) for layer in layers}

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')

## Reduce sample

In [None]:
if vis_color:
    # reduce
    doc_components = {}
    doc_reduced_acts = {}
    for layer in layers:
        _components, _reduced_acts = acts_util.fit_components(doc_acts[layer], reduction, dim)
        doc_components[layer] = _components
        doc_reduced_acts[layer] = _reduced_acts
    if vis_size:
        doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

## Get neighbors

In [None]:
selected_tok = 'caught'
selected_tok_idx = doc.index(selected_tok)

### Get corpus neighbors

In [None]:
if vis_corpus_KNNs:
    # io
    corpus_dir = '/atlas/u/pkalluri/bert-vis/big-data/wiki-small'
    corpus_dir = os.path.abspath(corpus_dir)
    subdir = refs.sphere_dir if spherize else refs.standard_dir
    knn_models_dir = os.path.join(corpus_dir, subdir, refs.knn_models_dirname)
    n_neighbors = 20

    # get KNNs
    corpus_neighbors = {}
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, subdir, refs.contexts_fn),'rb'))
    for layer in layers:
        print(f'Layer {layer}')

        print('Loading nearest neighbors model.')
        knn_model = pickle.load(open(os.path.join(corpus_dir, subdir, knn_models_dir, f'{layer}.pickle'), 'rb'))

        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs_dists, neighs_ids = neighs_dists[0], neighs_ids[0]
        neighs = [(neigh_id, corpus_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids, neighs_dists)]  # a neighbor is an id, the corresponding context, and the distance away
        corpus_neighbors[layer] = neighs
        del knn_model

## Get masked neighbors

In [None]:
if vis_masked_KNNs:
    # params
    mask_lengths = range(max(len(doc)-3, 1),len(doc)-1)
    # GET MASKED VARIANTS OF SAMPLE
    print('Getting variants.')
    variants = [doc] + bert_util.get_masked_variants(doc, mask_lengths)
    print('Getting activations.')
    variants_contexts, variants_acts = bert.get_contexts_and_acts(variants, tokenized=True)
    if spherize: variants_acts = {layer: spherize(variants_acts[layer]) for layer in layers}
    
    # GET KNNs FROM MASKED VARIANTS
    # params
    n_neighbors = 20

    # get KNNs
    masked_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')

        print('Fitting nearest neighbors model.')
        knn_model = NearestNeighbors(n_neighbors=n_neighbors).fit(variants_acts[layer])

        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs_dists, neighs_ids = neighs_dists[0], neighs_ids[0]
        neighs = [(variants_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids, neighs_dists)]  # a neighbor is an id, the corresponding context, and the distance away
        masked_neighbors[layer] = neighs
        del knn_model

## Check similarity of custom sentences

In [None]:
if vis_custom_contexts:
    custom_contexts_unprocessed = [
                                    ('He caught it.', 'caught'),
                                    ('He caught her eye.', 'caught'),
                                    ('The rabbit caught the fox\'s eye.', 'caught'),
                                    ('The rabbit caught the fox\'s attention.', 'caught'),
                                    ('0 1 2 3 4 caught 6 7 8.', 'caught'),
                                    ('He [MASK] it.', '[MASK]'),
                                    ('He [MASK] her eye.', '[MASK]'),
                                    ('The rabbit [MASK] the fox\'s eye.', '[MASK]'),
                                    ('The rabbit [MASK] the fox\'s attention.', '[MASK]'),
                                    ('0 1 2 3 4 [MASK] 6 7 8.', '[MASK]'),
                                    ('The rabbit caught the carrot.', 'caught')
                                    ]
    custom_neighbors = {layer: [] for layer in layers}
    for custom_txt, custom_tok in custom_contexts_unprocessed:
        custom_doc, custom_doc_acts = bert.get_toks_and_acts(custom_txt)
        custom_pos = custom_doc.index(custom_tok)
        custom_context = (custom_doc, custom_pos)
        for layer in layers:
            _act = doc_acts[layer][selected_tok_idx]
            _custom_act = custom_doc_acts[layer][custom_pos]
            if spherize: 
                _act = spherize([_act])[0]
                _custom_act = spherize([_custom_act])[0]
            dist = np.linalg.norm(_custom_act - _act)
            custom_neighbors[layer].append((custom_context, dist))
    

## Vis

In [None]:
# VISUALIZE

override_params = True  # useful for quick iteration on visualization
if override_params:
    vis_color = True
    vis_size = True
    vis_masked_KNNs = False
    vis_corpus_KNNs = True
    vis_custom_contexts = False
    together = True
    pruning=True  
    save_to_html = False
    # if pruning is true, in the masked neighbors vis, first the closest 0-length mask is presented, 
    # then jumps to the closes 1-length max, etc - forming a kind of slow pruning to capture the essence of the context.
# other vis params
max_font_size = 10
if save_to_html: html_doc = ''
# params if visualizing neighbors
if vis_corpus_KNNs or vis_masked_KNNs or vis_custom_contexts:
    n_neighbors_to_vis = 50
    token_styler = lambda t: t
    # token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')  # another pretty styling option
    bluer = highlighter('lightblue')
    greener = highlighter('limegreen')
    greyer = highlighter('grey')
    masker = highlighter('black')
    selected_tok_html = context_html(doc, selected_tok_idx)

# vis legend
if vis_color:
    pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
    print('Legend')
    display(HTML(html))
    if save_to_html: html_doc += html + '<br>'
    print() 
    rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}  # prepare coloring

corpus_neighs_so_far = []  
for layer in layers:
    html = f'Layer {layer}'
    display(HTML(html))
    if save_to_html: html_doc += html
    
    # vis sample
    _rgbs = rgbs[layer]
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))    
    if vis_color:
        color_html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            color_html += html_util.style(f' {tok} ', css=css)
        display(HTML(color_html))
        if save_to_html: html_doc += color_html + '<br>'
    
    # vis knns
    if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts:
        # Show the doc
        display(HTML(selected_tok_html))
        if save_to_html: html_doc += selected_tok_html + '<br>'
        
        if together:
            neighbors = []
            if vis_masked_KNNs:
                neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighbors[layer]])
            if vis_corpus_KNNs:
                neighbors.extend([(context, dist, 'corpus') for context_id, context, dist in corpus_neighbors[layer]])
            if vis_custom_contexts:
                neighbors.extend([(context, dist, 'custom') for context, dist in custom_neighbors[layer]])
            neighbors.sort(key=lambda t: t[1])  # sort all by dist
            
            for neigh_context, neigh_dist, neigh_source in neighbors[:n_neighbors_to_vis]:
                html = f'({neigh_dist:.3f}) '
                if neigh_source == 'masked':
                    html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
                elif neigh_source == 'corpus':
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html += context_html(*neigh_context, marker=marker, masker=masker, token_styler=token_styler)
                    corpus_neighs_so_far.append(neigh_context)
                elif neigh_source == 'custom':
                    html += context_html(*neigh_context, marker=greener, masker=masker, token_styler=token_styler)
                display(HTML(html))
                if save_to_html: html_doc += html + '<br>'
        else:
            if vis_masked_KNNs:
                _masked_neighbors = masked_neighbors[layer]
                if pruning:
                    for mask_len_to_search in range(len(doc)-2):
                        for neigh_context, neigh_dist in _masked_neighbors:
                            neigh_toks, neigh_pos = neigh_context
                            mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
                            if mask_len == mask_len_to_search:
                                html = token_styler(f'({neigh_dist:.3f}) ') +  context_html(*neigh_context, token_styler=token_styler)
                                display(HTML(html))
                                if save_to_html: html_doc += html + '<br>'
                                break
                else:
                    for neigh_context, neigh_dist in _masked_neighbors[:n_neighbors_to_vis]:
                        html = token_styler(f'({neigh_dist:.3f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                        display(HTML(html))
                        if save_to_html: html_doc += html + '<br>'
            if vis_corpus_KNNs:
                _corpus_neighs = corpus_neighbors[layer][:n_neighbors_to_vis]
                for neigh_id, neigh_context, neigh_dist in _corpus_neighs:
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html = token_styler(f'({neigh_dist:.3f})') + context_html(*neigh_context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
                    corpus_neighs_so_far.append(neigh_context)
            if vis_custom_contexts:
                _custom_neighbors = custom_neighbors[layer]
                for context, dist in _custom_neighbors:
                    marker = greener
                    html = token_styler(f'({dist:.3f})') + context_html(*context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
    print('.')
    print('.')
    print('.')

if save_to_html: open('../../../building-blocks.html', 'w').write(html_doc)

## Visualize the evolution of context and all neighbors (In progress)

In [None]:
# import umap
# import itertools
# # VIS
# from bokeh.plotting import figure, show, output_file
# from IPython.core.display import display, HTML
# from bokeh.io import output_notebook
# from bokeh.models import Div, HoverTool, ColumnDataSource, PanTool, BoxZoomTool, WheelZoomTool, ResetTool
# from bokeh.layouts import gridplot
# from bokeh.palettes import Category20
# view_vis_as_html = False
# output_notebook()
# if view_vis_as_html:
#     output_file('visualize-wiki.html')

In [None]:
# import plotly.express as px
# import plotly.io as pio
# pio.renderers.default = 'iframe'

In [None]:
# # Fresh visualization
# grid = vis_util.Grid()
# verbose = True
# palette = Category20[20]
# layer_palette = {layer: palette[layer_idx] for layer_idx, layer in enumerate(layers)}
# n_neighbors = 20  # todo: set automatically from corpus knns section

# col = []
# points_ids = []
# points_contexts = []
# points_layers = []
# points_colors = []
# points_htmls = []
# for layer in layers:
#     _points_ids, _points_contexts, _ = zip(*corpus_neighborhoods[layer][tok_of_interest_idx])
#     points_ids += _points_ids
#     points_contexts += _points_contexts 
#     points_layers += [layer] * len(_points_ids)
#     points_colors += [layer_palette[layer]] * len(_points_ids)
#     points_htmls += [context_util.context_html(*context) for context in _points_contexts]
# reducer = umap.UMAP(random_state=1)
# for layer_to_plot in layers:
#     if verbose: print('Gathering plot info...')
#     points_acts = corpus_acts[layer_to_plot][points_ids]
#     points_2d = reducer.fit_transform(points_acts)
#     tok_point_act = doc_acts[layer_to_plot][tok_of_interest_idx]
#     tok_point_2d = reducer.transform([tok_point_act])[0]
#     if verbose: print('Making empty plot...')
#     p: figure = vis_util.empty_plot(size=200)
#     if verbose: print('Adding points....')
#     source = ColumnDataSource({'x': points_2d[:,0], 'y': points_2d[:,1], 'color': points_colors, 'label': points_htmls})
#     p.circle('x', 'y', color='color', source=source)
#     star_source = ColumnDataSource({'x': [tok_point_2d[0],], 'y': [tok_point_2d[1],], 'color': [layer_palette[layer_to_plot],], 'size': [10], 'label': [tok_of_interest_html,]})
#     p.star('x', 'y', color='color', siz='size', source=star_source)
#     if verbose: print('Adding tools...')
#     p.tools = [WheelZoomTool(), PanTool(), BoxZoomTool(), ResetTool(), vis_util.hover_tool('label')]
#     col.append(p)
# grid.add_column('', col)
# grid.show()

In [None]:
    #     colors = ['red' if neigh_layer==layer else 'black' for neigh_layer in all_neighs_layers]
#     colors = []
#     for neigh_layer, neigh_context in zip(all_neighs_layers, all_neighs_contexts):
#         color = 'black'
#         if neigh_layer == layer:
#             color = 'blue'
#         if neigh_context in neighs_contexts_so_far:
#             print(layer, neighs_contexts_so_far)
#             color = 'green'
#         colors.append(color)
#         neighs_contexts_so_far.append(neigh_context)

## Try a custom sentence

In [None]:
# choose
new_text = 'Later that day, [MASK].'
new_text_tok_of_interest = '[MASK]'

# vis
new_doc, new_acts = bert.get_toks_and_acts(new_text)
new_pos = new_doc.index(new_text_tok_of_interest)
for layer in layers:
    current_act = doc_acts[layer][selected_tok_idx]
    new_act = new_acts[layer][new_pos]
    if spherize: new_act = spherize([new_act])[0]
    display(HTML(f'Layer {layer}'))
    dist = np.linalg.norm(current_act - new_act)
    token_styler = lambda t: font_size(t,max_font_size)
    html = token_styler(f'({dist:.3f})')
    html += context_html(new_doc, new_pos, marker=bluer, token_styler=token_styler)
    display(HTML(html))