# Understand a context through its clusters, its neighbors, and custom possible neighbors
Examine, at each layer, the given context's KNNs - both KNNs in the corpus and the KNNs out of its own masked variants - as well as any custom contexts you're curious about.

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import numpy as np
# from sklearn.neighbors import NearestNeighbors  # for now, switching to FasterNearestNeighbors
from IPython.core.display import display, HTML
import time

import os
import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils.acts_util import spherize, fit_components, reduce_acts
from utils.bert_util import get_masked_variants
from utils.context_util import context_html, context_str
from utils.html_util import highlighter, fix_size, font_size, style
from utils import vis_util, html_util
from utils import references as refs
from utils.SimpleBert import SimpleBert
from utils.FastNearestNeighbors import FastNearestNeighbors

In [2]:
bert = SimpleBert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Parameters

In [3]:
sphere = True  # Spherize the activations? This is equivalent to switching from standard Euclidean distance to cosine distance.
n_layers = 13  # Only visualizes n layers, selected evenly from the available layers. 
              # E.g. if n_layers=5 and 12 layers are avaiilable, will visualize layers 0,3,6,9,1 (n_layers = 2 is good for debugging)
vis_color = False 
reduction, dim = 'NMF', 3  # Reduction with which to color the doc
vis_size = False  # If visualizing color, can also visualize size

vis_masked_KNNs = False
pruning = True  
# pruning is only checked if visualizing masked KNNs.     
# if pruning is true, in the masked neighbors vis, first the closest 1-token-long mask is presented, 
# then jumps to the closest 2-token-long mask, etc - forming a kind of slow pruning to capture the essence of the context.
vis_corpus_KNNs = True
corpus_dir = '/atlas/u/pkalluri/bert-vis/big-data/wiki-large'  # only checked if visualizing corpus KNNs
vis_custom_contexts = True
together = True  # Instead of visualizing all masked KNNs, then all corpus KNNs, etc, do you want to visit them all together, sorted by distance?
save_to_html = False
html_path = '../../../building-blocks.html'  # only checked if saving to html

doc_txt = "Later that day, it caught her eye."
selected_tok = 'caught'

## Get tokens and activations

In [4]:
doc, doc_acts = bert.get_toks_and_acts(doc_txt)
layers = list(doc_acts)
n = int((len(layers)-1)/(n_layers-1)) # will visualize every nth layer
layers = layers[::n]
if sphere: doc_acts = {layer: spherize(doc_acts[layer]) for layer in layers}

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')


Document:
[CLS] Later that day , it caught her eye . [SEP]

Layers: arr_0, arr_1, arr_2, arr_3, arr_4, arr_5, arr_6, arr_7, arr_8, arr_9, arr_10, arr_11, arr_12


## Reduce sample

In [5]:
if vis_color:
    doc_components = {}
    doc_reduced_acts = {}
    for layer in layers:
        _components, _reduced_acts = fit_components(doc_acts[layer], reduction, dim)
        doc_components[layer] = _components
        doc_reduced_acts[layer] = _reduced_acts
    if vis_size:
        doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

## Get neighbors

In [6]:
selected_tok_idx = doc.index(selected_tok)

### Get corpus neighbors

In [7]:
if vis_corpus_KNNs:
    # io
    corpus_dir = os.path.abspath(corpus_dir)
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, refs.contexts_fn),'rb'))
    corpus_acts = np.load(os.path.join(corpus_dir, refs.acts_fn))
    n_neighbors = 20

    # get KNNs
    corpus_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')
        _acts = spherize(corpus_acts[layer]) if sphere else corpus_acts[layer]
        print('Fitting nearest neighbors model.')
        knn_model = FastNearestNeighbors().fit(_acts)
        del _acts
        print('Finding neighbors.')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        del knn_model
        neighs = [(neigh_id, corpus_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids[0], neighs_dists[0])]  # a neighbor is an id, the corresponding context, and the distance away
        corpus_neighbors[layer] = neighs

Layer arr_0
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_1
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_2
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_3
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_4
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_5
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_6
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_7
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_8
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_9
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_10
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_11
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_12
Fitting nearest neighbors model.
Finding neighbors.


### Get masked neighbors

In [8]:
if vis_masked_KNNs:
    mask_lengths = range(max(len(doc)-3, 1),len(doc)-2)  # a parameter used for what mask_lengths to consider. currently hard-coded.
    print('Getting variants.')
    variants = get_masked_variants(doc, mask_lengths)
    print('Getting activations.')
    variants_contexts, variants_acts = bert.get_contexts_and_acts(variants, tokenized=True)
    n_neighbors = 50

    masked_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')
        _acts = spherize(variants_acts[layer]) if sphere else variants_acts[layer]
        print('Fitting nearest neighbors model.')
        knn_model = FastNearestNeighbors().fit(_acts)
        del _acts
        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs = [(variants_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids[0], neighs_dists[0])]  # a neighbor is a context, and the distance away
        masked_neighbors[layer] = neighs
        del knn_model

### Check similarity of custom sentences

In [9]:
if vis_custom_contexts:
    custom_contexts_unprocessed = [
#         ('A young Muslim woman created the artwork.', '[CLS]'),
#         ('A young Christian woman created the artwork.', '[CLS]'),
#         ('A young Muslim woman created the violence.', '[CLS]'),
#         ('A young Christian woman created the violence.', '[CLS]'),
#         ('A young Muslim woman created the object.', '[CLS]'),
#         ('A young Christian woman created the object.', '[CLS]'),
#         ('A Chinese woman walked in.', 'Chinese'),
#         ('A Chinese woman walked in.', 'woman'),
#         ('A Chinese man walked in.', '[CLS]'),
#         ('A Chinese woman walked in.', '[CLS]'),
#         ('A Indian man walked in.', '[CLS]'),
#         ('A Indian woman walked in.', '[CLS]'),
        ('It caught it.', 'caught'),
        ('It caught her eye.', 'caught'),
        ('The rabbit caught the fox\'s eye.', 'caught'),
        ('The rabbit caught the fox\'s attention.', 'caught'),
        ('0 1 2 3 4 caught 6 7 8.', 'caught'),
        ('It [MASK] it.', '[MASK]'),
        ('It [MASK] her eye.', '[MASK]'),
        ('The rabbit [MASK] the fox\'s eye.', '[MASK]'),
        ('The rabbit [MASK] the fox\'s attention.', '[MASK]'),
        ('0 1 2 3 4 [MASK] 6 7 8.', '[MASK]'),
        ('The rabbit caught the carrot.', 'caught')
                                  ]
    custom_neighbors = {layer: [] for layer in layers}
    for custom_txt, custom_tok in custom_contexts_unprocessed:
        custom_doc, custom_doc_acts = bert.get_toks_and_acts(custom_txt)
        custom_pos = custom_doc.index(custom_tok)
        custom_context = (custom_doc, custom_pos)
        for layer in layers:
            _act = doc_acts[layer][selected_tok_idx]
            _custom_act = custom_doc_acts[layer][custom_pos]
            if sphere: 
                _act = spherize([_act])[0]
                _custom_act = spherize([_custom_act])[0]
            dist = np.linalg.norm(_custom_act - _act)
            custom_neighbors[layer].append((custom_context, dist))
    

## Vis

In [10]:
override_params = False  # useful for quick iteration on visualization
if override_params:
    vis_color = False
    vis_size = False
    vis_masked_KNNs = False
    vis_corpus_KNNs = True
    vis_custom_contexts = False
    together = True
    pruning=True  
    save_to_html = False
# other vis params
max_font_size = 10
# params if visualizing neighbors
if vis_corpus_KNNs or vis_masked_KNNs or vis_custom_contexts:
    n_neighbors_to_vis = 50
    token_styler = lambda t: t
    # token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')  # another pretty styling option

# setup
bluer = highlighter('lightblue')
greener = highlighter('limegreen')
greyer = highlighter('grey')
masker = highlighter('black')
if save_to_html: html_doc = ''
if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts: selected_tok_html = context_html(doc, selected_tok_idx)

# start visualizing
# vis legend
if vis_color:
    pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
    print('Legend')
    display(HTML(html))
    if save_to_html: html_doc += html + '<br>'
    print() 
    rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}  # prepare coloring

corpus_neighs_so_far = []  
for layer in layers:
    html = f'Layer {layer}'
    display(HTML(html))
    if save_to_html: html_doc += html
    
    # vis sample
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))    
    if vis_color:
        _rgbs = rgbs[layer]
        color_html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            color_html += html_util.style(f' {tok} ', css=css)
        display(HTML(color_html))
        if save_to_html: html_doc += color_html + '<br>'
    
    # vis knns
    if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts:
        # Show the doc
        display(HTML(selected_tok_html))
        if save_to_html: html_doc += selected_tok_html + '<br>'
        
        if together:
            neighbors = []
            if vis_masked_KNNs:
                neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighbors[layer]])
            if vis_corpus_KNNs:
                neighbors.extend([(context, dist, 'corpus') for context_id, context, dist in corpus_neighbors[layer]])
            if vis_custom_contexts:
                neighbors.extend([(context, dist, 'custom') for context, dist in custom_neighbors[layer]])
            neighbors.sort(key=lambda t: t[1])  # sort all by dist
            
            for neigh_context, neigh_dist, neigh_source in neighbors[:n_neighbors_to_vis]:
                html = f'({neigh_dist:.3f}) '
                if neigh_source == 'masked':
                    html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
                elif neigh_source == 'corpus':
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html += context_html(*neigh_context, marker=marker, masker=masker, token_styler=token_styler)
                    corpus_neighs_so_far.append(neigh_context)
                elif neigh_source == 'custom':
                    html += context_html(*neigh_context, marker=greener, masker=masker, token_styler=token_styler)
                display(HTML(html))
                if save_to_html: html_doc += html + '<br>'
        else:
            if vis_masked_KNNs:
                _masked_neighbors = masked_neighbors[layer]
                if pruning:
                    for mask_len_to_search in range(len(doc)-2):
                        for neigh_context, neigh_dist in _masked_neighbors:
                            neigh_toks, neigh_pos = neigh_context
                            mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
                            if mask_len == mask_len_to_search:
                                html = token_styler(f'({neigh_dist:.3f}) ') +  context_html(*neigh_context, token_styler=token_styler)
                                display(HTML(html))
                                if save_to_html: html_doc += html + '<br>'
                                break
                else:
                    for neigh_context, neigh_dist in _masked_neighbors[:n_neighbors_to_vis]:
                        html = token_styler(f'({neigh_dist:.3f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                        display(HTML(html))
                        if save_to_html: html_doc += html + '<br>'
            if vis_corpus_KNNs:
                _corpus_neighs = corpus_neighbors[layer][:n_neighbors_to_vis]
                for neigh_id, neigh_context, neigh_dist in _corpus_neighs:
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html = token_styler(f'({neigh_dist:.3f})') + context_html(*neigh_context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
                    corpus_neighs_so_far.append(neigh_context)
            if vis_custom_contexts:
                _custom_neighbors = custom_neighbors[layer]
                for context, dist in _custom_neighbors:
                    marker = greener
                    html = token_styler(f'({dist:.3f})') + context_html(*context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
    print('.')
    print('.')
    print('.')

if save_to_html: open(html_path, 'w').write(html_doc)

.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


In [61]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'
import pandas as pd

# reduce and visualize
reduction = 'PCA'
points_df = pd.DataFrame(columns=['label', 'layer', 'x','y', 'main'])
for layer in layers:  
    print('Preparing layer', layer)
    _tok_act = doc_acts[layer][selected_tok_idx]
    _neighs_ids, _neighs_contexts, _  = zip(*corpus_neighbors[layer][:n_neighbors_to_vis])
    _neighs_acts = corpus_acts[layer][_neighs_ids,:]
    _acts = np.concatenate([[_tok_act,], _neighs_acts])
    _points = reduce_acts(_acts, reduction=reduction, dim=2)
    _xs, _ys = zip(*_points)
    _labels = [context_str(doc, selected_tok_idx)] + [context_str(*context) for context in _neighs_contexts]
    _points_df = pd.DataFrame(dict(label=_labels, 
                                   x=_xs, y=_ys,
                                   layer=layer, 
                                   main=[True]+[False]*(len(_points)-1)
                                  ))
    points_df = points_df.append(_points_df, ignore_index=True)
#     fig = px.scatter(_points_df, x='x', y='y', symbol='main', hover_data={'x': False, 'y': False, 'layer': False, 'label': True}, height=300, width=800)
#     fig.show()
# fig = px.scatter(points_df, x='x', y='y', color='main', facet_col='layer', color_discrete_map=color_discrete_map, hover_data={'x': False, 'y': False, 'layer': False, 'label': True}, height=300, width=300*len(layers))
# fig.update_traces(mode="markers", hovertemplate=None)
# fig.update_layout(hovermode='label')
# fig

Preparing layer arr_0
Preparing layer arr_1
Preparing layer arr_2
Preparing layer arr_3
Preparing layer arr_4
Preparing layer arr_5
Preparing layer arr_6
Preparing layer arr_7
Preparing layer arr_8
Preparing layer arr_9
Preparing layer arr_10
Preparing layer arr_11
Preparing layer arr_12


In [62]:
color_discrete_map = {False: 'darkturquoise', True: 'darkorange'}
fig = px.scatter(points_df, x='x', y='y', color='main', facet_col='layer',
                 color_discrete_map=color_discrete_map, #template="none",
                 hover_data={'x': False, 'y': False, 'layer': False, 'label': True}, height=300, width=300*len(layers))
fig.update_traces(marker=dict(size=10, symbol='circle-open', line=dict(width=1.3)))
# fig.update_layout({'plot_bgcolor': 'ghostwhite', 'paper_bgcolor': 'white'})
fig.update_yaxes(matches=None, showticklabels=False, visible=True)
fig.update_xaxes(matches=None, showticklabels=False, visible=True)
fig

In [63]:
reduction = 'PCA'
layer = layers[2]
print('Preparing layer', layer)
_tok_act = doc_acts[layer][selected_tok_idx]
_extra_contexts = corpus_contexts[:100]
_extra_acts = corpus_acts[layer][:100]
_acts = np.concatenate([_extra_acts, [_tok_act,]])
_points = reduce_acts(_acts, reduction=reduction, dim=2)
_xs, _ys = zip(*_points)
_labels = [context_str(*context) for context in _extra_contexts] + [context_str(doc, selected_tok_idx)]
_points_df = pd.DataFrame(dict(label=_labels, 
                               x=_xs, y=_ys,
                               layer=layer, 
                               main=[False]*(len(_points)-1) + [True]
                              ))

Preparing layer arr_2


In [58]:
color_discrete_map = {True: 'darkturquoise', False: 'orange'}
fig = px.scatter(_points_df, x='x', y='y', color='main', facet_col='layer',
                 color_discrete_map=color_discrete_map, #template="none",
                 hover_data={'x': False, 'y': False, 'layer': False, 'label': True}, height=300, width=400
                )
fig.update_layout(showlegend=False)
fig.update_traces(marker=dict(size=10, symbol='circle-open', line=dict(width=1.3)))
fig.update_yaxes(matches=None, showticklabels=False, visible=True)
fig.update_xaxes(matches=None, showticklabels=False, visible=True)
fig

In [None]:
# # import plotly.express as px
# # import plotly.io as pio
# # pio.renderers.default = 'iframe'
# # import pandas as pd

fig.update_traces(marker=dict(size=10,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                              selector=dict(mode='markers'))

# # reduce and visualize
# reduction = 'PCA'
# points_df = pd.DataFrame(columns=['label', 'layer', 'x','y', 'main'])
# layer = layers[0]  
# print('Preparing layer', layer)
# _tok_act = doc_acts[layer][selected_tok_idx]
# _neighs_ids, _neighs_contexts, _  = zip(*corpus_neighbors[layer][:n_neighbors_to_vis])
# _neighs_acts = corpus_acts[layer][_neighs_ids,:]
# _acts = np.concatenate([[_tok_act,], _neighs_acts])
# _points = reduce_acts(_acts, reduction=reduction, dim=2)
# _xs, _ys = zip(*_points)
# _labels = [context_str(doc, selected_tok_idx)] + [context_str(*context) for context in _neighs_contexts]
# _points_df = pd.DataFrame(dict(label=[_labels[0]], 
#                                x=[_xs[0]], y=[_ys[0]],
#                                layer=layer, 
#                                main=[True]
#                               ))
# fig = px.scatter(_points_df, x='x', y='y', color='main', hover_data={'x': False, 'y': False, 'layer': False, 'label': True}, height=300, width=300)
# fig.show()

In [None]:
# import plotly.express as px
# import plotly.io as pio
# pio.renderers.default = 'iframe'
# import pandas as pd

# # compile all neighbors and their acts
# layers = layers[:]
# neighs_contexts = []
# neighs_associations = []  # which layer was this neighbor originally associated with?
# neighs_acts = {layer: [] for layer in layers}
# for layer in layers[::-1]:
#     for neigh_id, neigh_context, _ in corpus_neighbors[layer]:
#         if neigh_context not in neighs_contexts:
#             neighs_contexts.append(neigh_context)
#             neighs_associations.append(layer)
#             neigh_doc, neigh_pos = neigh_context
#             neigh_doc, neigh_acts = bert.get_toks_and_acts(neigh_doc, tokenized=True)
#             for layer_to_grow in layers:
# # preprocess
# neighs_labels = [context_str(*context) for context in neighs_contexts]
# neighs_acts = {layer: np.stack(neighs_acts[layer]) for layer in layers}

# # reduce and visualize
# reduction = 'TSNE'
# corpus_acts = np.load(os.path.join(corpus_dir, subdir, refs.acts_fn))
# points_df = pd.DataFrame(columns=['label', 'layer', 'x','y', 'association', 'main'])
# for layer in layers:  
#     print('Preparing layer', layer)
#     _tok_act = doc_acts[layer][selected_tok_idx]
#     _neighs_acts = neighs_acts[layer]
#     _acts = np.concatenate([[_tok_act,], _neighs_acts])
# #     _acts = acts_util.reduce_acts(_acts, reduction='KMeans', dim=10)
#     _points = acts_util.reduce_acts(_acts, reduction=reduction, dim=2)
#     _xs, _ys = zip(*_points)
#     _points_df = pd.DataFrame(dict(label=[context_str(doc, selected_tok_idx)]+neighs_labels, 
#                                    x=_xs, y=_ys,
#                                    layer=layer, 
#                                    association=['main']+neighs_associations,
#                                    main=[False]+[True]*(len(_points)-1)
#                                   ))
#     _points_df = _points_df[_points_df['association'].isin([layer, 'main'])]
#     #     points_df = points_df.append(_points_df['layer'==layer], ignore_index=True)
#     fig = px.scatter(_points_df, x='x', y='y', color='association', hover_data={'x': False, 'y': False, 'association': False, 'layer': False, 'label': True}, height=300, width=800)
#     fig.show()
    
# fig = px.scatter(points_df, x='x', y='y', color='association', facet_col='layer', hover_data={'x': False, 'y': False, 'association': False, 'layer': False, 'label': True}, height=300, width=300*len(layers))
# # fig.update_traces(mode="markers", hovertemplate=None)
# # fig.update_layout(hovermode='label')
# fig