# See a document through BERT's eyes

# Imports

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np

import sys
sys.path.insert(0, os.path.abspath('../..'))
from utils import acts_util, vis_util, html_util

# Load tokens and activations

In [None]:
# Set parameters
# dir_path = '../../../data/alice/sample3'
dir_path = '../../../data/short-sentence/'

In [None]:
# Checking parameters are valid
name = dir_path.split('/')[-1]
print(f'Directory: \'{name}\'')
 
tokens_path = os.path.join(dir_path, "tokens.pickle")
print(f'Path to tokens: \'{tokens_path}\'')
assert os.path.exists(tokens_path), f'File does not exist: {os.path.abspath(tokens_path)}'

acts_path = os.path.join(dir_path, f"activations.npz")
print(f'Path to reduced activations: \'{acts_path}\'')
assert os.path.exists(acts_path), f'File does not exist: {os.path.abspath(acts_path)}'

In [None]:
with open(tokens_path, 'rb') as f:
    tokens = pickle.load(f)

doc = ' '.join(tokens)
print('\nDocument:')
print(doc)

layer_to_acts = np.load(acts_path)
layers = layer_to_acts.files
print(f'\nLayers: {", ".join(layers)}')

# Visualize activations, reduced by NMF

In [None]:
# REDUCE ACTS
dim_reduction, dim = 'NMF', 3
visualize_size = True
max_font_size = 10

layer_to_reduced_acts = {layer:acts_util.fit_components(acts, dim_reduction, dim)[1] for (layer,acts) in layer_to_acts.items()}
if visualize_size:
    layer_to_acts_sizes = {layer: np.linalg.norm(acts, axis=1) for layer, acts in layer_to_acts.items()}

In [None]:
# VISUALIZE

# legend
pure_directions = np.eye(dim)
pure_rgbs = [list(vis_util.channels_to_rgbs(direction)[0]) for direction in pure_directions]
html = ''
for i, rgb in enumerate(pure_rgbs):
    css = f'background-color: {html_util.rgb_to_color(*rgb)}'
    html += html_util.style(f' {i} ', css=css)
print('Legend')
display(HTML(html))

# vis
layer_to_rgbs = {layer:vis_util.channels_to_rgbs(reduced_acts) for layer,reduced_acts in layer_to_reduced_acts.items()}
for layer in layers:
    rgbs = layer_to_rgbs[layer]
    if visualize_size:
        sizes = layer_to_acts_sizes[layer]
        normalized_sizes = (sizes - np.min(sizes)) / (np.max(sizes) - np.min(sizes))
    print(layer)
    html = ''
    for tok_idx, tok in enumerate(tokens):
        if visualize_size:
            css = f'background-color: {html_util.rgb_to_color(*rgbs[tok_idx])}; font-size: {normalized_sizes[tok_idx]*max_font_size}pt;'
        else:
            css = f'background-color: {html_util.rgb_to_color(*rgbs[tok_idx])}; font-size: {max_font_size}pt;'
        html += html_util.style(f' {tok} ', css=css)
    display(HTML(html))

# Visualize activations, reduced by Kmeans

In [None]:
# CLUSTER
n_clusters = 3
layer_to_clusters = {layer:KMeans(n_clusters=n_clusters).fit(acts).labels_ for (layer,acts) in layer_to_acts.items()}

In [None]:
# VISUALIZE

# legend
pure_directions = np.eye(n_clusters)
pure_rgbs = [list(vis_util.channels_to_rgbs(direction)[0]) for direction in pure_directions]
html = ''
for i, rgb in enumerate(pure_rgbs):
    css = f'background-color: {html_util.rgb_to_color(*rgb)}'
    html += html_util.style(f' {i} ', css=css)
print('Legend')
display(HTML(html))

# vis
for layer, clusters in layer_to_clusters.items():
    print(layer)
    html = ''
    for tok, cluster in zip(tokens, clusters):
        rgb = pure_rgbs[cluster]
        css = f'background-color: {html_util.rgb_to_color(*rgb)}; font-size: 4pt;'
        html += html_util.style(f' {tok} ', css=css)
    display(HTML(html))