# See a sample through BERT's eyes (clusters, size, and KNNs)

## Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import os
import pickle
import numpy as np
import glob
from typing import List
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import warnings
from IPython.core.display import display, HTML

import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils import context_util, gpt_util, acts_util, vis_util, html_util


## Get doc and acts

In [None]:
# Get doc and acts from str
# doc = 'The second floor exhibits classical art.'

# Or from textfile
doc = open('/Users/pkalluri/projects/clarity/bert-vis/data/paragraphs/alice-1/doc.txt', 'r').read()
print(doc)

# Get toks and acts
doc, doc_acts = gpt_util.get_doc_acts(doc)
layers = list(range(len(doc_acts)))

## Visualize sample

In [None]:
# REDUCE ACTS

# params
reduction, dim = 'KMeans', 6
vis_size = True

# reduce
doc_components = []
doc_reduced_acts = []
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for layer in layers:
        _doc_acts = doc_acts[layer]
        _doc_components, _doc_reduced_acts = acts_util.fit_components(_doc_acts, reduction, dim)
        doc_components.append(_doc_components)
        doc_reduced_acts.append(_doc_reduced_acts)
if vis_size:
    doc_acts_sizes = [np.linalg.norm(doc_acts[layer], axis=1) for layer in layers]

In [None]:
# GET K NEAREST NEIGHBORS FROM DATASET


# params
vis_KNNs = False
corpus_dir = ''
contexts_fn = 'contexts.pickle'
knn_fn = 'KNN_models_K5.pickle'
n_neighbors = 10
toks_of_interest = [0,1,2,3,4]

# get KNNs
if vis_KNNs:
    components_neighborhoods = []
    toks_neighborhoods = []
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, contexts_fn),'rb'))
    with open(os.path.join(corpus_dir, knn_fn), 'rb') as f:
        for layer in layers:
            print(layer)

            print('Loading nearest neighbors model.')
            knn_model = pickle.load(f)

            print('Finding neighbors')
            # a concise neighborhood is a single tuple of (neighbors' dists to the neighborhood's true token, neighbors' idxs)
            _concise_neighborhoods = zip(*knn_model.kneighbors(doc_components[layer], n_neighbors=n_neighbors, return_distance=True))
            # We want a more intuitive and useful representation:
            # a neighborhood contains a list of neighbors; a neighbor is a tuple of (context, its dist to the true token)    
            _neighborhoods = [[(corpus_contexts[idx], dist) for (dist, idx) in zip(*concise_neighborhood)] 
                              for concise_neighborhood in _concise_neighborhoods]  
            components_neighborhoods.append(_neighborhoods)


            _concise_neighborhoods = zip(*knn_model.kneighbors(doc_acts[layer][toks_of_interest], n_neighbors=n_neighbors, return_distance=True))
            _neighborhoods = [[(corpus_contexts[idx], dist) for (dist, idx) in zip(*concise_neighborhood)] 
                      for concise_neighborhood in _concise_neighborhoods]  
            toks_neighborhoods.append(_neighborhoods) 
            del knn_model


In [None]:
# VISUALIZE
# params
vis_legend = True
vis_sample = True
vis_components_KNNs = False
vis_tokens_KNNs = False
vis_size = False
max_font_size = 12

# legend
pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
if vis_legend:
    html = ''
    for i, rgb in enumerate(pure_rgbs):
        css = f'background-color: {html_util.rgb_to_color(*rgb)}'
        html += html_util.style(f' {i} ', css=css)
    print('Legend')
    display(HTML(html))
    print()

# vis
rgbs = [vis_util.channels_to_rgbs(doc_reduced_acts[layer]) for layer in layers]
for layer in layers:
    _rgbs = rgbs[layer]
    if vis_size:
        _sizes = doc_acts_sizes[layer]
        _normalized_sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))
    display(HTML(f'Layer {layer}'))
    if vis_sample:
        # vis sample
        html = ''
        for tok_idx, tok in enumerate(doc):
            if vis_size:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[tok_idx])}; font-size: {_normalized_sizes[tok_idx]*max_font_size}pt;'
            else:
                css = f'background-color: {html_util.rgb_to_color(*_rgbs[tok_idx])}; font-size: {max_font_size}pt;'
            html += html_util.style(f' {tok} ', css=css)
        display(HTML(html))
    
    # vis neighbors
    if vis_components_KNNs:
        _neighborhoods = components_neighborhoods[layer]
        for neighborhood_idx, (pure_rgb, neighborhood) in enumerate(zip(pure_rgbs, _neighborhoods)):
            color = html_util.rgb_to_color(*pure_rgb)
            display(HTML(f'Cluster {neighborhood_idx} KNNs:'))
            for context, dist in neighborhood:
                html = f'({dist:.2f}) ' + context_util.context_html(*context, marker=html_util.highlighter(color=color))
                html = html_util.style(html, css = 'font-size: 10px;')
                display(HTML(html))
    if vis_tokens_KNNs:
        _neighborhoods = toks_neighborhoods[layer]
        for tok_idx, neighborhood in zip(toks_of_interest, _neighborhoods):
            display(HTML(f'Token {tok_idx} ({doc[tok_idx]}) KNNs:'))
            for context, dist in neighborhood:
                html = f'({dist:.2f}) ' + context_util.context_html(*context)
                html = html_util.style(html, css = 'font-size: 10px;')
                display(HTML(html))
#     print('.')
#     print('.')
#     print('.')            