# See a sample through BERT's eyes (clusters, size, and KNNs)

In [None]:
vis_legend = True
vis_sample = True
vis_size = True
vis_components_knns = True
vis_toks_knns = True

## Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np
from sklearn.neighbors import NearestNeighbors
import glob
import warnings
from typing import List
import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils import acts_util, vis_util, html_util, context_util, bert_util
import references as refs

## Load tokens and activations

In [None]:
# set sample directory
dir_path = os.path.abspath('../../../data/sentences/wells/')

In [None]:
# load toks and acts
tokens_path = os.path.join(dir_path, refs.toks_fn)
acts_path = os.path.join(dir_path, refs.acts_fn)

doc = pickle.load(open(tokens_path, 'rb'))
doc_acts = np.load(acts_path)
layers = doc_acts.files

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')

## Visualize sample

In [None]:
# REDUCE ACTS
# params
reduction, dim = 'NMF', 5
max_font_size = 10

# reduce
doc_components = {}
doc_reduced_acts = {}
for layer in layers:
    _components, _reduced_acts = acts_util.fit_components(doc_acts[layer], reduction, dim)
    doc_components[layer] = _components
    doc_reduced_acts[layer] = _reduced_acts
if vis_size:
    doc_acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

In [None]:
# GET K NEAREST NEIGHBORS FROM DATASET
# params
corpus_dir = '/Users/pkalluri/projects/clarity/bert-vis/bucket/wiki-large/wiki-split/'
corpus_dir = os.path.abspath(corpus_dir)
knn_fn = 'KNN_models_K5.pickle'
n_neighbors = 5
toks_of_interest = [0,1,2,3,4]

if vis_components_knns or vis_toks_knns:
    components_neighborhoods = {}
    tokens_neighborhoods = {}
    corpus_contexts = pickle.load( open(os.path.join(corpus_dir, refs.contexts_fn),'rb') )
    with open(os.path.join(corpus_dir, knn_fn), 'rb') as f:
        for layer in layers:
            print(layer)

            print('Loading nearest neighbors model.')
            model = pickle.load(f)

            print('Finding neighbors')
            _neighborhoods_distances, _neighborhoods = model.kneighbors(doc_components[layer], n_neighbors=3, return_distance=True) # indices
            _neighborhoods_and_distances = zip(_neighborhoods, _neighborhoods_distances)
            components_neighborhoods[layer] = [[(corpus_contexts[idx], dist) for idx, dist in zip(neighborhood, neigh_distances)] for neighborhood, neigh_distances in _neighborhoods_and_distances] # contexts

            _neighborhoods_distances, _neighborhoods = model.kneighbors(doc_acts[layer][toks_of_interest], n_neighbors=3, return_distance=True) # indices
            _neighborhoods_and_distances = zip(_neighborhoods, _neighborhoods_distances)
            tokens_neighborhoods[layer] = [[(corpus_contexts[idx], dist) for idx, dist in zip(neighborhood, neigh_distances)] for neighborhood, neigh_distances in _neighborhoods_and_distances] # contexts

            del model

In [None]:
# VISUALIZE

# legend
pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
if vis_legend:
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        css = f'background-color: {html_util.rgb_to_color(*rgb)}'
        html += html_util.style(f' {i} ', css=css)
    print('Legend')
    display(HTML(html))
    print()

# vis
rgbs = {layer:vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers}
for layer_idx, layer in enumerate(layers):
    _rgbs = rgbs[layer]
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))
    display(HTML(f'Layer {layer_idx}'))
    
    if vis_sample:
        html = ''
        for pos, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
            html += html_util.style(f' {tok} ', css=css)
        display(HTML(html))

    if vis_components_knns:
        _neighborhoods = components_neighborhoods[layer]
        for neighborhood_idx, (pure_rgb, neighborhood) in enumerate(zip(pure_rgbs, _neighborhoods)):
            color = html_util.rgb_to_color(*pure_rgb)
            display(HTML(f'Cluster {neighborhood_idx} KNNs:'))
            for context, dist in neighborhood:
                html = f'({dist:.2f}) ' + context_util.context_html(*context, marker=html_util.highlighter(color=color))
                display(HTML(html))
    if vis_toks_knns:
        _neighborhoods = tokens_neighborhoods[layer]
        for pos, neighborhood in zip(toks_of_interest, _neighborhoods):
            display(HTML(f'Token {pos} ({doc[pos]}) KNNs:'))
            for context, dist in neighborhood:
                html = ''
                html += f'({dist:.2f}) '
                html += context_util.context_html(*context)
                display(HTML(html))
    print('.')
    print('.')
    print('.')            