# Understand a context through its clusters and its neighbors
Examine, at each layer, the given context's KNNs - both KNNs in the corpus and the KNNs out of its own masked variants.

## Imports

In [2]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np
from sklearn.neighbors import NearestNeighbors
import glob
import warnings
from typing import List
import collections
import itertools
import random

import sys
sys.path.insert(0, os.path.abspath('../../..'))
from src.utils import acts_util, vis_util, html_util, context_util, bert_util
from src.utils.context_util import context_html
from src.utils.html_util import highlighter, fix_size, font_size, style
from src.utils import references as refs
from src.utils.SimpleBert import SimpleBert

## Parameters

In [3]:
spherize = True
vis_color = True
vis_size = True
vis_masked_KNNs = True
vis_corpus_KNNs = True
doc_txt = "The sky was blue."

## Get tokens and activations

In [4]:
bert = SimpleBert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
doc, doc_acts = bert.get_toks_and_acts(doc_txt)
layers = [layer for layer in doc_acts] # may be dict or npz
layers = layers[:2] # for debugging
if spherize: doc_acts = {layer: acts_util.spherize(doc_acts[layer]) for layer in layers}

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')


Document:
[CLS] The sky was blue . [SEP]

Layers: arr_0, arr_1


## Reduce sample

In [12]:
# params
reduction, dim = 'NMF', 3

# reduce
doc_components = {}
doc_reduced_acts = {}
for layer in layers:
    _components, _reduced_acts = acts_util.fit_components(doc_acts[layer], reduction, dim)
    doc_components[layer] = _components
    doc_reduced_acts[layer] = _reduced_acts
if vis_size:
    doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

## Get corpus neighbors

In [13]:
toks_of_interest = range(len(doc))

In [15]:
if vis_corpus_KNNs:
    
    # GET K NEAREST NEIGHBORS FROM DATASET
    # params
    corpus_dir = '/Users/pkalluri/projects/clarity/bert-vis/big-data/wiki-large'
    subdir = 'sphere' if spherize else'standard'
    knn_path = os.path.join(corpus_dir, subdir, refs.knn_models_fn)
    n_neighbors = 20

    # get KNNs
    corpus_neighborhoods = {}
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, subdir, refs.contexts_fn),'rb'))
    with open(knn_path, 'rb') as f:
        for layer in layers:
            print(f'Layer {layer}')

            print('Loading nearest neighbors model.')
            knn_model = pickle.load(f)

            print('Finding neighbors')
            _doc_acts = doc_acts[layer]
            # a concise neighborhood is a single tuple of (neighbors' dists to the neighborhood's true token, neighbors' idxs)
            concise_neighborhoods = zip(*knn_model.kneighbors(_doc_acts[toks_of_interest], n_neighbors=n_neighbors, return_distance=True))
            # We want a more intuitive and useful representation:
            # a neighborhood contains a list of neighbors; a neighbor is a tuple of (context, its dist to the true token)    
            neighborhoods = []  
            for concise_neighborhood in concise_neighborhoods:
                neighborhood = [(corpus_contexts[neigh_idx], neigh_dist) for (neigh_dist, neigh_idx) in zip(*concise_neighborhood)]
                neighborhoods.append(neighborhood)
            corpus_neighborhoods[layer] = neighborhoods
            del knn_model


Layer arr_0
Loading nearest neighbors model.
Finding neighbors
Layer arr_1
Loading nearest neighbors model.
Finding neighbors


## Get masked neighbors

In [20]:
if vis_masked_KNNs:
    # params
    mask_lengths = range(1, len(doc)-2)
    # GET MASKED VARIANTS OF SAMPLE
    variants = [doc] + bert_util.get_masked_variants(doc, mask_lengths)
    variants_contexts, variants_acts = bert.get_contexts_and_acts(variants, tokenized=True)
    if spherize: variants_acts = {layer: acts_util.spherize(variants_acts[layer]) for layer in layers}
    
    # GET KNNs FROM MASKED VARIANTS
    # params
    n_neighbors = 50

    # get KNNs
    masked_neighborhoods = {}
    for layer in layers:
        print(layer)

        print('Fit nearest neighbors model.')
        knn_model = NearestNeighbors(n_neighbors=n_neighbors).fit(variants_acts[layer])

        print('Finding neighbors')
        _variants_acts = variants_acts[layer]
        # a concise neighborhood is a single tuple of (neighbors' dists to the neighborhood's true token, neighbors' idxs)
        concise_neighborhoods = zip(*knn_model.kneighbors(_variants_acts[toks_of_interest], n_neighbors=n_neighbors, return_distance=True))
        # We want a more intuitive and useful representation:
        # a neighborhood contains a list of neighbors; a neighbor is a tuple of (context, its dist to the true token)    
        neighborhoods = []  
        for concise_neighborhood in concise_neighborhoods:
            neighborhood = [(variants_contexts[neigh_idx], neigh_dist) for (neigh_dist, neigh_idx) in zip(*concise_neighborhood)]
            neighborhoods.append(neighborhood)
        masked_neighborhoods[layer] = neighborhoods
        del knn_model

arr_0
Fit nearest neighbors model.
Finding neighbors
arr_1
Fit nearest neighbors model.
Finding neighbors


## Vis

In [24]:
# VISUALIZE

override_params = True  # useful for quick iteration on visualization
if override_params:
    vis_color = True
    vis_size = True
    vis_masked_KNNs = True
    vis_corpus_KNNs = True
# other vis params
max_font_size = 10
n_neighbors_to_vis = 30
together = False
pruning=True  
# if true, in the masked neighbors vis, first the closest 0-length mask is presented, 
# then jumps to the closes 1-length max, etc - forming a kind of slow pruning to capture the essence of the context.
token_styler = lambda t: t
# token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')  # another styling option

# vis
rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}
tok_of_interest = '[CLS]'
tok_of_interest_idx = doc.index(tok_of_interest)
neighborhood_idx = toks_of_interest.index(tok_of_interest_idx)  # it should be one of the toks of interest we searched for
tok_of_interest_html = context_html(doc, toks_of_interest[tok_of_interest_idx])
display(HTML(tok_of_interest_html))
print()

# vis legend
if vis_color:
    pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
    print('Legend')
    display(HTML(html))
    print() 
        
bluer = highlighter('lightblue')
for layer_idx, layer in list(enumerate(layers))[:]:
    display(HTML(f'Layer {str(layer_idx)}'))
    # vis sample
    _rgbs = rgbs[layer]
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))    
    if vis_color:
        color_html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            color_html += html_util.style(f' {tok} ', css=css)
        display(HTML(tok_of_interest_html))
        display(HTML(color_html))
    
    # vis knns
    if together:
        neighbors = []
        if vis_masked_KNNs:
            neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighborhoods[layer][neighborhood_idx]][1:])
        if vis_corpus_KNNs:
            neighbors.extend([(context, dist, 'corpus') for context, dist in corpus_neighborhoods[layer][neighborhood_idx]][1:])
        neighbors.sort(key=lambda t: t[1])
        for neigh_context, neigh_dist, neigh_dataset in neighbors[:n_neighbors_to_vis]:
            html = f'({neigh_dist:.3f}) '
            if neigh_dataset == 'masked':
                html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
            elif neigh_dataset == 'corpus':
                html += context_html(*neigh_context, marker=bluer, masker=masker, token_styler=token_styler)
            display(HTML(html))
    else:
        if vis_masked_KNNs:
            neighbors = masked_neighborhoods[layer][neighborhood_idx]
            if pruning:
                for mask_len_to_find in range(len(doc)-2):
                    for neigh_context, neigh_dist in neighbors:
                        neigh_toks, neigh_pos = neigh_context
                        mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
                        if mask_len == mask_len_to_find:
                            html = token_styler(f'({neigh_dist:.3f}) ') +  context_html(*neigh_context, token_styler=token_styler)
                            display(HTML(html))
                            break
            else:
                for neigh_context, neigh_dist in neighbors[:n_neighbors_to_vis]:  # masked neighbors
                    html = token_styler(f'({neigh_dist:.3f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                    display(HTML(html))
        if vis_corpus_KNNs:
            neighbors = corpus_neighborhoods[layer][tok_of_interest_idx][:n_neighbors_to_vis]
            for neigh_context, neigh_dist in neighbors:  # corpus neighbors
                html = token_styler(f'{neigh_dist:.3f}') + context_html(*neigh_context, marker=bluer, token_styler=token_styler)
                display(HTML(html))
    print('.')
    print('.')
    print('.')  



Legend





.
.
.


.
.
.


## Try a custom sentence

In [26]:
# choose
layer = layers[1]
new_text = 'The cluster lies about .'
new_text_tok_of_interest = '[CLS]'

# vis
new_doc, new_acts = bert.get_toks_and_acts(new_text)
new_pos = new_doc.index(new_text_tok_of_interest)
current_act = doc_acts[layer][tok_of_interest_idx]
if spherize: current_act = acts_util.spherize([current_act])[0]
new_act = new_acts[layer][new_pos]
if spherize: new_act = acts_util.spherize([new_act])[0]
display(HTML(f'Layer {layer}'))
dist = np.linalg.norm(current_act - new_act)
token_styler = lambda t: font_size(t,max_font_size)
html = token_styler(f'{dist:.3f}'[:])
html += context_html(new_doc, new_pos, marker=bluer, token_styler=token_styler)
display(HTML(html))