# Investigation of Cosine Similarity of Block Intermediates

> Thus far, all of the similarity investigations have been based on Euclidean distance. In this notebook, we look at whether cosine similarity might be a better measure. 

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from pathlib import Path
from typing import Callable, Dict, List, Optional, Iterable, Protocol, Sequence, Tuple, TypeVar, Type

In [None]:
#| hide
from fastcore.test import *
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
import torch
from torch.nn import functional as F
from tqdm.auto import tqdm

In [None]:
# | hide

from transformer_experiments.common.substring_generator import all_unique_substrings
from transformer_experiments.common.text_analysis import (
    build_next_token_map,
    SubstringFrequencyAnalysis,
    top_nonzero_tokens
)
from transformer_experiments.common.utils import (
    aggregate_by_string_key,
    DataWrapper,
    topk_across_batches,
)
from transformer_experiments.dataset_split import split_text_dataset
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.models.transformer import (
    n_layer,
    TransformerLanguageModel
)
from transformer_experiments.models.transformer_helpers import (
    unsqueeze_emb,
    EncodingHelpers,
    LogitsWrapper,
    TransformerAccessors
)
from transformer_experiments.trained_models.tinyshakespeare_transformer import (
    create_model_and_tokenizer
)
from transformer_experiments.experiments.block_internals import (
    BlockInternalsAccessors,
    BlockInternalsExperiment,
    BatchedBlockInternalsExperiment,
    BlockInternalsAnalysis,
    batch_cosine_sim,
)
from transformer_experiments.experiments.similar_strings import (
    SimilarStringsData,
    SimilarStringsExperiment,
    SimilarStringsResult
)
from transformer_experiments.experiments.logit_lens import LogitLens

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ts = TinyShakespeareDataSet(cache_file='../artifacts/input.txt')
m, tokenizer = create_model_and_tokenizer(
    saved_model_filename='../artifacts/shakespeare.pt',
    dataset=ts,
    device=device,
)
_, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9)
encoding_helpers = EncodingHelpers(tokenizer, device)
accessors = TransformerAccessors(m, device)

In [None]:
print(f"device is {device}")

device is cpu


In [None]:
if list(Path('../artifacts/block_internals_results/large_files/slen10/').glob('*')) == []:
    print("Run `make block_internals_slen10_dataset` in the project root to generate the required dataset")

In [None]:
strings10 = all_unique_substrings(ts.text, 10)

In [None]:
exp10 = BatchedBlockInternalsExperiment(
    eh=encoding_helpers,
    accessors=accessors,
    strings=strings10,
    output_dir=Path('../artifacts/block_internals_results/large_files/slen10/'),
    batch_size=10000,
)

First, investigate whether there is a lot of variance in the norms of the block intermediates. If so, it suggests that cosine similarity may be a better measure than Euclidean distance.

In [None]:
for block_idx in range(n_layer):
    proj_out_batch = torch.load(str(exp10._block_output_filename(batch_idx=0, block_idx=block_idx)), mmap=True)
    proj_out_norms = torch.norm(proj_out_batch[:, -1, :], dim=-1)
    print(f"Layer {block_idx}: mean {proj_out_norms.mean()}, std {proj_out_norms.std()}")


Layer 0: mean 27.912761688232422, std 1.1137561798095703
Layer 1: mean 33.36977767944336, std 1.7041479349136353
Layer 2: mean 39.782466888427734, std 2.0023622512817383
Layer 3: mean 46.48314666748047, std 3.300010919570923
Layer 4: mean 53.44303894042969, std 6.607938289642334
Layer 5: mean 61.70024871826172, std 11.696634292602539


In [None]:
for block_idx in range(n_layer):
    ffwd_out_batch = torch.load(str(exp10._ffwd_output_filename(batch_idx=0, block_idx=block_idx)), mmap=True)
    ffwd_out_norms = torch.norm(ffwd_out_batch[:, -1, :], dim=-1)
    print(f"Layer {block_idx}: mean {ffwd_out_norms.mean()}, std {ffwd_out_norms.std()}")


Layer 0: mean 6.409949779510498, std 1.142516851425171
Layer 1: mean 8.440470695495605, std 0.9452682137489319
Layer 2: mean 9.34270191192627, std 1.0641635656356812
Layer 3: mean 11.903395652770996, std 1.5840272903442383
Layer 4: mean 13.59791374206543, std 2.9059391021728516
Layer 5: mean 19.13654136657715, std 5.285085201263428


OK, so for both proj_out and ffwd_out, norm goes up in the later layers and so does std dev. So, cosine similarity is probably a better measure than Euclidean distance.

In [None]:
prompts = ['my most gr', 'is dreams,']
prompts_exp = BlockInternalsExperiment(encoding_helpers, accessors, prompts)

In [None]:
queries = prompts_exp.proj_output(block_idx=0)[:, -1, :]

In [None]:
# Mock up of what a cosine similarity function would look like
batch = proj_out_batch[:, -1, :]
B, _ = batch.shape
n_queries, _ = queries.shape
sims = F.cosine_similarity(batch.reshape(B, 1, -1).expand(-1, n_queries, -1), queries, dim=-1)
sims.shape

torch.Size([10000, 2])

In [None]:
block_idx = 0
sims, distances = exp10.strings_with_topk_closest_proj_outputs(
    block_idx=block_idx,
    t_i = -1,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:, -1, :],
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)

for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	  'my most gr' (1.000)
	  'ur most gr' (0.995)
	  'is most gr' (0.995)
	  'ne most gr' (0.995)
	  'ilst my gr' (0.995)
	  'he most gr' (0.994)
	  'unto my gr' (0.994)
	  'e, most gr' (0.994)
	  't, most gr' (0.994)
	  'yman to gr' (0.993)

Closest to 'is dreams,': 
	  'is dreams,' (1.000)
	  'ly dreams,' (0.995)
	  'en dreams,' (0.994)
	  'he dreams,' (0.994)
	  'ur dreams,' (0.994)
	  'nd dreams,' (0.993)
	  'ery beams,' (0.992)
	  'of dreams,' (0.991)
	  "n's beams," (0.990)
	  'hese arms,' (0.989)



In [None]:
block_idx = 0
sims, distances = exp10.strings_with_topk_closest_ffwd_outputs(
    block_idx=block_idx,
    t_i = -1,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:, -1, :],
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)
for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	  'at unsubst' (0.454)
	  't in subst' (0.454)
	  'ften burst' (0.454)
	  'd by subst' (0.454)
	  'n of subst' (0.454)
	  'most burst' (0.453)
	  'it unconst' (0.453)
	  'l, inconst' (0.453)
	  'o be subst' (0.453)
	  're unconst' (0.453)

Closest to 'is dreams,': 
	  'Schoolmast' (0.473)
	  'schoolmast' (0.472)
	  'hath chast' (0.471)
	  'Stand fast' (0.470)
	  ' notwithst' (0.470)
	 'ng\nfantast' (0.470)
	  'stand fast' (0.470)
	 'thou\nhadst' (0.469)
	  'ough chast' (0.469)
	  'ch fantast' (0.469)



In [None]:
block_idx = 5
sims, distances = exp10.strings_with_topk_closest_proj_outputs(
    block_idx=block_idx,
    t_i = -1,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:, -1, :],
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)
for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	  'my most gr' (1.000)
	  'my most st' (0.897)
	  'my most sa' (0.864)
	  ' my most r' (0.849)
	  ' my most l' (0.809)
	  'my high bl' (0.799)
	  'mt my mast' (0.799)
	  'm thy moth' (0.795)
	  'm, my mour' (0.788)
	  'my most he' (0.784)

Closest to 'is dreams,': 
	  'is dreams,' (1.000)
	  'is hoarse,' (0.932)
	  'ish hairs,' (0.924)
	  'ith oaths,' (0.914)
	  'is events,' (0.912)
	  'ish tears,' (0.911)
	  'ir mouths,' (0.911)
	  'ir plumes,' (0.905)
	  'is throne,' (0.899)
	  'is mother,' (0.896)



In [None]:
block_idx = 5
sims, distances = exp10.strings_with_topk_closest_ffwd_outputs(
    block_idx=block_idx,
    t_i = -1,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:, -1, :],
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)
for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	  's fast bel' (0.347)
	  's part bel' (0.337)
	  's that bel' (0.334)
	  ' drops bel' (0.331)
	  'assage bel' (0.331)
	  'e step bel' (0.325)
	  'y best bel' (0.325)
	  'w then bel' (0.323)
	  'myself bel' (0.319)
	 'ts:\nSometi' (0.317)

Closest to 'is dreams,': 
	  ' may she--' (0.402)
	  '! should--' (0.397)
	 'tantly,\n--' (0.396)
	  ' it were--' (0.391)
	  's is she--' (0.390)
	 'LO:\nAnd,--' (0.390)
	  "ty in't,--" (0.390)
	  'im--dead--' (0.390)
	  't are so--' (0.389)
	  ": here's--" (0.389)



In [None]:
block_idx = 5
sims, distances = exp10.strings_with_topk_closest_ffwd_outputs(
    block_idx=block_idx,
    t_i = 8,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:, -1, :],
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)
for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	   'ssage bel' (0.346)
	   ' step bel' (0.339)
	   ' fast bel' (0.337)
	  'g\ninto so' (0.322)
	   'drops bel' (0.320)
	   'grave bel' (0.320)
	  ' it\nTo so' (0.317)
	   ' part bel' (0.311)
	   ' best bel' (0.310)
	   'place bel' (0.307)

Closest to 'is dreams,': 
	   ' should--' (0.400)
	   'who has--' (0.397)
	   'may she--' (0.395)
	   ' so mad--' (0.391)
	   'derates--' (0.387)
	  'antly,\n--' (0.386)
	   ' is she--' (0.386)
	   'no soul--' (0.384)
	   ' camest--' (0.384)
	   'f these--' (0.384)



Try it for embeddings:

In [None]:
sims, distances = exp10.strings_with_topk_closest_embeddings(
    queries=prompts_exp.embeddings,
    k=10,
    largest=True,
    distance_function=batch_cosine_sim,
)
for idx, strings in enumerate(sims):
    print(f"Closest to {repr(prompts[idx])}: ")

    for i, s in enumerate(strings):
        print(f"\t{repr(s):>14} ({distances[i][idx]:.3f})")
    print()

Closest to 'my most gr': 
	  'my most gr' (1.000)
	  'my most sa' (0.912)
	  't, most gr' (0.909)
	  'my most st' (0.909)
	  'my most so' (0.906)
	  'e, most gr' (0.905)
	  'my most re' (0.905)
	  'ur most gr' (0.905)
	  'my most he' (0.905)
	  'is most gr' (0.904)

Closest to 'is dreams,': 
	  'is dreams,' (1.000)
	  'is dream o' (0.906)
	  'ur dreams,' (0.906)
	  'of dreams,' (0.905)
	  'us dreams.' (0.904)
	  'he dreams,' (0.903)
	  'ly dreams,' (0.902)
	  'en dreams,' (0.902)
	  'nd dreams,' (0.896)
	 'as dream\nS' (0.865)

