# Understand a context through its neighbors
Examine, at each layer, the given context's KNNs in the corpus and the context's KNNs out of its masked variants.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import pickle
import numpy as np
from sklearn.cluster import KMeans
import os
from IPython.core.display import display, HTML
import numpy as np
from sklearn.neighbors import NearestNeighbors
import glob
import warnings
from typing import List
import collections
import itertools
import random

import sys
sys.path.insert(0, os.path.abspath('../../..'))
from src.utils import acts_util, vis_util, html_util, context_util, bert_util
from src.utils.context_util import context_html
from src.utils.html_util import highlighter, fix_size, font_size, style
from src import references as refs

## Parameters

In [None]:
vis_corpus_KNNs = True
vis_masked_KNNs = True

## Load tokens and activations

In [None]:
# set sample directory
dir_path = os.path.abspath('../../../data/sentences/art/')

In [None]:
# load doc and acts
doc = pickle.load(open(os.path.join(dir_path, refs.toks_fn), 'rb'))
doc_acts = np.load(os.path.join(dir_path, refs.acts_fn))
layers = doc_acts.files[:2] # change to fewer layers if you want

print('\nDocument:')
print(' '.join(doc))
print(f'\nLayers: {", ".join(layers)}')

## Get neighbors

In [None]:
toks_of_interest = list(range(1,8))

In [None]:
if vis_corpus_KNNs:
    # GET K NEAREST NEIGHBORS FROM DATASET
    # params
    corpus_dir = '/Users/pkalluri/projects/clarity/bert-vis/big-data/wiki-large/standard'
    knn_path = os.path.join(corpus_dir, refs.knn_models_fn)
    n_neighbors = 20

    # get KNNs
    corpus_neighborhoods = {}
    corpus_contexts = pickle.load(open(os.path.join(corpus_dir, refs.contexts_fn),'rb'))
    with open(knn_path, 'rb') as f:
        for layer in layers:
            print(f'Layer {layer}')

            print('Loading nearest neighbors model.')
            knn_model = pickle.load(f)

            print('Finding neighbors')
            # a concise neighborhood is a single tuple of (neighbors' dists to the neighborhood's true token, neighbors' idxs)
            concise_neighborhoods = zip(*knn_model.kneighbors(doc_acts[layer][toks_of_interest], n_neighbors=n_neighbors, return_distance=True))
            # We want a more intuitive and useful representation:
            # a neighborhood contains a list of neighbors; a neighbor is a tuple of (context, its dist to the true token)    
            neighborhoods = []  
            for concise_neighborhood in concise_neighborhoods:
                neighborhood = [(corpus_contexts[neigh_idx], neigh_dist) for (neigh_dist, neigh_idx) in zip(*concise_neighborhood)]
                neighborhoods.append(neighborhood)
            corpus_neighborhoods[layer] = neighborhoods
            del knn_model


In [None]:
if vis_masked_KNNs:
    # params
    mask_lengths = [1,2,3,4,5,6,7]
    # GET MASKED VARIANTS OF SAMPLE
    variants = [doc] + bert_util.get_masked_variants(doc, mask_lengths)
    variants_contexts, variants_acts = bert_util.get_contexts_and_acts(variants, tokenized=True)

In [None]:
if vis_masked_KNNs:
    # GET KNNs FROM MASKED VARIANTS
    # params
    n_neighbors = 50

    # get KNNs
    masked_neighborhoods = {}
    for layer in layers:
        print(layer)

        print('Fit nearest neighbors model.')
        knn_model = NearestNeighbors(n_neighbors=n_neighbors).fit(variants_acts[layer])

        print('Finding neighbors')
        # a concise neighborhood is a single tuple of (neighbors' dists to the neighborhood's true token, neighbors' idxs)
        concise_neighborhoods = zip(*knn_model.kneighbors(variants_acts[layer][toks_of_interest], n_neighbors=n_neighbors, return_distance=True))
        # We want a more intuitive and useful representation:
        # a neighborhood contains a list of neighbors; a neighbor is a tuple of (context, its dist to the true token)    
        neighborhoods = []  
        for concise_neighborhood in concise_neighborhoods:
            neighborhood = [(variants_contexts[neigh_idx], neigh_dist) for (neigh_dist, neigh_idx) in zip(*concise_neighborhood)]
            neighborhoods.append(neighborhood)
        masked_neighborhoods[layer] = neighborhoods
        del knn_model


In [None]:
# random.seed(0)
# vis_rand_KNNs = True
# if vis_rand_KNNs:
#     all_toks = [tok for (toks, pos) in corpus_contexts for tok in toks]
#     rand_toks = random.sample(all_toks, 20)
#     variants = itertools.combinations(all_toks, 1)
# #     counter=collections.Counter(all_toks)
#     variants_contexts, variants_acts = bert_util.get_docs_contexts_and_acts(variants, tokenized=True)

## Vis KNNs in corpus and KNNs in masked variants

In [None]:
# VISUALIZE
# params
n_neighbors_to_vis = 20
vis_corpus_KNNs = True
vis_masked_KNNs = True
together = False
token_styler = lambda t: style(font_size(fix_size(t),7), 'line-height:0px;')

# vis
tok_of_interest = 'The'
tok_of_interest_idx = doc.index(tok_of_interest)
neigh_idx = toks_of_interest.index(tok_of_interest_idx)  # it should be one of the toks of interested we searched for
display(HTML(context_html(doc, toks_of_interest[tok_of_interest_idx])))
bluer = highlighter('lightblue')
for layer in layers:
    print(f'Layer {layer}')
    if together:
        neighbors = []
        if vis_masked_KNNs:
            neighbors.extend([(context, dist, 'masked') for context, dist in masked_neighborhoods[layer][neigh_idx]][1:n_neighbors_to_vis])
        if vis_corpus_KNNs:
            neighbors.extend([(context, dist, 'corpus') for context, dist in corpus_neighborhoods[layer][neigh_idx]][1:n_neighbors_to_vis])
        neighbors.sort(key=lambda t: t[1])
        for neigh_context, neigh_dist, neigh_dataset in neighbors[:n_neighbors_to_vis]:
            html = f'({neigh_dist:.2f}) '
            if neigh_dataset == 'masked':
                html += context_html(*neigh_context, masker=masker, token_styler=token_styler)
            elif neigh_dataset == 'corpus':
                html += context_html(*neigh_context, marker=bluer, masker=masker, token_styler=token_styler)
            display(HTML(html))
    else:
        if vis_masked_KNNs:
            neighbors = masked_neighborhoods[layer][neigh_idx][:n_neighbors_to_vis]
            for neigh_context, neigh_dist in neighbors:  # masked neighbors
                html = token_styler(f'({neigh_dist:.2f}) ') + context_html(*neigh_context, token_styler=lambda t: token_styler(fix_size(t)))
                display(HTML(html))
        if vis_corpus_KNNs:
            neighbors = corpus_neighborhoods[layer][tok_of_interest_idx][:n_neighbors_to_vis]
            for neigh_context, neigh_dist in neighbors:  # corpus neighbors
                html = token_styler(f'({neigh_dist:.2f}) ') + context_html(*neigh_context, marker=bluer, token_styler=token_styler)
                display(HTML(html))
    print('.')
    print('.')
    print('.')  


## Vis mask pruning (remove one token at a time to visualize the essence of the context)

In [None]:
# PRUNING
# params
masker = highlighter('black')
# token_styler = lambda t: html_util.font_size(t, 7)
token_styler = lambda t: style(font_size(t,10), 'line-height:0px;')

# vis one token's neighbors
tok_of_interest = 'The'
tok_of_interest_idx = doc.index(tok_of_interest)
neigh_idx = toks_of_interest.index(tok_of_interest_idx)
display(HTML(context_html(doc, tok_of_interest_idx)))
for layer in layers:
    print(f'Layer {layer}')
    bluer = highlighter('lightblue')
    neighbors = []
    neighbors.extend([(neigh_context, neigh_dist, 'masked') for neigh_context, neigh_dist in masked_neighborhoods[layer][neigh_idx]][1:])
#     neighbors.extend([(context, dist, 'corpus') for context, dist in corpus_neighborhoods[layer][neigh_idx]][1:])
#     neighbors.sort(key=lambda neigh: len([tok for tok in neigh[0] if tok=='[MASK]']))
    for mask_len_to_find in range(len(doc)-2):
        for neigh_context, neigh_dist, neigh_dataset in neighbors:
            neigh_toks, neigh_pos = neigh_context
            mask_len = len([tok for tok in neigh_toks if tok == '[MASK]'])
            if mask_len == mask_len_to_find:
                html = ''
                if neigh_dataset == 'masked':
                    html += token_styler(f'({neigh_dist:.2f}) ') +  context_html(*neigh_context, masker=masker, token_styler=token_styler)
                elif neigh_dataset == 'corpus':
                    html += token_styler(f'({neigh_dist:.2f}) ') + context_html(*neigh_context, marker=bluer, masker=masker, token_styler=token_styler)
                display(HTML(html))
                break
    print('.')
    print('.')
    print('.')  



## Try a custom sentence

In [None]:
# choose
layer = layers[1]
new_text = 'The [MASK] floor exhibits classical art.'
new_tok_of_interest = 'The'

# vis
new_doc, new_acts = bert_util.get_toks_and_acts(new_text)
new_poos = new_doc.index(new_tok_of_interest)
current_act = doc_acts[layer][tok_of_interest_idx]
new_act = new_acts[layer][new_poos]
display(HTML(f'Layer {layer}'))
dist = np.linalg.norm(current_act - new_act)
bluer = highlighter('lightblue')
token_styler = lambda t: font_size(t,10)
html = token_styler(f'({dist:.2f}) ')
html += context_html(new_doc, new_poos, marker=bluer, token_styler=token_styler)
display(HTML(html))