# See a sample through BERT's eyes (clusters, size, and KNNs)

## Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np
from sklearn.neighbors import NearestNeighbors
import glob

import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils import acts_util, vis_util, html_util, context_util

## Load tokens and activations

In [None]:
# Set sample directory
dir_path = os.path.abspath('../../../data/short-sentence/')

In [None]:
# Checking files are valid
name = dir_path.split('/')[-1] 
tokens_path = os.path.join(dir_path, "tokens.pickle")
acts_path = os.path.join(dir_path, f"activations.npz")

print(f'Directory: \'{name}\'')
print(f'Path to tokens: \'{tokens_path}\'')
assert os.path.exists(tokens_path), f'File does not exist: {os.path.abspath(tokens_path)}'
print(f'Path to reduced activations: \'{acts_path}\'')
assert os.path.exists(acts_path), f'File does not exist: {os.path.abspath(acts_path)}'

In [None]:
with open(tokens_path, 'rb') as f:
    tokens = pickle.load(f)
doc = ' '.join(tokens)
layer_to_acts = np.load(acts_path)
layers = layer_to_acts.files

print('\nDocument:')
print(doc)
print(f'\nLayers: {", ".join(layers)}')

## Visualize activations, reduced by NMF

In [None]:
# REDUCE ACTS
reduction, dim = 'NMF', 3
visualize_size = True
max_font_size = 10

layer_to_components = {}
layer_to_reduced_acts = {}
for layer, acts in layer_to_acts.items():
    components, reduced_acts = acts_util.fit_components(acts, reduction, dim)
    layer_to_components[layer] = components
    layer_to_reduced_acts[layer] = reduced_acts
if visualize_size:
    layer_to_acts_sizes = {layer: np.linalg.norm(acts, axis=1) for layer, acts in layer_to_acts.items()}

In [None]:
# GET ACTS NEAREST NEIGHBORS
visualize_NNs = True
corpus_dir = '/Users/pkalluri/projects/clarity/bert-vis/bucket/wiki-large/wiki-split/'
contexts_fn = 'contexts.pickle'
acts_fn = 'activations.npz'
knn_fn = 'KNN_models_K5.pickle'
n_neighbors = 5

print(f'Path to corpus: {corpus_dir}')
assert os.path.exists(os.path.join(corpus_dir)), f'File does not exist: {corpus_dir}'
assert os.path.exists(os.path.join(corpus_dir, knn_fn)), f'File does not exist: {corpus_dir}'
assert os.path.exists(os.path.join(corpus_dir, contexts_fn)), f'Contexts does not exist.'
assert os.path.exists(os.path.join(corpus_dir, acts_fn)), f'Contexts does not exist.'

layer_to_comp_neighborhoods = {}
layer_to_target_neighborhoods = {}
if visualize_NNs:
    corpus_contexts = pickle.load(open(os.path.abspath(os.path.join(corpus_dir, contexts_fn)),'rb'))
    with open(os.path.join(corpus_dir, knn_fn), 'rb') as f:
        for layer in layers:
            print(layer)

            print('Loading nearest neighbors model.')
            model = pickle.load(f)

            print('Finding neighbors')
            components = layer_to_components[layer]
            neighborhoods = model.kneighbors(components, n_neighbors=3, return_distance=False) # indices
            neighborhoods = [[corpus_contexts[idx] for idx in neighborhood] for neighborhood in neighborhoods] # contexts
            layer_to_comp_neighborhoods[layer] = neighborhoods

            neighborhoods = model.kneighbors([acts[0]], n_neighbors=3, return_distance=False) # indices
            neighborhoods = [[corpus_contexts[idx] for idx in neighborhood] for neighborhood in neighborhoods] # contexts
            layer_to_target_neighborhoods[layer] = neighborhoods
            
            del model
        

In [None]:
# VISUALIZE

# legend
pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
html = ''
for i, rgb in enumerate(pure_rgbs):
    css = f'background-color: {html_util.rgb_to_color(*rgb)}'
    html += html_util.style(f' {i} ', css=css)
print('Legend')
display(HTML(html))

# vis
layer_to_rgbs = {layer:vis_util.channels_to_rgbs(reduced_acts) for layer,reduced_acts in layer_to_reduced_acts.items()}
for layer in layers:
    rgbs = layer_to_rgbs[layer]
    if visualize_size:
        sizes = layer_to_acts_sizes[layer]
        normalized_sizes = (sizes - np.min(sizes)) / (np.max(sizes) - np.min(sizes))
    print(layer)
    # vis sample
    html = ''
    for tok_idx, tok in enumerate(tokens):
        if visualize_size:
            css = f'background-color: {html_util.rgb_to_color(*rgbs[tok_idx])}; font-size: {normalized_sizes[tok_idx]*max_font_size}pt;'
        else:
            css = f'background-color: {html_util.rgb_to_color(*rgbs[tok_idx])}; font-size: {max_font_size}pt;'
        html += html_util.style(f' {tok} ', css=css)
    display(HTML(html))
    
    # vis neighbors
    neighborhoods = layer_to_comp_neighborhoods[layer]
    for pure_rgb, neighborhood in zip(pure_rgbs, neighborhoods):
        color = html_util.rgb_to_color(*pure_rgb)
        for context in neighborhood:
            html = context_util.context_html(*context, highlighter=html_util.highlighter(color=color))
            display(HTML(html))
    neighborhoods = layer_to_target_neighborhoods[layer]
    for i, neighborhood in enumerate(neighborhoods):
        print(f'Target {i+1}')
        for context in neighborhood:
            html = context_util.context_html(*context)
            display(HTML(html))
            