# Understand a context through its clusters, its neighbors, and custom possible neighbors
Examine, at each layer, the given context's KNNs - both KNNs in the corpus and the KNNs out of its own masked variants - as well as any custom contexts you're curious about.

## Imports

In [1]:
'Test notebook is working.'

'Test notebook is working.'

In [2]:
%load_ext autoreload
%autoreload 2

import pickle
import glob
import random
from typing import List
import collections
import itertools
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from IPython.core.display import display, HTML

import os
import sys
sys.path.insert(0, os.path.abspath('../../..'))
from src.utils import acts_util, vis_util, html_util, context_util, bert_util
from src.utils.acts_util import spherize
from src.utils.context_util import context_html, context_str
from src.utils.html_util import highlighter, fix_size, font_size, style
from src.utils import references as refs
from src.utils.SimpleBert import SimpleBert

In [3]:
bert = SimpleBert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Parameters

In [11]:
spherize = False
vis_color = False
reduction, dim = 'NMF', 3
vis_size = False  # If visualizing color, can also visualize size

vis_masked_KNNs = False
pruning = False
vis_corpus_KNNs = True
vis_custom_contexts = False
together = False
save_to_html = False

doc_txt = "Two Muslims walk into a bar."

## Get tokens and activations & choose layers to analyze

In [12]:
doc, doc_acts = bert.get_toks_and_acts(doc_txt)
layers = list(doc_acts)[::3]
# layers = layers[:2] # good for debugging
if spherize: doc_acts = {layer: spherize(doc_acts[layer]) for layer in layers}

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')


Document:
[CLS] Two Muslims walk into a bar . [SEP]

Layers: arr_0, arr_3, arr_6, arr_9, arr_12


## Reduce sample

In [13]:
if vis_color:
    # reduce
    doc_components = {}
    doc_reduced_acts = {}
    for layer in layers:
        _components, _reduced_acts = acts_util.fit_components(doc_acts[layer], reduction, dim)
        doc_components[layer] = _components
        doc_reduced_acts[layer] = _reduced_acts
    if vis_size:
        doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

## Get neighbors

In [14]:
selected_tok = 'Muslims'
selected_tok_idx = doc.index(selected_tok)

### Get corpus neighbors

In [15]:
if vis_corpus_KNNs:
    # io
    corpus_dir = '/atlas/u/pkalluri/bert-vis/big-data/wiki-large'
    corpus_dir = os.path.abspath(corpus_dir)
    subdir = refs.sphere_dir if spherize else refs.standard_dir
    knn_models_dir = os.path.join(corpus_dir, subdir, refs.knn_models_dirname)
    n_neighbors = 20

    # get KNNs
    corpus_neighbors = {}
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, subdir, refs.contexts_fn),'rb'))
    for layer in layers:
        print(f'Layer {layer}')

        print('Loading nearest neighbors model.')
        knn_model = pickle.load(open(os.path.join(corpus_dir, subdir, knn_models_dir, f'{layer}.pickle'), 'rb'))

        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs_dists, neighs_ids = neighs_dists[0], neighs_ids[0]
        neighs = [(neigh_id, corpus_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids, neighs_dists)]  # a neighbor is an id, the corresponding context, and the distance away
        corpus_neighbors[layer] = neighs
        del knn_model, neighs_dists, neighs_ids, neighs

Layer arr_0
Loading nearest neighbors model.
Finding neighbors
Layer arr_3
Loading nearest neighbors model.
Finding neighbors
Layer arr_6
Loading nearest neighbors model.
Finding neighbors
Layer arr_9
Loading nearest neighbors model.
Finding neighbors
Layer arr_12
Loading nearest neighbors model.
Finding neighbors


## Get masked neighbors

In [16]:
if vis_masked_KNNs:
    # params
    mask_lengths = range(max(len(doc)-3, 1),len(doc)-1)
    # GET MASKED VARIANTS OF SAMPLE
    print('Getting variants.')
    variants = [doc] + bert_util.get_masked_variants(doc, mask_lengths)
    print('Getting activations.')
    variants_contexts, variants_acts = bert.get_contexts_and_acts(variants, tokenized=True)
    if spherize: variants_acts = {layer: spherize(variants_acts[layer]) for layer in layers}
    
    # GET KNNs FROM MASKED VARIANTS
    # params
    n_neighbors = 20

    # get KNNs
    masked_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')

        print('Fitting nearest neighbors model.')
        knn_model = NearestNeighbors(n_neighbors=n_neighbors).fit(variants_acts[layer])

        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs_dists, neighs_ids = neighs_dists[0], neighs_ids[0]
        neighs = [(variants_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids, neighs_dists)]  # a neighbor is an id, the corresponding context, and the distance away
        masked_neighbors[layer] = neighs
        del knn_model

## Check similarity of custom sentences

In [17]:
if vis_custom_contexts:
    custom_contexts_unprocessed = [
                                    ('He caught it.', 'caught'),
                                    ('He caught her eye.', 'caught'),
                                    ('The rabbit caught the fox\'s eye.', 'caught'),
                                    ('The rabbit caught the fox\'s attention.', 'caught'),
                                    ('0 1 2 3 4 caught 6 7 8.', 'caught'),
                                    ('He [MASK] it.', '[MASK]'),
                                    ('He [MASK] her eye.', '[MASK]'),
                                    ('The rabbit [MASK] the fox\'s eye.', '[MASK]'),
                                    ('The rabbit [MASK] the fox\'s attention.', '[MASK]'),
                                    ('0 1 2 3 4 [MASK] 6 7 8.', '[MASK]'),
                                    ('The rabbit caught the carrot.', 'caught')
                                    ]
    custom_neighbors = {layer: [] for layer in layers}
    for custom_txt, custom_tok in custom_contexts_unprocessed:
        custom_doc, custom_doc_acts = bert.get_toks_and_acts(custom_txt)
        custom_pos = custom_doc.index(custom_tok)
        custom_context = (custom_doc, custom_pos)
        for layer in layers:
            _act = doc_acts[layer][selected_tok_idx]
            _custom_act = custom_doc_acts[layer][custom_pos]
            if spherize: 
                _act = spherize([_act])[0]
                _custom_act = spherize([_custom_act])[0]
            dist = np.linalg.norm(_custom_act - _act)
            custom_neighbors[layer].append((custom_context, dist))
    

## Vis

In [18]:
# VISUALIZE

override_params = False  # useful for quick iteration on visualization
if override_params:
    vis_color = False
    vis_size = False
    vis_masked_KNNs = False
    vis_corpus_KNNs = True
    vis_custom_contexts = False
    together = True
    pruning=True  
    save_to_html = False
    # if pruning is true, in the masked neighbors vis, first the closest 0-length mask is presented, 
    # then jumps to the closes 1-length max, etc - forming a kind of slow pruning to capture the essence of the context.
# other vis params
max_font_size = 10
if save_to_html: html_doc = ''
# params if visualizing neighbors
if vis_corpus_KNNs or vis_masked_KNNs or vis_custom_contexts:
    n_neighbors_to_vis = 50
    token_styler = lambda t: t
    # token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')  # another pretty styling option
    bluer = highlighter('lightblue')
    greener = highlighter('limegreen')
    greyer = highlighter('grey')
    masker = highlighter('black')
    selected_tok_html = context_html(doc, selected_tok_idx)

# vis legend
if vis_color:
    pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
    print('Legend')
    display(HTML(html))
    if save_to_html: html_doc += html + '<br>'
    print() 
    rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}  # prepare coloring

corpus_neighs_so_far = []  
for layer in layers:
    html = f'Layer {layer}'
    display(HTML(html))
    if save_to_html: html_doc += html
    
    # vis sample
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))    
    if vis_color:
        _rgbs = rgbs[layer]
        color_html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            color_html += html_util.style(f' {tok} ', css=css)
        display(HTML(color_html))
        if save_to_html: html_doc += color_html + '<br>'
    
    # vis knns
    if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts:
        # Show the doc
        display(HTML(selected_tok_html))
        if save_to_html: html_doc += selected_tok_html + '<br>'
        
        if together:
            neighbors = []
            if vis_masked_KNNs:
                neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighbors[layer]])
            if vis_corpus_KNNs:
                neighbors.extend([(context, dist, 'corpus') for context_id, context, dist in corpus_neighbors[layer]])
            if vis_custom_contexts:
                neighbors.extend([(context, dist, 'custom') for context, dist in custom_neighbors[layer]])
            neighbors.sort(key=lambda t: t[1])  # sort all by dist
            
            for neigh_context, neigh_dist, neigh_source in neighbors[:n_neighbors_to_vis]:
                html = f'({neigh_dist:.3f}) '
                if neigh_source == 'masked':
                    html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
                elif neigh_source == 'corpus':
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html += context_html(*neigh_context, marker=marker, masker=masker, token_styler=token_styler)
                    corpus_neighs_so_far.append(neigh_context)
                elif neigh_source == 'custom':
                    html += context_html(*neigh_context, marker=greener, masker=masker, token_styler=token_styler)
                display(HTML(html))
                if save_to_html: html_doc += html + '<br>'
        else:
            if vis_masked_KNNs:
                _masked_neighbors = masked_neighbors[layer]
                if pruning:
                    for mask_len_to_search in range(len(doc)-2):
                        for neigh_context, neigh_dist in _masked_neighbors:
                            neigh_toks, neigh_pos = neigh_context
                            mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
                            if mask_len == mask_len_to_search:
                                html = token_styler(f'({neigh_dist:.3f}) ') +  context_html(*neigh_context, token_styler=token_styler)
                                display(HTML(html))
                                if save_to_html: html_doc += html + '<br>'
                                break
                else:
                    for neigh_context, neigh_dist in _masked_neighbors[:n_neighbors_to_vis]:
                        html = token_styler(f'({neigh_dist:.3f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                        display(HTML(html))
                        if save_to_html: html_doc += html + '<br>'
            if vis_corpus_KNNs:
                _corpus_neighs = corpus_neighbors[layer][:n_neighbors_to_vis]
                for neigh_id, neigh_context, neigh_dist in _corpus_neighs:
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html = token_styler(f'({neigh_dist:.3f})') + context_html(*neigh_context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
                    corpus_neighs_so_far.append(neigh_context)
            if vis_custom_contexts:
                _custom_neighbors = custom_neighbors[layer]
                for context, dist in _custom_neighbors:
                    marker = greener
                    html = token_styler(f'({dist:.3f})') + context_html(*context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
    print('.')
    print('.')
    print('.')

if save_to_html: open('../../../building-blocks.html', 'w').write(html_doc)

.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


## Visualize the evolution of context and all neighbors (In progress)

In [None]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'
import pandas as pd

# compile all neighbors and their acts
layers = layers[:]
neighs_contexts = []
neighs_associations = []  # which layer was this neighbor originally associated with?
neighs_acts = {layer: [] for layer in layers}
for layer in layers[::-1]:
    for neigh_id, neigh_context, _ in corpus_neighbors[layer]:
        if neigh_context not in neighs_contexts:
            neighs_contexts.append(neigh_context)
            neighs_associations.append(layer)
            neigh_doc, neigh_pos = neigh_context
            neigh_doc, neigh_acts = bert.get_toks_and_acts(neigh_doc, tokenized=True)
            for layer_to_grow in layers:
# preprocess
neighs_labels = [context_str(*context) for context in neighs_contexts]
neighs_acts = {layer: np.stack(neighs_acts[layer]) for layer in layers}

# reduce and visualize
reduction = 'TSNE'
corpus_acts = np.load(os.path.join(corpus_dir, subdir, refs.acts_fn))
points_df = pd.DataFrame(columns=['label', 'layer', 'x','y', 'association', 'main'])
for layer in layers:  
    print('Preparing layer', layer)
    _tok_act = doc_acts[layer][selected_tok_idx]
    _neighs_acts = neighs_acts[layer]
    _acts = np.concatenate([[_tok_act,], _neighs_acts])
#     _acts = acts_util.reduce_acts(_acts, reduction='KMeans', dim=10)
    _points = acts_util.reduce_acts(_acts, reduction=reduction, dim=2)
    _xs, _ys = zip(*_points)
    _points_df = pd.DataFrame(dict(label=[context_str(doc, selected_tok_idx)]+neighs_labels, 
                                   x=_xs, y=_ys,
                                   layer=layer, 
                                   association=['main']+neighs_associations,
                                   main=[False]+[True]*(len(_points)-1)
                                  ))
    _points_df = _points_df[_points_df['association'].isin([layer, 'main'])]
    #     points_df = points_df.append(_points_df['layer'==layer], ignore_index=True)
    fig = px.scatter(_points_df, x='x', y='y', color='association', hover_data={'x': False, 'y': False, 'association': False, 'layer': False, 'label': True}, height=300, width=800)
    fig.show()
    
fig = px.scatter(points_df, x='x', y='y', color='association', facet_col='layer', hover_data={'x': False, 'y': False, 'association': False, 'layer': False, 'label': True}, height=300, width=300*len(layers))
# fig.update_traces(mode="markers", hovertemplate=None)
# fig.update_layout(hovermode='label')
fig

## Try a custom sentence

In [None]:
# choose
new_text = 'Later that day, [MASK].'
new_text_tok_of_interest = '[MASK]'

# vis
new_doc, new_acts = bert.get_toks_and_acts(new_text)
new_pos = new_doc.index(new_text_tok_of_interest)
for layer in layers:
    current_act = doc_acts[layer][selected_tok_idx]
    new_act = new_acts[layer][new_pos]
    if spherize: new_act = spherize([new_act])[0]
    display(HTML(f'Layer {layer}'))
    dist = np.linalg.norm(current_act - new_act)
    token_styler = lambda t: font_size(t,max_font_size)
    html = token_styler(f'({dist:.3f})')
    html += context_html(new_doc, new_pos, marker=bluer, token_styler=token_styler)
    display(HTML(html))