# Understand a context through its clusters, its neighbors, and custom possible neighbors
Examine, at each layer, the given context's KNNs - both KNNs in the corpus and the KNNs out of its own masked variants - as well as any custom contexts you're curious about.

## Imports

In [1]:
'test'

'test'

In [2]:
%load_ext autoreload
%autoreload 2

import pickle
import numpy as np
# from sklearn.neighbors import NearestNeighbors  # for now, switching to FasterNearestNeighbors
from IPython.core.display import display, HTML
import time

import os
import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils.acts_util import spherize, fit_components
from utils.bert_util import get_masked_variants
from utils.context_util import context_html, context_str
from utils.html_util import highlighter, fix_size, font_size, style
from utils import vis_util, html_util
from utils import references as refs
from utils.SimpleGPT2 import SimpleGPT2
from utils.FastNearestNeighbors import FastNearestNeighbors

In [16]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
headmodel = GPT2LMHeadModel.from_pretrained('gpt2')

In [3]:
model = SimpleGPT2(gpt2_type='gpt2')

## Parameters

In [4]:
sphere = True  # Spherize the activations? This is equivalent to switching from standard Euclidean distance to cosine distance.
n_layers = None  # Only visualizes n layers, selected evenly from the available layers. 
              # E.g. if n_layers=5 and 12 layers are avaiilable, will visualize layers 0,3,6,9,1 (n_layers = 2 is good for debugging)
vis_color = True 
reduction, dim = 'NMF', 10  # Reduction with which to color the doc
vis_size = False  # If visualizing color, can also visualize size

vis_masked_KNNs = False
pruning = False  
# pruning is only checked if visualizing masked KNNs.     
# if pruning is true, in the masked neighbors vis, first the closest 1-token-long mask is presented, 
# then jumps to the closest 2-token-long mask, etc - forming a kind of slow pruning to capture the essence of the context.
vis_corpus_KNNs = True
corpus_dir = '/atlas/u/pkalluri/bert-vis/big-data/gpt-wiki-large'  # only checked if visualizing corpus KNNs
vis_custom_contexts = True
together = True  # Instead of visualizing all masked KNNs, then all corpus KNNs, etc, do you want to visit them all together, sorted by distance?
save_to_html = False
html_path = '../../../building-blocks.html'  # only checked if saving to html

doc_txt = 'Later that day, he caught her eye.'
# doc_txt = '''
# Alice was beginning to get very tired of sitting by her sister on the
# bank, and of having nothing to do: once or twice she had peeped into the
# book her sister was reading, but it had no pictures or conversations in
# it, 'and what is the use of a book,' thought Alice 'without pictures or
# conversation?'

# So she was considering in her own mind (as well as she could, for the
# hot day made her feel very sleepy and stupid), whether the pleasure
# of making a daisy-chain would be worth the trouble of getting up and
# picking the daisies, when suddenly a White Rabbit with pink eyes ran
# close by her.

# There was nothing so VERY remarkable in that; nor did Alice think it so
# VERY much out of the way to hear the Rabbit say to itself, 'Oh dear!
# Oh dear! I shall be late!' (when she thought it over afterwards, it
# occurred to her that she ought to have wondered at this, but at the time
# it all seemed quite natural); but when the Rabbit actually TOOK A WATCH
# OUT OF ITS WAISTCOAT-POCKET, and looked at it, and then hurried on,
# Alice started to her feet, for it flashed across her mind that she had
# never before seen a rabbit with either a waistcoat-pocket, or a watch
# to take out of it, and burning with curiosity, she ran across the field
# after it, and fortunately was just in time to see it pop down a large
# rabbit-hole under the hedge.
# '''
selected_tok = 'caught'

In [33]:
inputs = tokenizer(doc_txt, return_tensors="pt")
outputs = headmodel(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
logits = outputs.logits
outs = np.argmax(logits[0].tolist(), axis=1)
for i, out in enumerate(outs):
    print(f'{tokenizer.decode(inputs["input_ids"][0][:i+1].tolist())} --> {tokenizer.decode([out,])}')

Later --> ,
Later that -->  day
Later that day --> ,
Later that day, -->  he
Later that day, he -->  was
Later that day, he caught -->  a
Later that day, he caught her -->  in
Later that day, he caught her eye --> .
Later that day, he caught her eye. -->  "


## Get tokens and activations

In [5]:
doc, doc_acts = model.get_toks_and_acts(doc_txt)
layers = list(doc_acts)
if n_layers:
    n = int((len(layers)-1)/(n_layers-1)) # will visualize every nth layer
else:
    n = 1
layers = layers[::n]
if sphere: doc_acts = {layer: spherize(doc_acts[layer]) for layer in layers}

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')


Document:
Later that day , he caught her eye .

Layers: arr_0, arr_1, arr_2, arr_3, arr_4, arr_5, arr_6, arr_7, arr_8, arr_9, arr_10, arr_11, arr_12


## Reduce sample

In [6]:
import warnings
warnings.filterwarnings('ignore')
if vis_color:
    doc_components = {}
    doc_reduced_acts = {}
    for layer in layers:
        _components, _reduced_acts = fit_components(doc_acts[layer], reduction, dim)
        doc_components[layer] = _components
        doc_reduced_acts[layer] = _reduced_acts
    if vis_size:
        doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

## Get neighbors

In [7]:
if vis_masked_KNNs or vis_corpus_KNNs: selected_tok_idx = doc.index(selected_tok)

### Get corpus neighbors

In [8]:
if vis_corpus_KNNs:
    # io
    corpus_dir = os.path.abspath(corpus_dir)
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, refs.contexts_fn),'rb'))
    corpus_acts = np.load(os.path.join(corpus_dir, refs.acts_fn))
    n_neighbors = 20

    # get KNNs
    corpus_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')
        _acts = spherize(corpus_acts[layer]) if sphere else corpus_acts[layer]
        print('Fitting nearest neighbors model.')
        knn_model = FastNearestNeighbors().fit(_acts)
        del _acts
        print('Finding neighbors.')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        del knn_model
        neighs = [(neigh_id, corpus_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids[0], neighs_dists[0])]  # a neighbor is an id, the corresponding context, and the distance away
        corpus_neighbors[layer] = neighs

Layer arr_0
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_1
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_2
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_3
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_4
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_5
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_6
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_7
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_8
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_9
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_10
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_11
Fitting nearest neighbors model.
Finding neighbors.
Layer arr_12
Fitting nearest neighbors model.
Finding neighbors.


### Get masked neighbors - does this apply?

In [9]:
if vis_masked_KNNs:
    mask_lengths = range(max(len(doc)-3, 1),len(doc)-2)  # a parameter used for what mask_lengths to consider. currently hard-coded.
    print('Getting variants.')
    variants = get_masked_variants(doc, mask_lengths)
    print('Getting activations.')
    variants_contexts, variants_acts = model.get_contexts_and_acts(variants, tokenized=True)
    n_neighbors = 20

    masked_neighbors = {}
    for layer in layers:
        print(f'Layer {layer}')
        _acts = spherize(variants_acts[layer]) if sphere else variants_acts[layer]
        print('Fitting nearest neighbors model.')
        knn_model = FastNearestNeighbors().fit(_acts)
        print('Finding neighbors')
        _doc_acts = doc_acts[layer]
        neighs_dists, neighs_ids = knn_model.kneighbors([_doc_acts[selected_tok_idx]], n_neighbors=n_neighbors, return_distance=True)
        neighs = [(variants_contexts[neigh_id], neigh_dist) for neigh_id, neigh_dist in zip(neighs_ids[0], neighs_dists[0])]  # a neighbor is a context, and the distance away
        masked_neighbors[layer] = neighs
        del knn_model

### Check similarity of custom sentences

In [10]:
if vis_custom_contexts:
    custom_contexts_unprocessed = [
        ('Later that day, he caught her eye.', 'day'),
        ('Later that day, he caught her eye.', 'he'),
        ('Later that day, he caught her eye.', 'her'),
        ('Later that day, he caught her eye.', 'eye'),
        ]
    custom_neighbors = {layer: [] for layer in layers}
    for custom_txt, custom_tok in custom_contexts_unprocessed:
        custom_doc, custom_doc_acts = model.get_toks_and_acts(custom_txt)
        custom_pos = custom_doc.index(custom_tok)
        custom_context = (custom_doc, custom_pos)
        for layer in layers:
            _act = doc_acts[layer][selected_tok_idx]
            _custom_act = custom_doc_acts[layer][custom_pos]
            if sphere: 
                _act = spherize([_act])[0]
                _custom_act = spherize([_custom_act])[0]
            dist = np.linalg.norm(_custom_act - _act)
            custom_neighbors[layer].append((custom_context, dist))
    

## Vis

In [15]:
# VISUALIZE

override_params = False  # useful for quick iteration on visualization
if override_params:
    vis_color = False
    vis_size = False
    vis_masked_KNNs = False
    vis_corpus_KNNs = True
    vis_custom_contexts = False
    together = True
    pruning=True  
    save_to_html = False
# other vis params
max_font_size = 10
# params if visualizing neighbors
if vis_corpus_KNNs or vis_masked_KNNs or vis_custom_contexts:
    n_neighbors_to_vis = 50
    token_styler = lambda t: t
    # token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')  # another pretty styling option

# setup
bluer = highlighter('lightblue')
greener = highlighter('limegreen')
greyer = highlighter('grey')
masker = highlighter('black')
if save_to_html: html_doc = ''
if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts: selected_tok_html = context_html(doc, selected_tok_idx)

# start visualizing
# vis legend
if vis_color:
    pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
    print('Legend')
    display(HTML(html))
    if save_to_html: html_doc += html + '<br>'
    print() 
    rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}  # prepare coloring

corpus_neighs_so_far = []  
for layer in layers:
    html = f'Layer {layer}'
    display(HTML(html))
    if save_to_html: html_doc += html
    
    # vis sample
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))    
    if vis_color:
        _rgbs = rgbs[layer]
        color_html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            color_html += html_util.style(f' {tok} ', css=css)
        display(HTML(color_html))
        if save_to_html: html_doc += color_html + '<br>'
    
    # vis knns
    if vis_masked_KNNs or vis_corpus_KNNs or vis_custom_contexts:
        # Show the doc
        display(HTML(selected_tok_html))
        if save_to_html: html_doc += selected_tok_html + '<br>'
        
        if together:
            neighbors = []
            if vis_masked_KNNs:
                neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighbors[layer]])
            if vis_corpus_KNNs:
                neighbors.extend([(context, dist, 'corpus') for context_id, context, dist in corpus_neighbors[layer]])
            if vis_custom_contexts:
                neighbors.extend([(context, dist, 'custom') for context, dist in custom_neighbors[layer]])
            neighbors.sort(key=lambda t: t[1])  # sort all by dist
            
            for neigh_context, neigh_dist, neigh_source in neighbors[:n_neighbors_to_vis]:
                html = f'({200*neigh_dist:.3f}) '
                if neigh_source == 'masked':
                    html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
                elif neigh_source == 'corpus':
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html += context_html(*neigh_context, marker=marker, masker=masker, token_styler=token_styler)
                    corpus_neighs_so_far.append(neigh_context)
                elif neigh_source == 'custom':
                    html += context_html(*neigh_context, marker=greener, masker=masker, token_styler=token_styler)
                display(HTML(html))
                if save_to_html: html_doc += html + '<br>'
        else:
            if vis_masked_KNNs:
                _masked_neighbors = masked_neighbors[layer]
                if pruning:
                    for mask_len_to_search in range(len(doc)-2):
                        for neigh_context, neigh_dist in _masked_neighbors:
                            neigh_toks, neigh_pos = neigh_context
                            mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
                            if mask_len == mask_len_to_search:
                                html = token_styler(f'({200*neigh_dist:.3f}) ') +  context_html(*neigh_context, token_styler=token_styler)
                                display(HTML(html))
                                if save_to_html: html_doc += html + '<br>'
                                break
                else:
                    for neigh_context, neigh_dist in _masked_neighbors[:n_neighbors_to_vis]:
                        html = token_styler(f'({200*neigh_dist:.3f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                        display(HTML(html))
                        if save_to_html: html_doc += html + '<br>'
            if vis_corpus_KNNs:
                _corpus_neighs = corpus_neighbors[layer][:n_neighbors_to_vis]
                for neigh_id, neigh_context, neigh_dist in _corpus_neighs:
                    marker = bluer if not neigh_context in corpus_neighs_so_far else greyer
                    html = token_styler(f'({200*neigh_dist:.3f})') + context_html(*neigh_context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
                    corpus_neighs_so_far.append(neigh_context)
            if vis_custom_contexts:
                _custom_neighbors = custom_neighbors[layer]
                for context, dist in _custom_neighbors:
                    marker = greener
                    html = token_styler(f'({200*dist:.3f})') + context_html(*context, marker=marker, token_styler=token_styler)
                    display(HTML(html))
                    if save_to_html: html_doc += html + '<br>'
    print('.')
    print('.')
    print('.')

if save_to_html: open(html_path, 'w').write(html_doc)

Legend





.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.


.
.
.
