# Widening the Space of Similar Values

> A major finding was that the current approaches are considering values that are too similar. This notebook investigates ways to search a wider space.  

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from pathlib import Path
from typing import Callable, Dict, List, Optional, Iterable, Protocol, Sequence, Tuple, TypeVar, Type

In [None]:
#| hide
from fastcore.test import *
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
import torch
from torch.nn import functional as F
from tqdm.auto import tqdm

In [None]:
# | hide

from transformer_experiments.common.substring_generator import all_unique_substrings
from transformer_experiments.common.text_analysis import (
    build_next_token_map,
    SubstringFrequencyAnalysis,
    top_nonzero_tokens
)
from transformer_experiments.common.utils import (
    aggregate_by_string_key,
    DataWrapper,
    topk_across_batches,
)
from transformer_experiments.dataset_split import split_text_dataset
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.models.transformer import (
    n_embed,
    n_layer,
    TransformerLanguageModel
)
from transformer_experiments.models.transformer_helpers import (
    unsqueeze_emb,
    EncodingHelpers,
    LogitsWrapper,
    TransformerAccessors
)
from transformer_experiments.trained_models.tinyshakespeare_transformer import (
    create_model_and_tokenizer
)
from transformer_experiments.experiments.block_internals import (
    BlockInternalsAccessors,
    BlockInternalsExperiment,
    BatchedBlockInternalsExperiment,
    BlockInternalsAnalysis,
    batch_cosine_sim,
)
from transformer_experiments.experiments.final_ffwd import FinalFFWDExperiment
from transformer_experiments.experiments.similar_strings import (
    SimilarStringsData,
    SimilarStringsExperiment,
    SimilarStringsResult
)
from transformer_experiments.experiments.logit_lens import LogitLens

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ts = TinyShakespeareDataSet(cache_file='../artifacts/input.txt')
m, tokenizer = create_model_and_tokenizer(
    saved_model_filename='../artifacts/shakespeare.pt',
    dataset=ts,
    device=device,
)
_, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9)
encoding_helpers = EncodingHelpers(tokenizer, device)
accessors = TransformerAccessors(m, device)

In [None]:
print(f"device is {device}")

device is cpu


In [None]:
if list(Path('../artifacts/block_internals_results/large_files/slen10/').glob('*')) == []:
    print("Run `make block_internals_slen10_dataset` in the project root to generate the required dataset")

In [None]:
strings10 = all_unique_substrings(ts.text, 10)

In [None]:
exp10 = BatchedBlockInternalsExperiment(
    eh=encoding_helpers,
    accessors=accessors,
    strings=strings10,
    output_dir=Path('../artifacts/block_internals_results/large_files/slen10/'),
    batch_size=10000,
)

In [None]:
torch.manual_seed(1337)
n_samples = 20000
indices = torch.randperm(len(strings10))[:n_samples]
strings20k = [strings10[i.item()] for i in indices]

In [None]:
# Create a sample of 500 strings
sample_size = 500
strings_sample = strings20k[:sample_size]

In [None]:
# TODO: put this in a common component
def get_model_outputs(prompts: Sequence[str], encoding_helpers: EncodingHelpers):
    # Compute the model's predictions:
    tokens = encoding_helpers.tokenize_strings(prompts)
    logits, _ = m(tokens)

    logits = LogitsWrapper(logits, encoding_helpers.tokenizer)
    return [topk_tokens[-1] for topk_tokens in logits.topk_tokens(k=10)]


In [None]:
model_outputs_sample = get_model_outputs(strings_sample, encoding_helpers)

In [None]:
prompts_exp = BlockInternalsExperiment(encoding_helpers, accessors, strings_sample)

Start by examining what we get when we ask for a much larger top k values. 

In [None]:
emb_sims, emb_distances = exp10.strings_with_topk_closest_embeddings(
    prompts_exp.embeddings[:5, :, :],
    k=200,
    largest=True,
    distance_function=batch_cosine_sim,
)

In [None]:
emb_distances[:10, :], emb_distances[-10:, :]

(tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [0.9062, 0.9470, 0.9045, 0.9493, 0.9544],
         [0.9062, 0.9462, 0.8602, 0.9083, 0.9443],
         [0.9053, 0.9429, 0.8593, 0.9082, 0.9064],
         [0.9037, 0.9429, 0.8586, 0.9055, 0.9057],
         [0.9027, 0.9404, 0.8566, 0.9052, 0.9045],
         [0.9020, 0.9027, 0.8562, 0.9044, 0.8996],
         [0.9017, 0.9009, 0.8548, 0.9035, 0.8982],
         [0.8962, 0.9009, 0.8545, 0.9034, 0.8656],
         [0.8651, 0.9008, 0.8532, 0.9032, 0.8631]]),
 tensor([[0.7591, 0.7566, 0.7604, 0.8064, 0.8047],
         [0.7591, 0.7566, 0.7604, 0.8063, 0.8047],
         [0.7591, 0.7566, 0.7603, 0.8063, 0.8042],
         [0.7590, 0.7564, 0.7602, 0.8063, 0.8041],
         [0.7589, 0.7563, 0.7601, 0.8062, 0.8041],
         [0.7589, 0.7563, 0.7601, 0.8061, 0.8040],
         [0.7588, 0.7557, 0.7601, 0.8061, 0.8039],
         [0.7587, 0.7557, 0.7600, 0.8060, 0.8039],
         [0.7587, 0.7557, 0.7599, 0.8060, 0.8038],
         [0.7587, 0.7557, 0.7

In [None]:
for j in range(10):
    print(f"{'   '.join([repr(emb_sims[i][j]) for i in range(len(emb_sims))])}")

print()
for j in range(-10, 0):
    print(f"{'   '.join([repr(emb_sims[i][j]) for i in range(len(emb_sims))])}")


'is dreams,'   'by present'   's eyes may'   'eart of ho'   ' man, as I'
'is dream o'   'My present'   's eye, mak'   'eart of mo'   ' men, as I'
'ur dreams,'   'be present'   's eyes in '   'earn of hi'   ' man, as y'
'of dreams,'   'dy present'   'l eyes can'   'ears of ha'   ' man, if I'
'us dreams.'   'my present'   's eyes to '   'park of ho'   'oman, as t'
'he dreams,'   'ry present'   'l eyes gaz'   'eart of ge'   ' men, as i'
'ly dreams,'   'y, present'   'r foes may'   'eard of hi'   ' many as y'
'en dreams,'   'on present'   'r ever may'   'east of yo'   'cian, as I'
'nd dreams,'   't, present'   's eyes do '   'part of hi'   ' son, as t'
'as dream\nS'   'in present'   's eye; tal'   'earn of yo'   ' long as I'

'is presenc'   ' a prisone'   'g over mas'   'efit of se'   ' man: we s'
'is prowess'   'ot prone t'   'l even tak'   'wist of ro'   ' wind as s'
'is be all,'   'is project'   't so I may'   'e is of so'   ' man; all '
'is the mad'   'my person.'   'n thou may'   'gen

In [None]:
block_idx = 0
proj_sims, proj_distances = exp10.strings_with_topk_closest_proj_outputs(
    block_idx=block_idx,
    t_i=-1,
    queries=prompts_exp.proj_output(block_idx=block_idx)[:5, -1, :],
    k=200,
    largest=True,
    distance_function=batch_cosine_sim,
)

In [None]:
proj_distances[:10, :], proj_distances[-10:, :]

(tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [0.9948, 0.9967, 0.9947, 0.9974, 0.9956],
         [0.9942, 0.9967, 0.9945, 0.9960, 0.9933],
         [0.9937, 0.9964, 0.9921, 0.9958, 0.9897],
         [0.9935, 0.9963, 0.9917, 0.9957, 0.9896],
         [0.9928, 0.9962, 0.9910, 0.9956, 0.9871],
         [0.9922, 0.9950, 0.9905, 0.9955, 0.9857],
         [0.9911, 0.9944, 0.9900, 0.9954, 0.9856],
         [0.9897, 0.9942, 0.9895, 0.9950, 0.9850],
         [0.9891, 0.9939, 0.9894, 0.9943, 0.9846]]),
 tensor([[0.9709, 0.9826, 0.9796, 0.9855, 0.9726],
         [0.9709, 0.9825, 0.9795, 0.9855, 0.9726],
         [0.9709, 0.9824, 0.9794, 0.9854, 0.9726],
         [0.9708, 0.9824, 0.9793, 0.9854, 0.9725],
         [0.9708, 0.9824, 0.9793, 0.9854, 0.9725],
         [0.9708, 0.9823, 0.9793, 0.9854, 0.9724],
         [0.9708, 0.9823, 0.9791, 0.9854, 0.9724],
         [0.9708, 0.9823, 0.9791, 0.9853, 0.9724],
         [0.9707, 0.9823, 0.9791, 0.9853, 0.9723],
         [0.9707, 0.9823, 0.9

In [None]:
for j in range(10):
    print(f"{'   '.join([repr(proj_sims[i][j]) for i in range(len(proj_sims))])}")

print()
for j in range(-10, 0):
    print(f"{'   '.join([repr(proj_sims[i][j]) for i in range(len(proj_sims))])}")


'is dreams,'   'by present'   's eyes may'   'eart of ho'   ' man, as I'
'ly dreams,'   'my present'   'e case may'   'ster of ho'   ' men, as I'
'en dreams,'   'dy present'   ' sense may'   'ffer to ha'   ' and, as I'
'he dreams,'   'ry present'   'ied as may'   'ruth of ho'   ' not, as I'
'ur dreams,'   'be present'   'ay she may'   'otes of ho'   'o me, as I'
'nd dreams,'   'My present'   'r foes may'   'anes of ho'   'nd I, as I'
'ery beams,'   'y; present'   'So she may'   'oint of ho'   'cian, as I'
'of dreams,'   'in present'   ' haste may'   'yers of ho'   'I am, as t'
"n's beams,"   'is present'   'odesty may'   'ains of ho'   '-day, as I'
'hese arms,'   'im present'   'esence may'   'ound of ho'   '\nAnd, as I'

'sires most'   'ast ungent'   "'s some am"   'lf with ho'   ' be, was l'
'teous mass'   'What scene'   'e they mad'   'tell of hi'   'Look, as I'
'm to kiss,'   'lest scent'   ' am to say'   'Thus to ha'   'wick, as o'
'much amiss'   'ondon sent'   ' early mad'   's o

In [None]:
block_idx = 0
ffwd_sims, ffwd_distances = exp10.strings_with_topk_closest_ffwd_outputs(
    block_idx=block_idx,
    t_i=-1,
    queries=prompts_exp.ffwd_output(block_idx=block_idx)[:5, -1, :],
    k=200,
    largest=True,
    distance_function=batch_cosine_sim,
)

In [None]:
ffwd_distances[:10, :], ffwd_distances[-10:, :]

(tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [0.9998, 0.9999, 0.9997, 0.9998, 0.9999],
         [0.9998, 0.9998, 0.9996, 0.9998, 0.9997],
         [0.9998, 0.9998, 0.9996, 0.9998, 0.9997],
         [0.9998, 0.9998, 0.9995, 0.9998, 0.9997],
         [0.9997, 0.9998, 0.9995, 0.9998, 0.9996],
         [0.9997, 0.9997, 0.9995, 0.9998, 0.9995],
         [0.9997, 0.9997, 0.9995, 0.9998, 0.9995],
         [0.9996, 0.9997, 0.9995, 0.9997, 0.9994],
         [0.9996, 0.9997, 0.9995, 0.9997, 0.9994]]),
 tensor([[0.9988, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9988, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9988, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9988, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9988, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9987, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9987, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9987, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9987, 0.9990, 0.9989, 0.9992, 0.9983],
         [0.9987, 0.9990, 0.9

In [None]:
for j in range(10):
    print(f"{'   '.join([repr(ffwd_sims[i][j]) for i in range(len(ffwd_sims))])}")

print()
for j in range(-10, 0):
    print(f"{'   '.join([repr(ffwd_sims[i][j]) for i in range(len(ffwd_sims))])}")

'is dreams,'   'by present'   's eyes may'   'eart of ho'   ' man, as I'
'en dreams,'   'my present'   ' sense may'   'ster of ho'   ' men, as I'
'ly dreams,'   'dy present'   'e case may'   'otes of ho'   ' and, as I'
'he dreams,'   'ry present'   'r foes may'   'anes of ho'   ' not, as I'
'nd dreams,'   'My present'   'ied as may'   'ains of ho'   'o me, as I'
'ur dreams,'   'be present'   'ay she may'   'oint of ho'   'nd I, as I'
'of dreams,'   'y, present'   ' bones may'   'ound of ho'   '-day, as I'
'ery beams,'   'y; present'   'So she may'   'fear to ho'   ' but, as I'
'rate arms,'   'im present'   ' grace may'   'ally of ho'   'rd me as I'
'hese arms,'   'in present'   'e more may'   'yers of ho'   '\nYes, as I'

'ck groans,'   'll\nPresent'   'en you say'   'ace our ho'   'r sakes, I'
'te builds,'   's innocent'   'Hope I may'   'he that ho'   'early as I'
'she finds,'   'me Florent'   'e must say'   'Half an ho'   'nes, and I'
'reat loss,'   'their gent'   'n thou may'   'as

## For the 500 Sample Strings, Can we Calculate Cosine Similarity to Every Other String?


In [None]:
class CosineSimilaritiesExperiment:
    def __init__(
        self,
        exp: BatchedBlockInternalsExperiment,
        output_folder: Path,
    ):
        self.exp = exp
        self.output_folder = output_folder

        self.n_batches = exp.n_batches

    def embedding_sims_filename(self, batch_idx: int):
        return self.output_folder / f'embedding_sims_{batch_idx:03d}.pt'

    def proj_out_sims_filename(self, batch_idx: int, block_idx: int):
        return self.output_folder / f'proj_out_sims_{batch_idx:03d}_{block_idx:02d}.pt'

    def ffwd_out_sims_filename(self, batch_idx: int, block_idx: int):
        return self.output_folder / f'ffwds_out_sims_{batch_idx:03d}_{block_idx:02d}.pt'

    def generate_embedding_sims(self, chunk_size: int, queries: torch.Tensor, disable_progress_bar: bool = False):
        n_chunks = exp10.batch_size // chunk_size

        assert queries.dim() == 2
        n_queries = queries.shape[0]

        def sims_for_chunk(emb_batch: torch.Tensor, chunk_idx: int):
            chunk = emb_batch[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size, :, :]
            actual_chunk_size = chunk.shape[0]
            return F.cosine_similarity(
                chunk.reshape(actual_chunk_size, 1, -1).expand(-1, n_queries, -1),
                queries,
                dim=-1
            )

        for batch_idx in tqdm(range(self.exp.n_batches), disable=disable_progress_bar):
            emb_batch = torch.load(str(self.exp._embeddings_filename(batch_idx=batch_idx)), mmap=True)
            sims = torch.cat([
                sims_for_chunk(emb_batch, i)
                for i in range(n_chunks)
            ], dim=0)
            torch.save(sims, str(self.embedding_sims_filename(batch_idx=batch_idx)))

    def generate_proj_out_sims(self, get_queries: Callable[[int], torch.Tensor], disable_progress_bar=False):
        for batch_idx in tqdm(range(self.exp.n_batches), disable=disable_progress_bar):
            for block_idx in range(n_layer):
                queries = get_queries(block_idx)

                assert queries.dim() == 2
                n_queries = queries.shape[0]

                proj_out_batch = torch.load(str(self.exp._proj_output_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True)
                batch_size = proj_out_batch.shape[0]
                sims = F.cosine_similarity(
                    proj_out_batch[:, -1, :].reshape(batch_size, 1, -1).expand(-1, n_queries, -1),
                    queries,
                    dim=-1
                )
                torch.save(sims, str(self.proj_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)))

    def generate_ffwd_out_sims(self, get_queries: Callable[[int], torch.Tensor], disable_progress_bar=False):
        for batch_idx in tqdm(range(self.exp.n_batches), disable=disable_progress_bar):
            for block_idx in range(n_layer):
                queries = get_queries(block_idx)
                assert queries.dim() == 2
                n_queries = queries.shape[0]

                ffwd_out_batch = torch.load(str(self.exp._ffwd_output_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True)
                batch_size = ffwd_out_batch.shape[0]
                sims = F.cosine_similarity(
                    ffwd_out_batch[:, -1, :].reshape(batch_size, 1, -1).expand(-1, n_queries, -1),
                    queries,
                    dim=-1
                )
                torch.save(sims, str(self.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)))



In [None]:
output_folder = exp10.output_dir / 'cosine_sims'
output_folder.mkdir(exist_ok=True)

In [None]:
cos_exp = CosineSimilaritiesExperiment(exp10, output_folder)

In [None]:
cos_exp.generate_embedding_sims(chunk_size=2000, queries=prompts_exp.embeddings.reshape(sample_size, -1))

In [None]:
cos_exp.generate_proj_out_sims(get_queries=lambda block_idx: prompts_exp.proj_output(block_idx=block_idx)[:, -1, :])

In [None]:
cos_exp.generate_ffwd_out_sims(get_queries=lambda block_idx: prompts_exp.ffwd_output(block_idx=block_idx)[:, -1, :])

## Analyze the Results

First, collect some stats about the results

In [None]:
def stats_across_batches(
    get_batch: Callable[[int], torch.Tensor],
    n_batches: int,
    n_queries: int,
):
    mins = torch.zeros(n_queries)
    maxs = torch.zeros(n_queries)
    means = torch.zeros(n_queries)
    s = torch.zeros(n_queries)

    total_count = 0
    for i in range(n_batches):
        batch = get_batch(i)
        batch_size, n_queries_batch = batch.shape
        assert n_queries_batch == n_queries

        mins = torch.minimum(mins, batch.min(dim=0).values)
        maxs = torch.maximum(maxs, batch.max(dim=0).values)

        # Implement Chan, Golub, and LeVeque method
        total_count += batch_size
        delta = batch.mean(dim=0) - means
        means += delta * batch_size / total_count
        s += batch.var(dim=0) * (batch_size - 1) + delta**2 * batch_size * (total_count - batch_size) / total_count

    vars = s / (total_count - 1)
    stds = torch.sqrt(vars)
    return mins, maxs, means, stds

In [None]:
emb_mins, emb_maxs, emb_means, emb_stds = stats_across_batches(
    get_batch=lambda i: torch.load(str(cos_exp.embedding_sims_filename(batch_idx=i))),
    n_batches=cos_exp.n_batches,
    n_queries=sample_size,
)

In [None]:
proj_mins, proj_maxs, proj_means, proj_stds = zip(*[
    stats_across_batches(
        get_batch=lambda i: torch.load(str(cos_exp.proj_out_sims_filename(batch_idx=i, block_idx=block_idx))),
        n_batches=cos_exp.n_batches,
        n_queries=sample_size,
    )
    for block_idx in range(n_layer)
])

In [None]:
ffwd_mins, ffwd_maxs, ffwd_means, ffwd_stds = zip(*[
    stats_across_batches(
        get_batch=lambda i: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=i, block_idx=block_idx))),
        n_batches=cos_exp.n_batches,
        n_queries=sample_size,
    )
    for block_idx in range(n_layer)
])

In [None]:
print(f"emb mean: {emb_means.mean():.3f} ± {emb_means.std():.3f}")

emb mean: 0.534 ± 0.009


In [None]:
for block_idx in range(n_layer):
    print(f"Block {block_idx}: proj_out mean: {proj_means[block_idx].mean():.3f} ± {proj_means[block_idx].std():.3f}, ffwd_out mean: {ffwd_means[block_idx].mean():.3f} ± {ffwd_means[block_idx].std():.3f}")

Block 0: proj_out mean: 0.380 ± 0.040, ffwd_out mean: 0.801 ± 0.092
Block 1: proj_out mean: 0.203 ± 0.035, ffwd_out mean: 0.365 ± 0.042
Block 2: proj_out mean: 0.112 ± 0.027, ffwd_out mean: 0.160 ± 0.033
Block 3: proj_out mean: 0.076 ± 0.032, ffwd_out mean: 0.072 ± 0.036
Block 4: proj_out mean: 0.085 ± 0.036, ffwd_out mean: 0.088 ± 0.059
Block 5: proj_out mean: 0.169 ± 0.059, ffwd_out mean: 0.133 ± 0.091


In [None]:
def filter_across_batches(
    get_batch: Callable[[int], torch.Tensor],
    n_batches: int,
    filter_fn: Callable[[torch.Tensor], torch.Tensor],
    n_queries: int,
):
    total_count = 0
    matching_indices = [[] for _ in range(n_queries)]
    for i in range(n_batches):
        batch = get_batch(i)
        batch_size, n_queries_batch = batch.shape
        assert n_queries_batch == n_queries

        filtered = filter_fn(batch)
        nonzeros = torch.nonzero(filtered)
        for i in range(nonzeros.shape[0]):
            idx_in_batch, query_idx = nonzeros[i, :]
            matching_indices[query_idx.item()].append(total_count + idx_in_batch.item())

        total_count += batch_size

    return matching_indices

In [None]:
# Tests for filter_across_batches()

batches = [
    torch.tensor([
        [0.0, 0.6, 0.4, 0.3],
        [0.1, 0.3, 0.5, 0.1],
        [0.0, 0.1, 0.8, 0.0],
    ]),
    torch.tensor([
        [0.7, 0.2, 0.6, 0.3],
        [0.1, 0.8, 0.2, 0.8],
    ]),
]

result = filter_across_batches(
    get_batch=lambda i: batches[i],
    n_batches=len(batches),
    filter_fn=lambda batch: batch > 0.5,
    n_queries=4,
)
test_eq(result, [
    [3,],
    [0, 4],
    [2, 3],
    [4],
])

In [None]:
def filter_result_stats(
    filter_results: List[List[int]],
):
    lens = [len(result) for result in filter_results]
    return {
        'min': min(lens),
        'max': max(lens),
        'mean': np.mean(lens),
        'std': np.std(lens),
    }

In [None]:
def get_matching_strings(
    filter_result: List[List[int]],
    strings: Sequence[str],
):
    return [
        [
            strings[j]
            for j in filter_result[i]
        ]
        for i in range(len(filter_result))
    ]

In [None]:
embs_result = filter_across_batches(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.embedding_sims_filename(batch_idx=batch_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.8,
    n_queries=sample_size,
)

In [None]:
filter_result_stats(embs_result)

{'min': 2, 'max': 1895, 'mean': 193.438, 'std': 268.96681236910996}

In [None]:
for i, result in enumerate(embs_result):
    print(f"Query {i} ({repr(strings_sample[i])}): {', '.join([repr(exp10.strings[j]) for j in result[:100]])}")

Query 0 ('is dreams,'): 'is great;\n', 'is breast ', 'is great a', "'s breast,", 'is treadin', 'is breathi', 'nd dreams,', 'nd dreams;', 'er dreamt ', "ns drew'st", "'s dread c", 'ng dream\nA', 'ly dreams,', 'ur dream? ', 'my dream w', 'he dream.\n', 'He dreamt ', 'is dreams,', 'id dream t', 'is treason', 'us dreams.', '\nA dream o', 'is greates', 'ut dream.\n', 'ul dream!\n', 'ng dreams\n', ' a dream.\n', 'ir spears,', 'is great; ', 'py dream;\n', 'us groans,', 'is deep si', 'er treads,', ' I dream n', ' a dream t', 'at dreamer', 'do dream t', 'ey dream o', 'at dream o', 'ht dream o', 'es dream,\n', 'en dreams ', 'en dreams,', 'of dreams,', ' a dream,\n', 'My dreams ', '\nI dreamt ', 'ge dream, ', ' I dream i', 'r, dreamin', 'is breast;', 'ut dream o', 'to dream u', 'is treasur', "or dream'd", 'th dreams;', 'is great s', 'is great e', 'He dreads ', 'ur dreams,', 'my dreams;', "ut dream'd", 'as dream\nS', 'ot dreams ', 'an dream o', 'is dream o', 'is precise', ' a dream!\n', ' I dream 

In [None]:
proj_out_results = [
    filter_across_batches(
        get_batch=lambda batch_idx: torch.load(str(cos_exp.proj_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
        n_batches=cos_exp.n_batches,
        filter_fn=lambda batch: batch > 0.95,
        n_queries=sample_size,
    )
    for block_idx in tqdm(range(n_layer))
]

ffwd_out_results = [
    filter_across_batches(
        get_batch=lambda batch_idx: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
        n_batches=cos_exp.n_batches,
        filter_fn=lambda batch: batch > 0.95,
        n_queries=sample_size,
    )
    for block_idx in tqdm(range(n_layer))
]

In [None]:
next_token_map10 = build_next_token_map(ts.text, 10, tokenizer.vocab_size, tokenizer.stoi)

In [None]:
# TODO: put this in a common component
def analyze_simulate_results(sim_freqs, model_outputs):
    assert len(sim_freqs) == len(model_outputs)
    topn_matches = [0 for _ in range(10)]
    topn_matches_any_order = [0 for _ in range(10)]
    for i, sim_freq in enumerate(sim_freqs):
        sim_output = top_nonzero_tokens(sim_freq, encoding_helpers.tokenizer.itos)[:10]
        model_output = model_outputs[i]

        sim_tokens, _ = zip(*sim_output)
        model_tokens, _ = zip(*model_output)

        n = min(len(sim_tokens), len(model_tokens))
        for j in range(n):
            if sim_tokens[j] == model_tokens[j]:
                topn_matches[j] += 1
            if set(sim_tokens[:j+1]) == set(model_tokens[:j+1]):
                topn_matches_any_order[j] += 1

    return topn_matches, topn_matches_any_order


In [None]:
def analyze_dataset(
    get_batch: Callable[[int], torch.Tensor],
    n_batches: int,
    filter_fn: Callable[[torch.Tensor], torch.Tensor],
    n_queries: int,
    all_strings: Sequence[str],
    next_token_map: Dict[str, torch.Tensor],
    model_outputs: Sequence[Sequence[Tuple[str, float]]],
):
    filter_results = filter_across_batches(
        get_batch=get_batch,
        n_batches=n_batches,
        filter_fn=filter_fn,
        n_queries=n_queries,
    )

    print(filter_result_stats(filter_results))

    filter_results_strings = get_matching_strings(filter_results, all_strings)
    filter_result_freqs = [
        torch.stack([
            next_token_map[matching_string]
            for matching_string in matching_strings
        ]).sum(dim=0)
        for matching_strings in filter_results_strings
    ]

    filter_result_probs = [
        freqs / freqs.sum()
        for freqs in filter_result_freqs
    ]

    topn_matches, topn_matches_any_order = analyze_simulate_results(filter_result_probs, model_outputs)
    for i in range(10):
        print(f"Top {i+1} matches: {topn_matches[i] / sample_size:.3f}")
        print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / sample_size:.3f}")

    return filter_result_freqs

In [None]:
block_idx = 5
ffwd5_freqs = analyze_dataset(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.95,
    n_queries=sample_size,
    all_strings=exp10.strings,
    next_token_map=next_token_map10,
    model_outputs=model_outputs_sample,
)

{'min': 1, 'max': 16224, 'mean': 922.742, 'std': 2399.3301297312128}
Top 1 matches: 0.786
Top 1 matches (any order): 0.786
Top 2 matches: 0.404
Top 2 matches (any order): 0.440
Top 3 matches: 0.236
Top 3 matches (any order): 0.306
Top 4 matches: 0.186
Top 4 matches (any order): 0.222
Top 5 matches: 0.118
Top 5 matches (any order): 0.164
Top 6 matches: 0.104
Top 6 matches (any order): 0.124
Top 7 matches: 0.066
Top 7 matches (any order): 0.082
Top 8 matches: 0.052
Top 8 matches (any order): 0.054
Top 9 matches: 0.040
Top 9 matches (any order): 0.048
Top 10 matches: 0.034
Top 10 matches (any order): 0.034


In [None]:
emb_freqs = analyze_dataset(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.embedding_sims_filename(batch_idx=batch_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.96,
    n_queries=sample_size,
    all_strings=exp10.strings,
    next_token_map=next_token_map10,
    model_outputs=model_outputs_sample,
)


{'min': 1, 'max': 2, 'mean': 1.002, 'std': 0.044676615807377355}
Top 1 matches: 0.616
Top 1 matches (any order): 0.616
Top 2 matches: 0.020
Top 2 matches (any order): 0.022
Top 3 matches: 0.004
Top 3 matches (any order): 0.004
Top 4 matches: 0.002
Top 4 matches (any order): 0.000
Top 5 matches: 0.000
Top 5 matches (any order): 0.000
Top 6 matches: 0.000
Top 6 matches (any order): 0.000
Top 7 matches: 0.000
Top 7 matches (any order): 0.000
Top 8 matches: 0.000
Top 8 matches (any order): 0.000
Top 9 matches: 0.000
Top 9 matches (any order): 0.000
Top 10 matches: 0.000
Top 10 matches (any order): 0.000


In [None]:
block_idx = 4
ffwd4_freqs = analyze_dataset(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.83,
    n_queries=sample_size,
    all_strings=exp10.strings,
    next_token_map=next_token_map10,
    model_outputs=model_outputs_sample,
)

{'min': 1, 'max': 48461, 'mean': 1877.104, 'std': 5356.329997786171}
Top 1 matches: 0.776
Top 1 matches (any order): 0.776
Top 2 matches: 0.456
Top 2 matches (any order): 0.520
Top 3 matches: 0.296
Top 3 matches (any order): 0.382
Top 4 matches: 0.198
Top 4 matches (any order): 0.260
Top 5 matches: 0.174
Top 5 matches (any order): 0.196
Top 6 matches: 0.130
Top 6 matches (any order): 0.176
Top 7 matches: 0.096
Top 7 matches (any order): 0.124
Top 8 matches: 0.078
Top 8 matches (any order): 0.092
Top 9 matches: 0.078
Top 9 matches (any order): 0.080
Top 10 matches: 0.062
Top 10 matches (any order): 0.052


In [None]:
block_idx = 3
ffwd3_freqs = analyze_dataset(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.97,
    n_queries=sample_size,
    all_strings=exp10.strings,
    next_token_map=next_token_map10,
    model_outputs=model_outputs_sample,
)

{'min': 1, 'max': 1260, 'mean': 38.416, 'std': 114.08503383003399}
Top 1 matches: 0.724
Top 1 matches (any order): 0.724
Top 2 matches: 0.256
Top 2 matches (any order): 0.282
Top 3 matches: 0.084
Top 3 matches (any order): 0.098
Top 4 matches: 0.060
Top 4 matches (any order): 0.068
Top 5 matches: 0.046
Top 5 matches (any order): 0.062
Top 6 matches: 0.030
Top 6 matches (any order): 0.022
Top 7 matches: 0.022
Top 7 matches (any order): 0.016
Top 8 matches: 0.018
Top 8 matches (any order): 0.008
Top 9 matches: 0.018
Top 9 matches (any order): 0.016
Top 10 matches: 0.008
Top 10 matches (any order): 0.008


In [None]:
block_idx = 2
ffwd2_freqs = analyze_dataset(
    get_batch=lambda batch_idx: torch.load(str(cos_exp.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
    n_batches=cos_exp.n_batches,
    filter_fn=lambda batch: batch > 0.90,
    n_queries=sample_size,
    all_strings=exp10.strings,
    next_token_map=next_token_map10,
    model_outputs=model_outputs_sample,
)

{'min': 1, 'max': 6046, 'mean': 468.448, 'std': 876.6662804602445}
Top 1 matches: 0.770
Top 1 matches (any order): 0.770
Top 2 matches: 0.454
Top 2 matches (any order): 0.500
Top 3 matches: 0.278
Top 3 matches (any order): 0.360
Top 4 matches: 0.214
Top 4 matches (any order): 0.246
Top 5 matches: 0.136
Top 5 matches (any order): 0.174
Top 6 matches: 0.102
Top 6 matches (any order): 0.130
Top 7 matches: 0.110
Top 7 matches (any order): 0.098
Top 8 matches: 0.062
Top 8 matches (any order): 0.082
Top 9 matches: 0.060
Top 9 matches (any order): 0.072
Top 10 matches: 0.052
Top 10 matches (any order): 0.050


In [None]:
total_freqs = [
    0.08*ffwd2_freqs[i] + 2 * ffwd3_freqs[i] + 3 * ffwd4_freqs[i] + 10 * ffwd5_freq
    for i, ffwd5_freq in enumerate(ffwd5_freqs)
]
total_probs = [
    freqs / freqs.sum()
    for freqs in total_freqs
]
topn_matches, topn_matches_any_order = analyze_simulate_results(total_probs, model_outputs_sample)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / sample_size:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / sample_size:.3f}")

Top 1 matches: 0.800
Top 1 matches (any order): 0.800
Top 2 matches: 0.500
Top 2 matches (any order): 0.546
Top 3 matches: 0.320
Top 3 matches (any order): 0.430
Top 4 matches: 0.250
Top 4 matches (any order): 0.306
Top 5 matches: 0.188
Top 5 matches (any order): 0.218
Top 6 matches: 0.152
Top 6 matches (any order): 0.190
Top 7 matches: 0.110
Top 7 matches (any order): 0.130
Top 8 matches: 0.084
Top 8 matches (any order): 0.106
Top 9 matches: 0.088
Top 9 matches (any order): 0.092
Top 10 matches: 0.060
Top 10 matches (any order): 0.070


In [None]:
topn_matches, topn_matches_any_order = analyze_simulate_results(total_probs, model_outputs_sample)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / sample_size:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / sample_size:.3f}")

Top 1 matches: 0.800
Top 1 matches (any order): 0.800
Top 2 matches: 0.474
Top 2 matches (any order): 0.522
Top 3 matches: 0.308
Top 3 matches (any order): 0.416
Top 4 matches: 0.236
Top 4 matches (any order): 0.288
Top 5 matches: 0.182
Top 5 matches (any order): 0.208
Top 6 matches: 0.142
Top 6 matches (any order): 0.184
Top 7 matches: 0.112
Top 7 matches (any order): 0.124
Top 8 matches: 0.074
Top 8 matches (any order): 0.102
Top 9 matches: 0.082
Top 9 matches (any order): 0.078
Top 10 matches: 0.052
Top 10 matches (any order): 0.066


## Cosine Sims for Length 256 Strings


In [None]:
strings256 = all_unique_substrings(ts.text, 256)

In [None]:
torch.manual_seed(1337)
sample_size = 500
indices256 = torch.randperm(len(strings256))[:sample_size]
strings256_sample = [strings256[i.item()] for i in indices256]

In [None]:
model_outputs_sample256 = get_model_outputs(strings256_sample, encoding_helpers)

In [None]:
prompts_exp256 = BlockInternalsExperiment(encoding_helpers, accessors, strings256_sample)

In [None]:
class CosineSimilaritiesForFinalFFWDExperiment:
    def __init__(
        self,
        exp: FinalFFWDExperiment,
        output_folder: Path,
    ):
        self.exp = exp
        self.output_folder = output_folder

        self.n_batches = exp.n_batches

    def ffwd_out_sims_filename(self, batch_idx: int, block_idx: int):
        return self.output_folder / f'ffwds_out_sims_{batch_idx:04d}_{block_idx:02d}.pt'

    def generate_ffwd_out_sims(self, get_queries: Callable[[int], torch.Tensor], disable_progress_bar=False):
        block_idx = n_layer - 1
        for batch_idx in tqdm(range(self.exp.n_batches), disable=disable_progress_bar):
            queries = get_queries(block_idx)
            assert queries.dim() == 2
            n_queries = queries.shape[0]

            ffwd_out_batch = torch.load(str(self.exp._ffwd_output_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True)
            batch_size = ffwd_out_batch.shape[0]
            sims = F.cosine_similarity(
                ffwd_out_batch.reshape(batch_size, 1, -1).expand(-1, n_queries, -1),
                queries,
                dim=-1
            )
            torch.save(sims, str(self.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)))




In [None]:
ffwd_exp256 = FinalFFWDExperiment(
    eh=encoding_helpers,
    accessors=accessors,
    strings=strings256,
    output_dir=Path('../artifacts/block_internals_results/large_files/slen256'),
    batch_size=400,
)

In [None]:
output_folder = ffwd_exp256.output_dir / 'cosine_sims'
output_folder.mkdir(exist_ok=True)

In [None]:
cos_exp256 = CosineSimilaritiesForFinalFFWDExperiment(ffwd_exp256, output_folder)

In [None]:
cos_exp256.generate_ffwd_out_sims(get_queries=lambda block_idx: prompts_exp256.ffwd_output(block_idx=block_idx)[:, -1, :])

  0%|          | 0/2788 [00:00<?, ?it/s]

In [None]:
next_token_map256 = build_next_token_map(ts.text, 256, tokenizer.vocab_size, tokenizer.stoi)

In [None]:
block_idx = n_layer - 1
ffwd256 = filter_across_batches(
    get_batch=lambda batch_idx: torch.load(str(cos_exp256.ffwd_out_sims_filename(batch_idx=batch_idx, block_idx=block_idx)), mmap=True),
    n_batches=cos_exp256.n_batches,
    filter_fn=lambda batch: batch > 0.91,
    n_queries=sample_size,
)
filter_result_stats(ffwd256)

{'min': 1, 'max': 38780, 'mean': 1561.9, 'std': 4162.8430373964375}

In [None]:
ffwd256_strings = get_matching_strings(ffwd256, strings256)

In [None]:
ffwd256_freqs = [
    torch.stack([
        next_token_map256[matching_string]
        for matching_string in matching_strings
    ]).sum(dim=0)
    for matching_strings in ffwd256_strings
]
ffwd256_probs = [
    freqs / freqs.sum()
    for freqs in ffwd256_freqs
]

In [None]:
topn_matches, topn_matches_any_order = analyze_simulate_results(ffwd256_probs, model_outputs_sample256)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / sample_size:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / sample_size:.3f}")

Top 1 matches: 0.802
Top 1 matches (any order): 0.802
Top 2 matches: 0.362
Top 2 matches (any order): 0.412
Top 3 matches: 0.246
Top 3 matches (any order): 0.298
Top 4 matches: 0.200
Top 4 matches (any order): 0.250
Top 5 matches: 0.156
Top 5 matches (any order): 0.154
Top 6 matches: 0.128
Top 6 matches (any order): 0.148
Top 7 matches: 0.070
Top 7 matches (any order): 0.086
Top 8 matches: 0.044
Top 8 matches (any order): 0.068
Top 9 matches: 0.032
Top 9 matches (any order): 0.038
Top 10 matches: 0.030
Top 10 matches (any order): 0.016


In [None]:
idx = 0
print(f"Query {idx} ({repr(strings256_sample[idx])}): ")
result = ffwd256[idx]
for j in result[:20]:
    print(f"  {repr(strings256[j])}")


Query 0 (" see my shame in him.\nThou art a widow; yet thou art a mother,\nAnd hast the comfort of thy children left thee:\nBut death hath snatch'd my husband from mine arms,\nAnd pluck'd two crutches from my feeble limbs,\nEdward and Clarence. O, what cause have I,\nThin"): 
  " so.\n\nSICINIUS:\nLet them assemble,\nAnd on a safer judgment all revoke\nYour ignorant election; enforce his pride,\nAnd his old hate unto you; besides, forget not\nWith what contempt he wore the humble weed,\nHow in his suit he scorn'd you; but your loves,\nThin"
  "I'ld have beaten him like a dog, but for\ndisturbing the lords within.\n\nAUFIDIUS:\nWhence comest thou? what wouldst thou? thy name?\nWhy speak'st not? speak, man: what's thy name?\n\nCORIOLANUS:\nIf, Tullus,\nNot yet thou knowest me, and, seeing me, dost not\nThin"
  "e not--to save my life, for if\nI had fear'd death, of all the men i' the world\nI would have 'voided thee, but in mere spite,\nTo be full quit of those my banishers,\nStand I bef

In [None]:
class TensorBatchIterator:
    def __init__(self, n_batches: int, get_batch: Callable[[int], torch.Tensor]):
        self.n_batches = n_batches
        self.get_batch = get_batch

        self.next_batch_idx = 0
        self.current_batch: Optional[torch.Tensor] = None
        self.idx_within_batch = 0

        self._load_next_batch()

    def _load_next_batch(self):
        if self.next_batch_idx >= self.n_batches:
            raise StopIteration()

        self.current_batch = self.get_batch(self.next_batch_idx)
        self.idx_within_batch = 0
        self.next_batch_idx += 1

    def __next__(self):
        if self.current_batch is None:
            raise StopIteration()

        if self.idx_within_batch >= self.current_batch.shape[0]:
            self._load_next_batch()
            if self.current_batch is None:
                raise StopIteration()

        result = self.current_batch[self.idx_within_batch, :]
        self.idx_within_batch += 1
        return result

class EmbeddingCosineSims:
    def __init__(self, exp: CosineSimilaritiesExperiment):
        self.exp = exp

    def __iter__(self):
        return TensorBatchIterator(
            n_batches=self.exp.n_batches,
            get_batch=lambda batch_idx: torch.load(str(self.exp.embedding_sims_filename(batch_idx=batch_idx)), mmap=True)
        )



In [None]:
emb_sims = EmbeddingCosineSims(cos_exp)
for i, sim in enumerate(emb_sims):
    print(i, sim.shape)

In [None]:
sims = next(iter(emb_sims))
sims.shape

torch.Size([500])