# See a sample through BERT's eyes (clusters and size)

## Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np

import sys
sys.path.insert(0, os.path.abspath('../..'))
import references as refs
from utils import acts_util, vis_util, html_util
from utils.html_util import style, rgb_to_color

## Load tokens and activations

In [None]:
# Set parameters
dir_path = '../../../data/sentences/wells/'
tokens_path = os.path.join(dir_path, refs.toks_fn)
acts_path = os.path.join(dir_path, refs.acts_fn)

In [None]:
doc = pickle.load(open(tokens_path, 'rb'))
print('\nDocument:')
print(' '.join(doc))

doc_acts = np.load(acts_path)
layers = doc_acts.files
print(f'\nLayers: {", ".join(layers)}')

## Visualize activations, reduced by NMF

In [None]:
# REDUCE ACTS
dim_reduction, dim = 'NMF', 3
vis_size = True
max_font_size = 10

reduced_acts = {layer: acts_util.reduce_acts(doc_acts[layer], dim_reduction, dim) for layer in layers}
if vis_size:
    acts_sizes = {layer: np.linalg.norm(doc_acts[layer], axis=1) for layer in layers}

In [None]:
# VISUALIZE

# legend
pure_directions = np.eye(dim)
pure_rgbs = [list(vis_util.channels_to_rgbs(direction)[0]) for direction in pure_directions]
html = ''
for i, rgb in enumerate(pure_rgbs):
    html += html_util.style(f' {i} ', css=f'background-color: {html_util.rgb_to_color(*rgb)}')
print('Legend')
display(HTML(html))

# vis
rgbs = {layer:vis_util.channels_to_rgbs(reduced_acts[layer]) for layer in layers}
for layer in layers:
    _rgbs = rgbs[layer]
    if visualize_size:
        _sizes = acts_sizes[layer]
        _sizes = (_sizes - np.min(_sizes)) / (np.max(_sizes) - np.min(_sizes))  # normalize to range [0,1]
    print(layer)
    html = ''
    for pos, tok in enumerate(doc):
        if visualize_size:
            css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {_sizes[pos]*max_font_size}pt;'
        else:
            css = f'background-color: {html_util.rgb_to_color(*_rgbs[pos])}; font-size: {max_font_size}pt;'
        html += html_util.style(f' {tok} ', css=css)
    display(HTML(html))

## Visualize activations, reduced by Kmeans

In [None]:
# CLUSTER
n_clusters = 3
cluster_labels = {layer:KMeans(n_clusters=n_clusters).fit(doc_acts[layer]).labels_ for layer in layers}

In [None]:
# VISUALIZE

# legend
pure_rgbs = vis_util.channels_to_rgbs(np.eye(dim))
html = ''
for i, rgb in enumerate(pure_rgbs):
    html += style(f' {i} ', css=f'background-color: {rgb_to_color(*rgb)}')
print('Legend')
display(HTML(html))

# vis
for layer in layers:
    print(layer)
    html = ''
    for tok, cluster in zip(doc, cluster_labels[layer]):
        rgb = pure_rgbs[cluster]
        html += style(f' {tok} ', css=f'background-color: {rgb_to_color(*rgb)}; font-size: 10pt;')
    display(HTML(html))