# Interpreting the Projection and Feed-Forward Layers in a Self-Attention Block

> A summary of my experiments to understand the projection layer and feed-forward layer of a self-attention block.

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from dataclasses import dataclass
import json
import math
from pathlib import Path
from typing import Callable, Dict, List, Optional, Iterable, Protocol, Sequence, Tuple, TypeVar, Type

In [None]:
# | hide

from fastcore.test import *
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
import torch
from torch.nn import functional as F
from tqdm.auto import tqdm

In [None]:
# Set a manual seed so output is deterministic (used same value as @karpathy)
_ = torch.manual_seed(1337)

In [None]:
# | hide

from transformer_experiments.common.substring_generator import all_unique_substrings
from transformer_experiments.common.text_analysis import (
    build_next_token_map,
    SubstringFrequencyAnalysis,
    top_nonzero_tokens
)
from transformer_experiments.common.utils import (
    aggregate_by_string_key,
    DataWrapper,
    topk_across_batches,
)
from transformer_experiments.dataset_split import split_text_dataset
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.models.transformer import (
    n_layer,
    TransformerLanguageModel
)
from transformer_experiments.models.transformer_helpers import (
    unsqueeze_emb,
    EncodingHelpers,
    LogitsWrapper,
    TransformerAccessors
)
from transformer_experiments.trained_models.tinyshakespeare_transformer import (
    create_model_and_tokenizer
)
from transformer_experiments.experiments.block_internals import (
    BlockInternalsAccessors,
    BlockInternalsExperiment,
    BatchedBlockInternalsExperiment,
    BlockInternalsAnalysis,
)
from transformer_experiments.experiments.similar_strings import (
    SimilarStringsData,
    SimilarStringsExperiment,
    SimilarStringsResult
)
from transformer_experiments.experiments.logit_lens import LogitLens

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ts = TinyShakespeareDataSet(cache_file='../artifacts/input.txt')
m, tokenizer = create_model_and_tokenizer(
    saved_model_filename='../artifacts/shakespeare.pt',
    dataset=ts,
    device=device,
)
_, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9)
encoding_helpers = EncodingHelpers(tokenizer, device)
accessors = TransformerAccessors(m, device)

In [None]:
print(f"device is {device}")

device is cpu


In [None]:
strings10 = all_unique_substrings(ts.text, 10)

In [None]:
# Build a next token map for each prefix length that we've run experiments for.
next_token_map3 = build_next_token_map(ts.text, prefix_len=3, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map4 = build_next_token_map(ts.text, prefix_len=4, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map5 = build_next_token_map(ts.text, prefix_len=5, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map6 = build_next_token_map(ts.text, prefix_len=6, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map7 = build_next_token_map(ts.text, prefix_len=7, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map8 = build_next_token_map(ts.text, prefix_len=8, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map9 = build_next_token_map(ts.text, prefix_len=9, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)
next_token_map10 = build_next_token_map(ts.text, prefix_len=10, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi)

In [None]:
all_token_lens = [3, 4, 5, 6, 7, 8, 9, 10]
all_token_maps = [
    next_token_map3,
    next_token_map4,
    next_token_map5,
    next_token_map6,
    next_token_map7,
    next_token_map8,
    next_token_map9,
    next_token_map10
]

In [None]:
# Combine all the token maps into one
next_token_map_all = {
    **next_token_map3,
    **next_token_map4,
    **next_token_map5,
    **next_token_map6,
    **next_token_map7,
    **next_token_map8,
    **next_token_map9,
    **next_token_map10
}

In [None]:
# Sanity check for entries that have no next token. This should only be the case
# for cases where the last substring in the text is unique.
for k, v in next_token_map_all.items():
    if v.sum() == 0:
        print(f"{repr(k)} has no next tokens")

In [None]:
# Sanity check all the lengths are right.
for l, token_map in zip(all_token_lens, all_token_maps):
    for k in token_map.keys():
        if len(k) != l:
            print(f"{repr(k)} has length {len(k)} but should have length {l}")

In [None]:
if list(Path('../artifacts/block_internals_results/large_files/slen10/').glob('*')) == []:
    print("Run `make block_internals_slen10_dataset` in the project root to generate the required dataset")

In [None]:
exp10 = BatchedBlockInternalsExperiment(
    eh=encoding_helpers,
    accessors=accessors,
    strings=strings10,
    output_dir=Path('../artifacts/block_internals_results/large_files/slen10/'),
    batch_size=10000,
)

In [None]:
# Run a similar strings experiment on a bunch of sample strings we'll use for analysis
output_dir = Path('../artifacts/block_internals_results/similar_strings_sample')
output_dir.mkdir(exist_ok=True)

ssexp = SimilarStringsExperiment(output_dir, encoding_helpers)
sample_strings = ['First Citi', 'Citizen:\nB', 'Shyamalan ', 'more in jo']
n_similars = 10
batch_size=len(sample_strings)
if not (output_dir / 'string_to_batch_map.json').exists():
    ssexp.generate_string_to_batch_map(sample_strings, batch_size)

try:
    _ = next(iter(output_dir.glob('embs_sim_strings-*.json')))
except StopIteration:
    ssexp.generate_embeddings_files(sample_strings, accessors, exp10, batch_size=batch_size, n_similars=n_similars)

try:
    _ = next(iter(output_dir.glob('proj_out_sim_strings-*.json')))
except StopIteration:
    ssexp.generate_proj_out_files(sample_strings, t_i=-1, accessors=accessors, exp=exp10, batch_size=batch_size, n_similars=n_similars)

try:
    _ = next(iter(output_dir.glob('ffwd_out_sim_strings-*.json')))
except StopIteration:
    ssexp.generate_ffwd_out_files(sample_strings, t_i=-1, accessors=accessors, exp=exp10, batch_size=batch_size, n_similars=n_similars)


In [None]:
T = TypeVar('T', bound='SimilarStringsFrequencyAndDistanceData')


@dataclass
class SimilarStringsFrequencyAndDistanceData:
    """Encapsulates the frequency and distance data associated
    with a set of `SimilarStringsResult`s."""

    strings: Sequence[str] # N strings
    string_to_idx: Dict[str, int]  # Maps N strings to indices 0..N-1
    emb_freqs: torch.Tensor  # (N, n_similars, vocab_size)
    emb_distances: torch.Tensor  # (N, n_similars)
    proj_freqs: torch.Tensor  # (N, n_layer, n_similars, vocab_size)
    proj_distances: torch.Tensor  # (N, n_layer, n_similars)
    ffwd_freqs: torch.Tensor  # (N, n_layer, n_similars, vocab_size)
    ffwd_distances: torch.Tensor  # (N, n_layer, n_similars)

    @classmethod
    def from_results(
        cls: Type[T],
        ss_results: Dict[str, SimilarStringsResult],
        next_token_map: Dict[str, torch.Tensor],
        aggregate_over_t_is: Sequence[int] = [-1],
        largest: bool = False,
    ) -> T:
        strings: List[str] = []
        string_to_idx: Dict[str, int] = {}

        all_emb_freqs = []
        all_emb_distances = []
        all_proj_freqs = []
        all_proj_distances = []
        all_ffwd_freqs = []
        all_ffwd_distances = []

        for i, (s, result) in enumerate(ss_results.items()):
            strings.append(s)
            string_to_idx[s] = i
            aggr_proj_out, aggr_ffwd_out = result.aggregate_over_t_is(
                aggregate_over_t_is,
                largest=largest,
            )

            emb_freqs = torch.stack(
                [next_token_map[s] for s in result.embs.sim_strings]
            )
            all_emb_freqs.append(emb_freqs)

            emb_distances = result.embs.distances
            all_emb_distances.append(emb_distances)

            proj_freqs = torch.stack(
                [
                    torch.stack(
                        [
                            next_token_map[s]
                            for s in aggr_proj_out[block_idx].sim_strings
                        ]
                    )
                    for block_idx in range(n_layer)
                ]
            )
            all_proj_freqs.append(proj_freqs)

            proj_distances = torch.stack(
                [aggr_proj_out[block_idx].distances for block_idx in range(n_layer)]
            )
            all_proj_distances.append(proj_distances)

            ffwd_freqs = torch.stack(
                [
                    torch.stack(
                        [
                            next_token_map[s]
                            for s in aggr_ffwd_out[block_idx].sim_strings
                        ]
                    )
                    for block_idx in range(n_layer)
                ]
            )
            all_ffwd_freqs.append(ffwd_freqs)

            ffwd_distances = torch.stack(
                [aggr_ffwd_out[block_idx].distances for block_idx in range(n_layer)]
            )
            all_ffwd_distances.append(ffwd_distances)

        return cls(
            strings,
            string_to_idx,
            torch.stack(all_emb_freqs),  # (len(ss_results), n_similars, vocab_size)
            torch.stack(all_emb_distances),  # (len(ss_results), n_similars)
            torch.stack(
                all_proj_freqs
            ),  # (len(ss_results), n_layer, n_similars, vocab_size)
            torch.stack(all_proj_distances),  # (len(ss_results), n_layer, n_similars)
            torch.stack(
                all_ffwd_freqs
            ),  # (len(ss_results), n_layer, n_similars, vocab_size)
            torch.stack(all_ffwd_distances),  # (len(ss_results), n_layer, n_similars)
        )

In [None]:
# Tests for SimilarStringsFrequencyAndDistanceData
ss_data = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ssexp.load_results_for_strings(sample_strings),
    next_token_map=next_token_map_all,
    aggregate_over_t_is=[-1],
)
test_eq(ss_data.strings, sample_strings)
test_eq(len(ss_data.string_to_idx), len(sample_strings))
test_eq(ss_data.emb_freqs.shape, (len(sample_strings), n_similars, tokenizer.vocab_size))
test_eq(ss_data.emb_distances.shape, (len(sample_strings), n_similars))
test_eq(ss_data.proj_freqs.shape, (len(sample_strings), n_layer, n_similars, tokenizer.vocab_size))
test_eq(ss_data.proj_distances.shape, (len(sample_strings), n_layer, n_similars))
test_eq(ss_data.ffwd_freqs.shape, (len(sample_strings), n_layer, n_similars, tokenizer.vocab_size))
test_eq(ss_data.ffwd_distances.shape, (len(sample_strings), n_layer, n_similars))

In [None]:
class ComputeNextTokenFreqs(Protocol):
    def __call__(
        self, prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
    ) -> torch.Tensor:
        ...


class ModelSimulation:
    def __init__(
        self,
        ss_data: SimilarStringsFrequencyAndDistanceData,
        compute_next_token_freqs: ComputeNextTokenFreqs,
        encoding_helpers: EncodingHelpers,
    ):
        self.ss_data = ss_data
        self.get_next_token_freqs = compute_next_token_freqs
        self.encoding_helpers = encoding_helpers

    def __call__(self, prompts: Sequence[str]):
        prompt_idxs = torch.tensor(
            [self.ss_data.string_to_idx[prompt] for prompt in prompts], dtype=torch.long
        )

        freqs = self.get_next_token_freqs(prompt_idxs, self.ss_data)

        return [
            top_nonzero_tokens(f, self.encoding_helpers.tokenizer.itos)[:10]
            for f in freqs
        ]

In [None]:
def get_model_outputs(prompts: Sequence[str], encoding_helpers: EncodingHelpers):
    # Compute the model's predictions:
    tokens = encoding_helpers.tokenize_strings(prompts)
    logits, _ = m(tokens)

    logits = LogitsWrapper(logits, encoding_helpers.tokenizer)
    return [topk_tokens[-1] for topk_tokens in logits.topk_tokens(k=10)]


In [None]:
def next_token_freqs_progressive_ffwd_weight(
    prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
):
    emb_weight = torch.tensor(1.0, dtype=torch.float32)
    proj_weights = torch.tensor(
        [1.0 for _ in range(n_layer)], dtype=torch.float32
    ).unsqueeze(dim=1)
    ffwd_weights = torch.tensor(
        [1 + block_idx for block_idx in range(n_layer)], dtype=torch.float32
    ).unsqueeze(dim=1)
    freqs = (
        (emb_weight * ss_data.emb_freqs[prompt_idxs, :]).sum(dim=1)
        + (proj_weights * ss_data.proj_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
        + (ffwd_weights * ss_data.ffwd_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
    )
    return freqs / freqs.sum(dim=1, keepdim=True)

In [None]:
simulate = ModelSimulation(
    ss_data=ss_data,
    compute_next_token_freqs=next_token_freqs_progressive_ffwd_weight,
    encoding_helpers=encoding_helpers,
)

In [None]:
strings10[14423:14433]

['more in jo',
 'ore in joy',
 're in joy ',
 'e in joy a',
 ' in joy at',
 'in joy at ',
 'n joy at f',
 ' joy at fi',
 'joy at fir',
 'oy at firs']

In [None]:
string_to_idx = {
    s: i for i, s in enumerate(sample_strings)
}
sample_model_outputs = get_model_outputs(sample_strings, encoding_helpers)


In [None]:
simulate(['First Citi'])[0], sample_model_outputs[string_to_idx['First Citi']]

([('z', 0.9380468726158142),
  ('i', 0.03080154024064541),
  ('e', 0.01610080525279045),
  ('c', 0.004550227429717779),
  ('h', 0.003850192530080676),
  ('p', 0.002100104931741953),
  (':', 0.0014000700321048498),
  ('o', 0.0014000700321048498),
  ('u', 0.0007000350160524249),
  (' ', 0.0007000350160524249)],
 [('z', 0.9996668100357056),
  ('u', 0.00010660554107744247),
  ('I', 7.993520557647571e-05),
  ('U', 2.734881672949996e-05),
  ('K', 2.4257360564661212e-05),
  ('P', 1.5074498151079752e-05),
  ('L', 1.0885321898967959e-05),
  ('n', 8.451069334114436e-06),
  ('O', 8.223939403251279e-06),
  ('f', 7.135453870432684e-06)])

In [None]:
simulate(['Citizen:\nB'])[0], sample_model_outputs[string_to_idx['Citizen:\nB']]

([('e', 0.3421829044818878),
  ('u', 0.17404130101203918),
  ('o', 0.13716813921928406),
  ('h', 0.09587020426988602),
  ('y', 0.08554572612047195),
  ('a', 0.07669616490602493),
  ('n', 0.033923305571079254),
  ('i', 0.028023598715662956),
  ('r', 0.011799409985542297),
  ('t', 0.007374631240963936)],
 [('e', 0.47825106978416443),
  ('u', 0.2509588301181793),
  ('y', 0.1266946792602539),
  ('r', 0.05788085237145424),
  ('i', 0.03135434538125992),
  ('o', 0.02440422773361206),
  ('a', 0.017102370038628578),
  ('l', 0.012891546823084354),
  ('s', 8.761954813962802e-05),
  ('R', 7.359156006714329e-05)])

In [None]:
simulate(['Shyamalan '])[0], sample_model_outputs[string_to_idx['Shyamalan ']]

([('t', 0.15815085172653198),
  ('b', 0.1228710487484932),
  ('o', 0.11313868314027786),
  ('a', 0.10340632498264313),
  ('i', 0.09975668787956238),
  ('d', 0.058394160121679306),
  ('s', 0.05352798104286194),
  ('w', 0.04866180196404457),
  ('m', 0.03041362576186657),
  ('f', 0.027980534359812737)],
 [('t', 0.16370470821857452),
  ('s', 0.10785210877656937),
  ('a', 0.09744462370872498),
  ('b', 0.09677103161811829),
  ('c', 0.08353256434202194),
  ('m', 0.056587863713502884),
  ('d', 0.048968441784381866),
  ('p', 0.04876013845205307),
  ('h', 0.04751131683588028),
  ('w', 0.038983285427093506)])

In [None]:
simulate(['more in jo'])[0], sample_model_outputs[string_to_idx['more in jo']]

([('y', 0.6007066965103149),
  ('t', 0.16607773303985596),
  ('i', 0.07773851603269577),
  ('s', 0.06713780760765076),
  ('u', 0.038869258016347885),
  ('m', 0.017667844891548157),
  ('d', 0.010600706562399864),
  ('r', 0.007067137863487005),
  ('k', 0.0035335689317435026),
  ('l', 0.0035335689317435026)],
 [('y', 0.8568735718727112),
  ('i', 0.06098264083266258),
  ('u', 0.04135835915803909),
  ('c', 0.016126777976751328),
  ('t', 0.012861563824117184),
  ('l', 0.0022685928270220757),
  ('o', 0.0021818610839545727),
  ('s', 0.0016513006994500756),
  ('v', 0.001559981144964695),
  ('w', 0.0011889823945239186)])

In [None]:
def analyze_simulate_results2(strings: Sequence[str], sim_outputs: Sequence, model_outputs: Sequence):
    """A version of analyze_simulate_results() that computes results for the
    full length of the returned results."""

    topn_matches = [0 for _ in range(10)]
    topn_matches_any_order = [0 for _ in range(10)]
    for i, s in enumerate(strings):
        sim_output = sim_outputs[i]
        model_output = model_outputs[i]
        sim_tokens, _ = zip(*sim_output)
        model_tokens, _ = zip(*model_output)

        n = min(len(sim_tokens), len(model_tokens))
        for j in range(n):
            if sim_tokens[j] == model_tokens[j]:
                topn_matches[j] += 1
            if set(sim_tokens[:j+1]) == set(model_tokens[:j+1]):
                topn_matches_any_order[j] += 1

    return topn_matches, topn_matches_any_order


In [None]:
n_samples = 20000
ss_exp20k = SimilarStringsExperiment(
    exp10.output_dir / 'similar_strings',
    encoding_helpers
)

torch.manual_seed(1337)
indices = torch.randperm(len(exp10.strings))[:n_samples]
strings = [exp10.strings[i.item()] for i in indices]

ss_results20k = ss_exp20k.load_results_for_strings(strings)

ss_data20k = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_results20k,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=[-1],
)

In [None]:
model_outputs20k = get_model_outputs(strings, encoding_helpers)

In [None]:
sim20k = ModelSimulation(
    ss_data=ss_data20k,
    compute_next_token_freqs=next_token_freqs_progressive_ffwd_weight,
    encoding_helpers=encoding_helpers,
)

In [None]:
sim_outputs = sim20k(strings)

In [None]:
topn_matches, topn_matches_any_order = analyze_simulate_results2(strings, sim_outputs, model_outputs20k)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / n_samples:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / n_samples:.3f}")

Top 1 matches: 0.774
Top 1 matches (any order): 0.774
Top 2 matches: 0.398
Top 2 matches (any order): 0.452
Top 3 matches: 0.211
Top 3 matches (any order): 0.238
Top 4 matches: 0.141
Top 4 matches (any order): 0.143
Top 5 matches: 0.102
Top 5 matches (any order): 0.088
Top 6 matches: 0.081
Top 6 matches (any order): 0.054
Top 7 matches: 0.059
Top 7 matches (any order): 0.028
Top 8 matches: 0.053
Top 8 matches (any order): 0.018
Top 9 matches: 0.046
Top 9 matches (any order): 0.010
Top 10 matches: 0.037
Top 10 matches (any order): 0.005


In [None]:
t_is=[3, 4, 5, 6, 7, 8, 9]
ss_results20k_all_t_is = ss_exp20k.load_results_for_strings(strings, load_t_is=t_is)

In [None]:
ss_data20k_aggr = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_results20k_all_t_is,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
)

In [None]:
sim20k_aggr = ModelSimulation(
    ss_data=ss_data20k_aggr,
    compute_next_token_freqs=next_token_freqs_progressive_ffwd_weight,
    encoding_helpers=encoding_helpers,
)

In [None]:
sim_outputs2 = sim20k_aggr(strings)

In [None]:
n_samples = 20000
topn_matches, topn_matches_any_order = analyze_simulate_results2(strings, sim_outputs2, model_outputs20k)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / n_samples:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / n_samples:.3f}")

Top 1 matches: 0.742
Top 1 matches (any order): 0.742
Top 2 matches: 0.363
Top 2 matches (any order): 0.415
Top 3 matches: 0.197
Top 3 matches (any order): 0.212
Top 4 matches: 0.136
Top 4 matches (any order): 0.130
Top 5 matches: 0.098
Top 5 matches (any order): 0.081
Top 6 matches: 0.080
Top 6 matches (any order): 0.048
Top 7 matches: 0.060
Top 7 matches (any order): 0.027
Top 8 matches: 0.053
Top 8 matches (any order): 0.016
Top 9 matches: 0.044
Top 9 matches (any order): 0.010
Top 10 matches: 0.036
Top 10 matches (any order): 0.004


In [None]:
sim20k_aggr(['my most gr'])[0], model_outputs20k[ss_data20k.string_to_idx['my most gr']]

([('a', 0.7179487347602844),
  ('i', 0.09455128014087677),
  ('e', 0.06410256773233414),
  ('h', 0.035256411880254745),
  ('o', 0.028846153989434242),
  ('t', 0.025641025975346565),
  ('u', 0.01923076994717121),
  ('r', 0.008012820966541767),
  ('v', 0.0032051282469183207),
  ('c', 0.0016025641234591603)],
 [('a', 0.4602494537830353),
  ('e', 0.35252559185028076),
  ('o', 0.09188850224018097),
  ('i', 0.09030349552631378),
  ('u', 0.004192721098661423),
  ('y', 0.0007521358784288168),
  ('r', 6.647213740507141e-05),
  ('l', 3.957989065384027e-06),
  ('v', 2.812936827467638e-06),
  ('w', 2.738903503995971e-06)])

In [None]:
sim20k(['my most gr'])[0], model_outputs20k[ss_data20k.string_to_idx['my most gr']]

([('a', 0.74631267786026),
  ('i', 0.14454276859760284),
  ('o', 0.05604719743132591),
  ('e', 0.02654867246747017),
  ('n', 0.005899704992771149),
  ('c', 0.005899704992771149),
  ('v', 0.005899704992771149),
  ('d', 0.0029498524963855743),
  ('r', 0.0029498524963855743),
  ('l', 0.0029498524963855743)],
 [('a', 0.4602494537830353),
  ('e', 0.35252559185028076),
  ('o', 0.09188850224018097),
  ('i', 0.09030349552631378),
  ('u', 0.004192721098661423),
  ('y', 0.0007521358784288168),
  ('r', 6.647213740507141e-05),
  ('l', 3.957989065384027e-06),
  ('v', 2.812936827467638e-06),
  ('w', 2.738903503995971e-06)])

In [None]:
def print_sim_strings(result: SimilarStringsResult, aggregate_over_t_is: Sequence[int], largest: bool = False):
    aggr_proj_out, aggr_ffwd_out = result.aggregate_over_t_is(aggregate_over_t_is, largest=largest)
    n_similars = len(aggr_proj_out[0].sim_strings)
    print("Proj Outputs")
    for i in range(n_similars):
        print(''.join([f"{repr(aggr_proj_out[block_idx].sim_strings[i]):>14} ({aggr_proj_out[block_idx].distances[i]:.2f})" for block_idx in range(n_layer)]))

    print()
    print("FFwd Outputs")
    for i in range(n_similars):
        print(''.join([f"{repr(aggr_ffwd_out[block_idx].sim_strings[i]):>14} ({aggr_ffwd_out[block_idx].distances[i]:.2f})" for block_idx in range(n_layer)]))

In [None]:
print_sim_strings(ss_results20k_all_t_is['my most gr'], t_is)

Proj Outputs
    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)
    ur most gr (0.79)    ur most gr (0.95)    is most gr (2.27)     y most gr (3.56)     my most r (4.90)     my most r (2.75)
    is most gr (0.80)    ne most gr (0.96)    ur most gr (2.43)    ur most gr (3.95)     y most gr (5.00)     my most l (3.52)
    ne most gr (0.80)    he most gr (1.05)     y most gr (2.56)     r most gr (4.34)         my gr (5.38)     my most h (3.79)
    ilst my gr (0.82)    is most gr (1.06)    ne most gr (2.63)       most gr (4.51)         my sl (5.54)    my most st (3.91)
    he most gr (0.84)    e, most gr (1.27)    he most gr (2.88)       most gu (4.61)     my most g (5.54)        my mos (3.97)
    unto my gr (0.89)    o, must gr (1.35)     r most gr (2.99)       most gl (4.67)         my gh (5.69)        my mod (4.01)
    e, most gr (0.89)    t, most gr (1.36)    e, most gr (3.16)    ne most gr (4.72)     my most l

Compare to just looking at t_i=-1:

In [None]:
print_sim_strings(ss_results20k_all_t_is['my most gr'], [-1])

Proj Outputs
    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)    my most gr (0.00)
    ur most gr (0.79)    ur most gr (0.95)    is most gr (2.27)    ur most gr (3.95)    he most gr (5.85)    my most st (3.91)
    is most gr (0.80)    ne most gr (0.96)    ur most gr (2.43)    ne most gr (4.72)    my most st (5.86)    my most sa (4.33)
    ne most gr (0.80)    he most gr (1.05)    ne most gr (2.63)    nd most gu (5.16)    ur most gr (6.43)     my most r (4.56)
    ilst my gr (0.82)    is most gr (1.06)    he most gr (2.88)    nd most gl (5.34)     my most r (6.52)     my most l (5.16)
    he most gr (0.84)    e, most gr (1.27)    e, most gr (3.16)    he most ge (5.49)    my young r (6.56)    my high bl (5.20)
    unto my gr (0.89)    o, must gr (1.35)    t, most gr (3.29)    is most gr (5.53)    my young p (6.58)    m thy moth (5.22)
    e, most gr (0.89)    t, most gr (1.36)    nd most gl (3.57)    ld most gl (5.64)    my young c

How would we do if we just used the next tokens for the prompt? 

In [None]:
def sim_just_next_tokens_from_prompt(prompt: str, next_token_map: Dict[str, torch.Tensor], encoding_helpers: EncodingHelpers):
    next_tokens = next_token_map[prompt]
    return top_nonzero_tokens(
        next_tokens.float() / next_tokens.sum(), encoding_helpers.tokenizer.itos
    )[:10]

In [None]:
sim_jntfp_out = [
    sim_just_next_tokens_from_prompt(s, next_token_map_all, encoding_helpers)
    for s in tqdm(strings)
]


  0%|          | 0/20000 [00:00<?, ?it/s]

In [None]:
n_samples = 20000
topn_matches, topn_matches_any_order = analyze_simulate_results2(strings, sim_jntfp_out, model_outputs20k)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / n_samples:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / n_samples:.3f}")

Top 1 matches: 0.606
Top 1 matches (any order): 0.606
Top 2 matches: 0.011
Top 2 matches (any order): 0.012
Top 3 matches: 0.002
Top 3 matches (any order): 0.001
Top 4 matches: 0.001
Top 4 matches (any order): 0.001
Top 5 matches: 0.001
Top 5 matches (any order): 0.000
Top 6 matches: 0.000
Top 6 matches (any order): 0.000
Top 7 matches: 0.000
Top 7 matches (any order): 0.000
Top 8 matches: 0.000
Top 8 matches (any order): 0.000
Top 9 matches: 0.000
Top 9 matches (any order): 0.000
Top 10 matches: 0.000
Top 10 matches (any order): 0.000


OK, so 60% on the top 1 token, but it quickly falls off after that. 

In [None]:
def unique_contributors(result: SimilarStringsResult, aggregate_over_t_is: Sequence[int]):
    aggr_proj_out, aggr_ffwd_out = result.aggregate_over_t_is(aggregate_over_t_is)

    s_to_next_tokens_map = {}

    for s in result.embs.sim_strings:
        s_to_next_tokens_map[s] = next_token_map_all[s]

    for block_idx in range(n_layer):
        for s in aggr_proj_out[block_idx].sim_strings:
            s_to_next_tokens_map[s] = next_token_map_all[s]
        for s in aggr_ffwd_out[block_idx].sim_strings:
            s_to_next_tokens_map[s] = next_token_map_all[s]

    def _print(item: Tuple[str, torch.Tensor]):
        s, next_tokens = item
        top_tokens = top_nonzero_tokens(next_tokens.float() / next_tokens.sum(), encoding_helpers.tokenizer.itos)
        tokens_str = ', '.join([f"{repr(t):>3} ({p:.2f})" for t, p in top_tokens])
        return f"{repr(s):>14}: {tokens_str}"

    return DataWrapper(s_to_next_tokens_map.items(), _print)



In [None]:
unique_contributors(ss_results20k_all_t_is['my most gr'], t_is).print()

  'my most gr': 'a' (1.00)
  'my most sa': 'c' (1.00)
  't, most gr': 'a' (1.00)
  'my most st': 'a' (1.00)
  'e, most gr': 'a' (1.00)
  'ur most gr': 'a' (1.00)
  'is most gr': 'i' (1.00)
  'my most so': 'v' (1.00)
  'my most re': 'd' (1.00)
  'my most he': 'a' (1.00)
  'ne most gr': 'a' (1.00)
  'ilst my gr': 'o' (1.00)
  'he most gr': 'a' (1.00)
  'unto my gr': 'a' (1.00)
  'yman to gr': 'i' (1.00)
  'o, must gr': 'a' (1.00)
  'be past gr': 'i' (1.00)
  'yet not gr': 'e' (1.00)
  'ver yet gr': 'e' (1.00)
  ' cannot gr': 'e' (1.00)
   'y most gr': 'a' (1.00)
   'r most gr': 'a' (1.00)
   's most gr': 'i' (1.00)
    ' most gr': 'a' (0.89), 'i' (0.11)
  'do them gr': 'a' (1.00)
  'im that gr': 'a' (1.00)
     'most gr': 'a' (0.89), 'i' (0.11)
     'most gu': 'i' (1.00)
    ' most gl': 'a' (1.00)
   'e most gr': 'a' (1.00)
   'my most r': 'e' (1.00)
       'my gr': 'a' (0.55), 'i' (0.22), 'e' (0.16), 'o' (0.06)
       'my sl': 'e' (0.40), 'a' (0.40), 'i' (0.20)
   'my most g': 'r' (1.00

In [None]:
def next_token_freqs_inv_distances(
    prompt_idxs: torch.Tensor, ss_tensors: SimilarStringsFrequencyAndDistanceData
):
    emb_weight = torch.tensor(1.0, dtype=torch.float32)
    proj_weights = torch.tensor(
        [1.0 for _ in range(n_layer)], dtype=torch.float32
    ).unsqueeze(dim=1)
    ffwd_weights = torch.tensor(
        [1 + block_idx for block_idx in range(n_layer)], dtype=torch.float32
    ).unsqueeze(dim=1)

    inv_emb_distances = (1 / (1 + ss_tensors.emb_distances[prompt_idxs, :])).unsqueeze(
        dim=2
    )

    inv_proj_distances = (
        1 / (1 + ss_data20k_aggr.proj_distances[prompt_idxs, :, :])
    ).unsqueeze(dim=3)

    inv_ffwd_distances = (
        1 / (1 + ss_data20k_aggr.ffwd_distances[prompt_idxs, :, :])
    ).unsqueeze(dim=3)

    freqs = (
        (emb_weight * ss_tensors.emb_freqs[prompt_idxs, :] * inv_emb_distances).sum(
            dim=1
        )
        + (
            proj_weights
            * (ss_tensors.proj_freqs[prompt_idxs, :, :, :] * inv_proj_distances).sum(
                dim=2
            )
        ).sum(dim=1)
        + (
            ffwd_weights
            * (ss_tensors.ffwd_freqs[prompt_idxs, :, :, :] * inv_ffwd_distances).sum(
                dim=2
            )
        ).sum(dim=1)
    )
    return freqs / freqs.sum(dim=1, keepdim=True)

In [None]:
sim20k_aggr_alt = ModelSimulation(
    ss_data=ss_data20k_aggr,
    compute_next_token_freqs=next_token_freqs_inv_distances,
    encoding_helpers=encoding_helpers,
)

In [None]:
sim_outputs_alt = sim20k_aggr_alt(strings)

In [None]:
n_samples = 20000
topn_matches, topn_matches_any_order = analyze_simulate_results2(strings, sim_outputs_alt, model_outputs20k)
for i in range(10):
    print(f"Top {i+1} matches: {topn_matches[i] / n_samples:.3f}")
    print(f"Top {i+1} matches (any order): {topn_matches_any_order[i] / n_samples:.3f}")

Top 1 matches: 0.699
Top 1 matches (any order): 0.699
Top 2 matches: 0.352
Top 2 matches (any order): 0.404
Top 3 matches: 0.208
Top 3 matches (any order): 0.220
Top 4 matches: 0.141
Top 4 matches (any order): 0.136
Top 5 matches: 0.103
Top 5 matches (any order): 0.089
Top 6 matches: 0.083
Top 6 matches (any order): 0.055
Top 7 matches: 0.065
Top 7 matches (any order): 0.032
Top 8 matches: 0.053
Top 8 matches (any order): 0.020
Top 9 matches: 0.043
Top 9 matches (any order): 0.012
Top 10 matches: 0.036
Top 10 matches (any order): 0.005


It does worse at the top but a little better further down.

Try a few other functions:

In [None]:
def next_token_freqs_only_last_ffwd(
    prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
):

    freqs = ss_data.ffwd_freqs[prompt_idxs, -1, :, :].sum(dim=1)
    return freqs / freqs.sum(dim=1, keepdim=True)

In [None]:
def next_token_freqs_only_last_ffwd_with_distances(
    prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
):

    inv_ffwd_distances = (
        1 / (1 + ss_data20k_aggr.ffwd_distances[prompt_idxs, -1, :])
    ).unsqueeze(dim=2)

    freqs = (inv_ffwd_distances * ss_data.ffwd_freqs[prompt_idxs, -1, :, :]).sum(dim=1)
    return freqs / freqs.sum(dim=1, keepdim=True)

In [None]:
def try_next_token_freqs_function(
    next_token_freqs_fn: ComputeNextTokenFreqs,
    ss_data: SimilarStringsFrequencyAndDistanceData,
    strings: Sequence[str],
    model_outputs: Sequence[Tuple[str, float]],
):
    sim = ModelSimulation(
        ss_data=ss_data,
        compute_next_token_freqs=next_token_freqs_fn,
        encoding_helpers=encoding_helpers,
    )
    sim_outputs = sim(strings)

    n_samples = len(strings)
    topn_matches, topn_matches_any_order = analyze_simulate_results2(
        strings, sim_outputs, model_outputs
    )
    for i in range(10):
        print(f"Top {i+1} matches: {topn_matches[i] / n_samples:.3f}")
        print(
            f"Top {i+1} matches (any order): {topn_matches_any_order[i] / n_samples:.3f}"
        )

In [None]:
try_next_token_freqs_function(
    next_token_freqs_only_last_ffwd, ss_data20k, strings, model_outputs20k
)

Top 1 matches: 0.744
Top 1 matches (any order): 0.744
Top 2 matches: 0.311
Top 2 matches (any order): 0.362
Top 3 matches: 0.134
Top 3 matches (any order): 0.148
Top 4 matches: 0.071
Top 4 matches (any order): 0.070
Top 5 matches: 0.037
Top 5 matches (any order): 0.030
Top 6 matches: 0.020
Top 6 matches (any order): 0.010
Top 7 matches: 0.011
Top 7 matches (any order): 0.003
Top 8 matches: 0.007
Top 8 matches (any order): 0.001
Top 9 matches: 0.004
Top 9 matches (any order): 0.001
Top 10 matches: 0.002
Top 10 matches (any order): 0.000


This is nearly as good as the best baseline results for top1 and top2, but drops off after that. 

In [None]:
try_next_token_freqs_function(
    next_token_freqs_only_last_ffwd_with_distances, ss_data20k_aggr, strings, model_outputs20k
)

Top 1 matches: 0.676
Top 1 matches (any order): 0.676
Top 2 matches: 0.253
Top 2 matches (any order): 0.300
Top 3 matches: 0.109
Top 3 matches (any order): 0.117
Top 4 matches: 0.057
Top 4 matches (any order): 0.054
Top 5 matches: 0.032
Top 5 matches (any order): 0.024
Top 6 matches: 0.017
Top 6 matches (any order): 0.009
Top 7 matches: 0.010
Top 7 matches (any order): 0.004
Top 8 matches: 0.006
Top 8 matches (any order): 0.002
Top 9 matches: 0.005
Top 9 matches (any order): 0.001
Top 10 matches: 0.003
Top 10 matches (any order): 0.001


For a single example, is it possible to choose weights that give you the same output as the model? 

In [None]:
strings[:10]

['is dreams,',
 'by present',
 's eyes may',
 'eart of ho',
 ' man, as I',
 'mour in a ',
 'LLA:\nAnd h',
 ' crave no ',
 'o find the',
 'e,\nplease ']

In [None]:
prompt = 'my most gr'

prompt_idx = ss_data20k.string_to_idx[prompt]
prompt_idxs = torch.tensor([prompt_idx], dtype=torch.long)

emb_freqs = ss_data20k.emb_freqs[prompt_idxs, :, :]
proj_freqs = ss_data20k.proj_freqs[prompt_idxs, :, :, :]
ffwd_freqs = ss_data20k.ffwd_freqs[prompt_idxs, :, :, :]

In [None]:
tokens = encoding_helpers.tokenize_string(prompt)
logits, _ = m(tokens)
model_output = F.softmax(logits[:, -1, :], dim=-1).squeeze(dim=0)

In [None]:
torch.manual_seed(1337) # Ensure stable random values

# Initialize all the learnable params
emb_weight_param = torch.nn.Parameter(
    torch.randn(1, dtype=torch.float32), requires_grad=True
).to(device)
proj_weights_param = torch.nn.Parameter(
    torch.randn(n_layer, 1, dtype=torch.float32), requires_grad=True
).to(device)
ffwd_weights_param = torch.nn.Parameter(
    torch.randn(n_layer, 1, dtype=torch.float32), requires_grad=True
).to(device)

In [None]:
learning_rate = 3e-3
max_iters = 10000
eval_interval=500

In [None]:
optimizer = torch.optim.AdamW([emb_weight_param, proj_weights_param, ffwd_weights_param], lr=learning_rate)
eval_iters = max_iters // 10

losses = []

for step in tqdm(range(max_iters)):
    optimizer.zero_grad()

    freqs = (
        (emb_weight_param * emb_freqs).sum(dim=1)
        + (proj_weights_param * proj_freqs.sum(dim=2)).sum(dim=1)
        + (ffwd_weights_param * ffwd_freqs.sum(dim=2)).sum(dim=1)
    )
    probs = freqs.squeeze(dim=0).float() / freqs.sum(dim=1)

    loss = torch.norm(probs - model_output, p=2)

    losses.append(loss.item())

    if step % eval_iters == 0:
        print(f"step {step}, loss: {loss.item():.3f}")

    # Take a step
    loss.backward()

    optimizer.step()

    loss = loss.detach()
    emb_freqs = emb_freqs.detach()
    proj_freqs = proj_freqs.detach()
    ffwd_freqs = ffwd_freqs.detach()
    model_output = model_output.detach()


  0%|          | 0/10000 [00:00<?, ?it/s]

step 0, loss: 0.437
step 1000, loss: 0.017
step 2000, loss: 0.005
step 3000, loss: 0.005
step 4000, loss: 0.005
step 5000, loss: 0.005
step 6000, loss: 0.005
step 7000, loss: 0.005
step 8000, loss: 0.005
step 9000, loss: 0.005


In [None]:
with torch.no_grad():
    freqs = (
        (emb_weight_param * emb_freqs).sum(dim=1)
        + (proj_weights_param * proj_freqs.sum(dim=2)).sum(dim=1)
        + (ffwd_weights_param * ffwd_freqs.sum(dim=2)).sum(dim=1)
    )
    probs = freqs.squeeze(dim=0).float() / freqs.sum(dim=1)

(
    top_nonzero_tokens(probs, encoding_helpers.tokenizer.itos)[:10],
    top_nonzero_tokens(model_output, encoding_helpers.tokenizer.itos)[:10]
)

([('a', 0.4609237313270569),
  ('e', 0.35285136103630066),
  ('o', 0.09236107766628265),
  ('i', 0.09082776308059692),
  ('v', 0.0006160509656183422),
  ('r', 0.0005453210324048996),
  ('l', 0.0005453210324048996),
  ('n', 0.0005146691692061722),
  ('c', 0.0005067135789431632),
  ('d', 0.0003080254828091711)],
 [('a', 0.4602494537830353),
  ('e', 0.35252559185028076),
  ('o', 0.09188850224018097),
  ('i', 0.09030349552631378),
  ('u', 0.004192721098661423),
  ('y', 0.0007521358784288168),
  ('r', 6.647213740507141e-05),
  ('l', 3.957989065384027e-06),
  ('v', 2.812936827467638e-06),
  ('w', 2.738903503995971e-06)])

What if we tried these weights for everything? 

In [None]:
emb_weight_param.data, proj_weights_param.data, ffwd_weights_param.data

(tensor([-1.0925]),
 tensor([[ 0.0070],
         [-0.7119],
         [ 1.1149],
         [ 0.0808],
         [-0.0404],
         [-0.6956]]),
 tensor([[-0.1006],
         [ 0.2627],
         [ 0.0467],
         [ 0.1357],
         [ 0.6802],
         [-1.1193]]))

In [None]:
# Try a run using these weights for everything

def next_token_freqs_progressive_learned_weights(
    prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
):
    emb_weight = emb_weight_param.data
    proj_weights = proj_weights_param.data
    ffwd_weights = ffwd_weights_param.data
    freqs = (
        (emb_weight * ss_data.emb_freqs[prompt_idxs, :]).sum(dim=1)
        + (proj_weights * ss_data.proj_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
        + (ffwd_weights * ss_data.ffwd_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
    )
    return freqs / freqs.sum(dim=1, keepdim=True)

try_next_token_freqs_function(
    next_token_freqs_progressive_learned_weights, ss_data20k, strings, model_outputs20k
)

Top 1 matches: 0.548
Top 1 matches (any order): 0.548
Top 2 matches: 0.181
Top 2 matches (any order): 0.184
Top 3 matches: 0.098
Top 3 matches (any order): 0.050
Top 4 matches: 0.072
Top 4 matches (any order): 0.021
Top 5 matches: 0.060
Top 5 matches (any order): 0.007
Top 6 matches: 0.050
Top 6 matches (any order): 0.004
Top 7 matches: 0.040
Top 7 matches (any order): 0.002
Top 8 matches: 0.036
Top 8 matches (any order): 0.002
Top 9 matches: 0.031
Top 9 matches (any order): 0.001
Top 10 matches: 0.022
Top 10 matches (any order): 0.000


Ok, clearly very bad.

## Attempt to learn weights over a large number of examples

In [None]:
def get_batch(
    batch_size: int,
    ss_data: SimilarStringsFrequencyAndDistanceData,
    split: str = 'train',
    train_pct: float = 0.5,
):
    n_train = int(len(ss_data.string_to_idx) * train_pct)
    low = 0 if split == 'train' else n_train
    high = n_train if split == 'train' else len(ss_data.string_to_idx)

    prompt_idxs = torch.randint(
        low=low, high=high, size=(batch_size,), dtype=torch.long
    )
    batch_strings = [ss_data.strings[i.item()] for i in prompt_idxs]

    emb_freqs = ss_data.emb_freqs[prompt_idxs, :, :]
    proj_freqs = ss_data.proj_freqs[prompt_idxs, :, :, :]
    ffwd_freqs = ss_data.ffwd_freqs[prompt_idxs, :, :, :]

    tokens = encoding_helpers.tokenize_strings(batch_strings)
    logits, _ = m(tokens)
    model_output = F.softmax(logits[:, -1, :], dim=-1)

    return (
        emb_freqs.detach(),
        proj_freqs.detach(),
        ffwd_freqs.detach(),
        model_output.detach(),
    )

In [None]:
torch.manual_seed(1337) # Ensure stable random values

# Initialize all the learnable params
emb_weight_param = torch.nn.Parameter(
    torch.randn(1, dtype=torch.float32), requires_grad=True
).to(device)
proj_weights_param = torch.nn.Parameter(
    torch.randn(n_layer, 1, dtype=torch.float32), requires_grad=True
).to(device)
ffwd_weights_param = torch.nn.Parameter(
    torch.randn(n_layer, 1, dtype=torch.float32), requires_grad=True
).to(device)

In [None]:
learning_rate = 3e-4
max_iters = 10000
eval_interval=500
eval_iters = 200
batch_size=100

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            emb_freqs, proj_freqs, ffwd_freqs, model_output = get_batch(batch_size, ss_data20k, split='train', train_pct=0.5)

            freqs = (
                (emb_weight_param * emb_freqs).sum(dim=1)
                + (proj_weights_param * proj_freqs.sum(dim=2)).sum(dim=1)
                + (ffwd_weights_param * ffwd_freqs.sum(dim=2)).sum(dim=1)
            )
            probs = freqs.float() / freqs.sum(dim=1, keepdim=True)

            loss = torch.norm(probs - model_output, p=2, dim=1).sum()
            losses[k] = loss.item()

        out[split] = losses.mean()
    return out

In [None]:
optimizer = torch.optim.AdamW([emb_weight_param, proj_weights_param, ffwd_weights_param], lr=learning_rate)

for step in tqdm(range(max_iters)):
    optimizer.zero_grad()

    emb_freqs, proj_freqs, ffwd_freqs, model_output = get_batch(batch_size, ss_data20k, split='train', train_pct=0.5)

    freqs = (
        (emb_weight_param * emb_freqs).sum(dim=1)
        + (proj_weights_param * proj_freqs.sum(dim=2)).sum(dim=1)
        + (ffwd_weights_param * ffwd_freqs.sum(dim=2)).sum(dim=1)
    )
    probs = freqs.float() / freqs.sum(dim=1, keepdim=True)

    loss = torch.norm(probs - model_output, p=2, dim=1).sum()

    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {step}, train loss: {losses['train']:.3f}, val_loss {losses['val']:.3f}")

    # Take a step
    loss.backward()

    optimizer.step()


  0%|          | 0/10000 [00:00<?, ?it/s]

step 0, train loss: 18.787, val_loss 18.656
step 500, train loss: 18.488, val_loss 18.526
step 1000, train loss: 18.462, val_loss 18.536
step 1500, train loss: 18.320, val_loss 18.076
step 2000, train loss: 18.156, val_loss 18.145
step 2500, train loss: 18.047, val_loss 18.269
step 3000, train loss: 18.095, val_loss 18.123
step 3500, train loss: 18.156, val_loss 18.288
step 4000, train loss: 18.011, val_loss 17.939
step 4500, train loss: 17.811, val_loss 18.024
step 5000, train loss: 18.163, val_loss 18.125
step 5500, train loss: 18.144, val_loss 18.114
step 6000, train loss: 18.147, val_loss 18.117
step 6500, train loss: 17.959, val_loss 18.084
step 7000, train loss: 18.106, val_loss 18.113
step 7500, train loss: 17.988, val_loss 17.954
step 8000, train loss: 18.221, val_loss 18.156
step 8500, train loss: 18.161, val_loss 18.105
step 9000, train loss: 18.026, val_loss 18.064
step 9500, train loss: 18.170, val_loss 18.122


In [None]:
# Try a run using these weights for everything

def next_token_freqs_progressive_learned_weights(
    prompt_idxs: torch.Tensor, ss_data: SimilarStringsFrequencyAndDistanceData
):
    emb_weight = emb_weight_param.data
    proj_weights = proj_weights_param.data
    ffwd_weights = ffwd_weights_param.data
    freqs = (
        (emb_weight * ss_data.emb_freqs[prompt_idxs, :]).sum(dim=1)
        + (proj_weights * ss_data.proj_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
        + (ffwd_weights * ss_data.ffwd_freqs[prompt_idxs, :, :, :].sum(dim=2)).sum(
            dim=1
        )
    )
    return freqs / freqs.sum(dim=1, keepdim=True)

try_next_token_freqs_function(
    next_token_freqs_progressive_learned_weights, ss_data20k, strings, model_outputs20k
)

Top 1 matches: 0.772
Top 1 matches (any order): 0.772
Top 2 matches: 0.403
Top 2 matches (any order): 0.458
Top 3 matches: 0.221
Top 3 matches (any order): 0.251
Top 4 matches: 0.149
Top 4 matches (any order): 0.156
Top 5 matches: 0.111
Top 5 matches (any order): 0.102
Top 6 matches: 0.088
Top 6 matches (any order): 0.065
Top 7 matches: 0.066
Top 7 matches (any order): 0.037
Top 8 matches: 0.052
Top 8 matches (any order): 0.021
Top 9 matches: 0.045
Top 9 matches (any order): 0.013
Top 10 matches: 0.033
Top 10 matches (any order): 0.006


This is definitely the best result for top 2 we've seen and very nearly the best for top1. 

In [None]:
emb_weight_param.data, proj_weights_param.data, ffwd_weights_param.data

(tensor([-0.0092]),
 tensor([[ 0.0223],
         [-0.1555],
         [ 0.0149],
         [ 0.0122],
         [-0.0932],
         [-0.0848]]),
 tensor([[-0.5732],
         [-0.7713],
         [-0.6631],
         [-0.7545],
         [-1.1689],
         [-2.0063]]))

In [None]:
sim_learned_weights = ModelSimulation(
    ss_data=ss_data20k,
    compute_next_token_freqs=next_token_freqs_progressive_learned_weights,
    encoding_helpers=encoding_helpers,
)

In [None]:
sim_learned_weights(['my most gr'])[0], model_outputs20k[ss_data20k.string_to_idx['my most gr']], sim20k(['my most gr'])[0]

([('a', 0.7675697803497314),
  ('i', 0.14597564935684204),
  ('o', 0.055422890931367874),
  ('e', 0.027286706492304802),
  ('c', 0.0012530256062746048),
  ('r', 0.0012431245995685458),
  ('l', 0.0012431245995685458),
  ('v', 0.0002444349229335785),
  ('d', 0.00012221746146678925),
  ('3', -0.0)],
 [('a', 0.4602494537830353),
  ('e', 0.35252559185028076),
  ('o', 0.09188850224018097),
  ('i', 0.09030349552631378),
  ('u', 0.004192721098661423),
  ('y', 0.0007521358784288168),
  ('r', 6.647213740507141e-05),
  ('l', 3.957989065384027e-06),
  ('v', 2.812936827467638e-06),
  ('w', 2.738903503995971e-06)],
 [('a', 0.74631267786026),
  ('i', 0.14454276859760284),
  ('o', 0.05604719743132591),
  ('e', 0.02654867246747017),
  ('n', 0.005899704992771149),
  ('c', 0.005899704992771149),
  ('v', 0.005899704992771149),
  ('d', 0.0029498524963855743),
  ('r', 0.0029498524963855743),
  ('l', 0.0029498524963855743)])

## Cosine Similarity Results

The cosine similarity data generation ran overnight. Let's look at the results.

In [None]:
ss_cos = SimilarStringsExperiment(
    exp10.output_dir / 'similar_strings_cos',
    encoding_helpers,
)
t_is = [7, 8, 9]
ss_cos_results = ss_cos.load_results_for_strings(strings, load_t_is=t_is)

In [None]:
t_is = [7, 8, 9]
ss_data_cos = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_cos_results,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
    largest=True,
)
try_next_token_freqs_function(
    next_token_freqs_progressive_ffwd_weight, ss_data_cos, strings, model_outputs20k
)

Top 1 matches: 0.750
Top 1 matches (any order): 0.750
Top 2 matches: 0.366
Top 2 matches (any order): 0.415
Top 3 matches: 0.191
Top 3 matches (any order): 0.207
Top 4 matches: 0.134
Top 4 matches (any order): 0.125
Top 5 matches: 0.096
Top 5 matches (any order): 0.077
Top 6 matches: 0.077
Top 6 matches (any order): 0.045
Top 7 matches: 0.060
Top 7 matches (any order): 0.024
Top 8 matches: 0.050
Top 8 matches (any order): 0.014
Top 9 matches: 0.040
Top 9 matches (any order): 0.008
Top 10 matches: 0.033
Top 10 matches (any order): 0.003


The results are pretty similar to the best we saw with Eucledian distance. But here's something interesting: let's flip the `largest` param so that it's not looking for the largest cosine similarity, but the smallest. This will mean we're looking at less similar samples.

In [None]:
t_is = [7, 8, 9]
ss_data_cos = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_cos_results,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
    largest=False,
)


In [None]:
try_next_token_freqs_function(
    next_token_freqs_progressive_ffwd_weight, ss_data_cos, strings, model_outputs20k
)

Top 1 matches: 0.783
Top 1 matches (any order): 0.783
Top 2 matches: 0.422
Top 2 matches (any order): 0.484
Top 3 matches: 0.233
Top 3 matches (any order): 0.271
Top 4 matches: 0.170
Top 4 matches (any order): 0.181
Top 5 matches: 0.128
Top 5 matches (any order): 0.119
Top 6 matches: 0.101
Top 6 matches (any order): 0.076
Top 7 matches: 0.083
Top 7 matches (any order): 0.041
Top 8 matches: 0.069
Top 8 matches (any order): 0.029
Top 9 matches: 0.059
Top 9 matches (any order): 0.018
Top 10 matches: 0.051
Top 10 matches (any order): 0.011


The results actually get BETTER! In fact, these are the best results yet. 

(I'm writing this up as if it were an intentional experiment but in fact I found this by accident. At first when I ran the cosine similarity results, I didn't have a `largest` parameter to pass into `SimilarStringsFrequencyAndDistanceData.from_results()`. So it was in fact doing the `largest=False` version. I got the results above. Then I fixed it to thread `largeest` through and the results got worse. And that inspired this line of thought.)

This suggests I've gotten something fundamentally wrong thus far. I've been trying methods to find more and more similar values: looking for more similar values amongst shorter strings, investigating cosine similarity as a potentially better metric. But maybe the model is actually a much wider net i.e. the predictions from a given embedding are an aggregate of similar values from a much wider range. 

This explains why efforts to produce more similar values yield worse results e.g. the results from the aggregation of t_is from 3-9 produces more similar values, but worse overall results. Perhaps there just isn't enough variety in the most similar values to produce final results that resemble the model's predictions.

Let's try a few experiments to test this hypothesis.

First, let's look at the cosine results with `largest=True` but only considering t_i=9:

In [None]:
t_is = [9]
ss_cos_results_only9 = ss_cos.load_results_for_strings(strings, load_t_is=t_is)

In [None]:
ss_data_cos = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_cos_results,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
    largest=True,
)

In [None]:
try_next_token_freqs_function(
    next_token_freqs_progressive_ffwd_weight, ss_data_cos, strings, model_outputs20k
)

Top 1 matches: 0.775
Top 1 matches (any order): 0.775
Top 2 matches: 0.397
Top 2 matches (any order): 0.452
Top 3 matches: 0.211
Top 3 matches (any order): 0.237
Top 4 matches: 0.143
Top 4 matches (any order): 0.142
Top 5 matches: 0.103
Top 5 matches (any order): 0.087
Top 6 matches: 0.082
Top 6 matches (any order): 0.055
Top 7 matches: 0.065
Top 7 matches (any order): 0.029
Top 8 matches: 0.054
Top 8 matches (any order): 0.021
Top 9 matches: 0.047
Top 9 matches (any order): 0.010
Top 10 matches: 0.038
Top 10 matches (any order): 0.005


This is better than the results from aggregating across t_is 7-9 with `largest=True`.

What happens if we try `largest=False` with t_i=9?

In [None]:
ss_data_cos = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_cos_results,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
    largest=False,
)
try_next_token_freqs_function(
    next_token_freqs_progressive_ffwd_weight, ss_data_cos, strings, model_outputs20k
)

Top 1 matches: 0.783
Top 1 matches (any order): 0.783
Top 2 matches: 0.422
Top 2 matches (any order): 0.484
Top 3 matches: 0.233
Top 3 matches (any order): 0.271
Top 4 matches: 0.170
Top 4 matches (any order): 0.181
Top 5 matches: 0.128
Top 5 matches (any order): 0.119
Top 6 matches: 0.101
Top 6 matches (any order): 0.076
Top 7 matches: 0.083
Top 7 matches (any order): 0.041
Top 8 matches: 0.069
Top 8 matches (any order): 0.029
Top 9 matches: 0.059
Top 9 matches (any order): 0.018
Top 10 matches: 0.051
Top 10 matches (any order): 0.011


This is better than `largest=True` but worse than when we aggregated ti_s 7-9. Probably because amonst the t_i=7 and t_i=8 values there were some less similar candidates. 

Let's try the Euclidean distance version, but flip `largest=True`. 

In [None]:
t_is=[7, 8, 9]
ss_data20k_aggr_backwards = SimilarStringsFrequencyAndDistanceData.from_results(
    ss_results=ss_results20k_all_t_is,
    next_token_map=next_token_map_all,
    aggregate_over_t_is=t_is,
    largest=True,
)
try_next_token_freqs_function(
    next_token_freqs_progressive_ffwd_weight, ss_data20k_aggr_backwards, strings, model_outputs20k
)

Top 1 matches: 0.783
Top 1 matches (any order): 0.783
Top 2 matches: 0.426
Top 2 matches (any order): 0.487
Top 3 matches: 0.236
Top 3 matches (any order): 0.275
Top 4 matches: 0.171
Top 4 matches (any order): 0.185
Top 5 matches: 0.126
Top 5 matches (any order): 0.120
Top 6 matches: 0.102
Top 6 matches (any order): 0.080
Top 7 matches: 0.082
Top 7 matches (any order): 0.045
Top 8 matches: 0.072
Top 8 matches (any order): 0.032
Top 9 matches: 0.060
Top 9 matches (any order): 0.019
Top 10 matches: 0.052
Top 10 matches (any order): 0.009


I tried this several times, and it seems the sweet spot is with t_is=[7, 8, 9]. If you include smaller t_is, the results get progressively worse. Which suggests there is some upper bound to the distance that yields best results. Let's see if we can hone in on the sweet spot. 

In [None]:
print_sim_strings(ss_results20k_all_t_is['my most gr'], aggregate_over_t_is=[7, 8, 9], largest=True)

Proj Outputs
    'om is gr' (3.76)    ' dost gi' (4.09)    'st my gr' (6.16)    'must giv' (6.67)    'ost my c' (7.47)  'my merry m' (5.45)
    'at is gr' (3.75)    'uldst gi' (4.09)    ' some gr' (6.10)    ' many ge' (6.64)    'my own d' (7.46)  'my most he' (5.34)
    'on my gu' (3.75)    'heart go' (4.07)    ' more gr' (6.00)    ' many gu' (6.64)    'ms thy c' (7.41)   'm my moth' (5.25)
    ' dost gu' (3.72)    'canst gi' (4.04)    'ommon gr' (5.95)    'most gui' (6.42)    'y most g' (7.41)  'mt my mast' (5.22)
    's not gu' (3.71)    ' last go' (4.03)    ' most gu' (5.83)   ', most gr' (6.13)    'most gen' (7.38)  'm thy moth' (5.22)
    ' most gr' (3.68)    ' most go' (4.03)    ' most go' (5.79)  'e, most gr' (6.00)    's most r' (7.25)    'm thy mo' (5.21)
    'or my gu' (3.68)    ' most gi' (4.00)   ' most\ngl' (5.57)  'e; most go' (5.96)    'e most r' (7.16)    'my misfo' (5.20)
    'nt is gu' (3.67)   '\nMust gi' (3.94)   's more gr' (5.57)   'o most go' (5.81)    ' most gr'

In [None]:
print_sim_strings(ss_results20k_all_t_is['my most gr'], aggregate_over_t_is=[6, 7, 8, 9], largest=True)

Proj Outputs
    '.\nKing ' (4.46)     ' yet gr' (4.76)    'st my gr' (6.16)    'must giv' (6.67)    'ost my c' (7.47)  'my merry m' (5.45)
     'm to gi' (4.46)    '\nLet go' (4.75)     'some gr' (6.14)    ' many ge' (6.64)    'my own d' (7.46)  'my most he' (5.34)
     ', or gi' (4.46)     'most gi' (4.75)    ' some gr' (6.10)    ' many gu' (6.64)     ' most g' (7.41)   'm my moth' (5.25)
     'k to gi' (4.46)     'past gr' (4.73)     'm my gr' (6.06)    'most gui' (6.42)    'ms thy c' (7.41)  'mt my mast' (5.22)
     '; to gr' (4.46)     'Most go' (4.72)    ' more gr' (6.00)     'm my gr' (6.37)    'y most g' (7.41)  'm thy moth' (5.22)
    '\nFor gi' (4.46)     'fast gr' (4.72)    'ommon gr' (5.95)     'more gr' (6.31)    'most gen' (7.38)    'm thy mo' (5.21)
     'm is gr' (4.44)     'must gr' (4.71)    ' most gu' (5.83)     'most gi' (6.31)    'most\ngl' (7.37)    'my misfo' (5.20)
     ', so gi' (4.44)    '\nBut gr' (4.70)     'most go' (5.80)     'must gr' (6.16)     ' most y'

This is a pretty crude analysis because we only have what we thought were the top 10 most similar values and we're flipping the whether we're looking for smallest vs largest. But these numbers give a ballpark of the distances that are worth considering. 

## Analysis of whether there are similar strings of smaller length

In [None]:
slens = [3, 5, 9, 10]

In [None]:
strings_n = {
    n: all_unique_substrings(ts.text, n)
    for n in slens
}

In [None]:
next_token_maps = {
    n: build_next_token_map(
        ts.text, prefix_len=n, vocab_size=tokenizer.vocab_size, stoi=tokenizer.stoi
    )
    for n in slens
}

In [None]:
for n in slens:
    if list(Path(f'../artifacts/block_internals_results/large_files/slen{n}/').glob('*')) == []:
        print(f"Run `make block_internals_slen{n}_dataset` in the project root to generate the required dataset")

In [None]:
exps = {
    n: BatchedBlockInternalsExperiment(
        eh=encoding_helpers,
        accessors=accessors,
        strings=strings_n[n],
        output_dir=Path(f'../artifacts/block_internals_results/large_files/slen{n}/'),
        batch_size=10000,
    )
    for n in slens
}

In [None]:
def compare_similar_strings(
    prompt: str,
    block_idx: int,
    compare_to_slen: int,
):
    prompt_accessors = BlockInternalsAccessors(prompt, encoding_helpers, accessors)

    sim_strings_comp, distances_comp = exps[compare_to_slen].strings_with_topk_closest_proj_outputs(
        block_idx=block_idx,
        t_i=-1,
        queries=prompt_accessors.proj_output(block_idx=block_idx)[:, -1, :],
        k=10,
        largest=False,
    )

    sim_strings, distances = exps[len(prompt)].strings_with_topk_closest_proj_outputs(
        block_idx=block_idx,
        t_i=-1,
        queries=prompt_accessors.proj_output(block_idx=block_idx)[:, -1, :],
        k=10,
        largest=False,
    )

    print(f"Length {compare_to_slen} similars:   Length {len(prompt)} similars:")
    for i in range(10):
        print(f'{repr(sim_strings_comp[0][i])} {distances_comp[i].item():.3f}        {repr(sim_strings[0][i])} {distances[i].item():.3f}')



After running this a bunch of times, my conclusions from this are:

* Yes it's possible to find similar strings with clear patterns in shorter strings
* The distances are greater when the strings are shorter
* But seeing which shorter strings are similar is illuminating

e.g. for block_idx = 1, looking at similar strings of length 5 vs same length as the prompt:

In [None]:
prompt = 'my most gr'
compare_similar_strings(prompt, block_idx=1, compare_to_slen=5)

Length 5 similars:   Length 10 similars:
'st go' 4.564        'my most gr' 0.000
'ot gl' 4.591        'ur most gr' 0.949
'st ga' 4.595        'ne most gr' 0.958
'ot ga' 4.609        'he most gr' 1.053
'st gl' 4.629        'is most gr' 1.056
'ot gr' 4.638        'e, most gr' 1.268
'st gr' 4.639        'o, must gr' 1.352
'rt go' 4.647        't, most gr' 1.361
'st gi' 4.662        'be past gr' 1.372
'et go' 4.679        'yet not gr' 1.506


All the length 5 similar strings have `s t`` in common. 

Now look at block 5:

In [None]:
prompt = 'my most gr'
compare_similar_strings(prompt, block_idx=5, compare_to_slen=5)

Length 5 similars:   Length 10 similars:
'my mo' 4.464        'my most gr' 0.000
'my st' 4.917        'my most st' 3.911
'my br' 4.969        'my most sa' 4.328
'my tr' 4.977        ' my most r' 4.556
'my gr' 5.039        ' my most l' 5.160
'my sw' 5.115        'my high bl' 5.198
'my sc' 5.132        'm thy moth' 5.221
'my wr' 5.140        'mt my mast' 5.225
'my bl' 5.215        'my most he' 5.338
'my tw' 5.267        'my merry m' 5.445


Here the common pattern in the length 5 strings is `my `.

I think this says something. The closest length 5 strings could have been any substring of the full prompt. Seeing what gets picked must be meaningful.

Let's try with length 3:

In [None]:
prompt = 'my most gr'
compare_similar_strings(prompt, block_idx=1, compare_to_slen=3)

Length 3 similars:   Length 10 similars:
' gn' 6.114        'my most gr' 0.000
' gy' 6.142        'ur most gr' 0.949
' gl' 6.156        'ne most gr' 0.958
' gr' 6.166        'he most gr' 1.053
' gu' 6.202        'is most gr' 1.056
' gh' 6.218        'e, most gr' 1.268
' go' 6.231        'o, must gr' 1.352
' ga' 6.232        't, most gr' 1.361
' ge' 6.267        'be past gr' 1.372
' gi' 6.397        'yet not gr' 1.506


Greater distance but still, a pattern. And block 5:

In [None]:
prompt = 'my most gr'
compare_similar_strings(prompt, block_idx=5, compare_to_slen=3)

Length 3 similars:   Length 10 similars:
'mys' 6.682        'my most gr' 0.000
'my-' 6.722        'my most st' 3.911
'ms-' 7.015        'my most sa' 4.328
'myr' 7.065        ' my most r' 4.556
'mso' 7.115        ' my most l' 5.160
"my'" 7.123        'my high bl' 5.198
'my?' 7.148        'm thy moth' 5.221
'mfu' 7.158        'mt my mast' 5.225
'moc' 7.168        'my most he' 5.338
"ms'" 7.214        'my merry m' 5.445


A different pattern, but still a pattern.

Let's see if we can do this with just the length 10 data and not have to generate all the block internals data for the other string lengths from scratch. 

As a prereq, let's first see if the intermediate values for substrings within a longer string are the same as the values that would have been produced for those substrings on their own.

In [None]:
# Show that for the common letters, the intermediates for a substring
# are the same as those in a longer string.
prompt_short = 'my mo'
prompt_long = 'my most gr'

bia_short = BlockInternalsAccessors(prompt_short, encoding_helpers, accessors)
bia_long = BlockInternalsAccessors(prompt_long, encoding_helpers, accessors)

for t_i in range(len(prompt_short)):
    for block_idx in range(n_layer):
        test_close(
            bia_short.proj_output(block_idx=block_idx)[0, t_i, :],
            bia_long.proj_output(block_idx=block_idx)[0, t_i, :],
        )
        test_close(
            bia_short.ffwd_output(block_idx=block_idx)[0, t_i, :],
            bia_long.ffwd_output(block_idx=block_idx)[0, t_i, :],
        )

It passes, so this shows we can use the values at the other t_i's from the slen10 dataset. Let's try it. 

In [None]:
full_slen=10
target_exp = exps[full_slen]

In [None]:
def compare_similar_strings2(
    prompt: str,
    block_idx: int,
    compare_to_slen: int,
):
    """Same as compare_similar_strings() above but just uses a single experiment."""
    prompt_accessors = BlockInternalsAccessors(prompt, encoding_helpers, accessors)

    # indexing `compare_to_slen - 1` below because slicers2 is indexed
    # by t_i, not string length
    sim_strings_comp, distances_comp = target_exp.strings_with_topk_closest_proj_outputs(
        block_idx=block_idx,
        t_i=compare_to_slen - 1,
        queries=prompt_accessors.proj_output(block_idx=block_idx)[:, -1, :],
        k=10,
        largest=False,
    )

    sim_strings, distances = exps[len(prompt)].strings_with_topk_closest_proj_outputs(
        block_idx=block_idx,
        t_i=-1,
        queries=prompt_accessors.proj_output(block_idx=block_idx)[:, -1, :],
        k=10,
        largest=False,
    )

    print(f"Length {compare_to_slen} similars:   Length {len(prompt)} similars:")
    for i in range(10):
        print(f'{repr(sim_strings_comp[0][i][:compare_to_slen])} {distances_comp[i].item():.3f}        {repr(sim_strings[0][i])} {distances[i].item():.3f}')


In [None]:
prompt = 'my most gr'
compare_similar_strings2(prompt, block_idx=5, compare_to_slen=5)

Length 5 similars:   Length 10 similars:
'my mo' 4.464        'my most gr' 0.000
'my st' 4.917        'my most st' 3.911
'my br' 4.969        'my most sa' 4.328
'my tr' 4.977        ' my most r' 4.556
'my gr' 5.039        ' my most l' 5.160
'my sw' 5.115        'my high bl' 5.198
'my sc' 5.132        'm thy moth' 5.221
'my wr' 5.140        'mt my mast' 5.225
'my bl' 5.215        'my most he' 5.338
'my tw' 5.267        'my merry m' 5.445


Notice that this is the same as the output above when we compared to a length 5 experiment's outputs!

And now we can do it for other lengths without having to run experiments for all of them.

In [None]:
prompt = 'my most gr'
compare_similar_strings2(prompt, block_idx=5, compare_to_slen=9)

Length 9 similars:   Length 10 similars:
'my most r' 2.754        'my most gr' 0.000
'my most l' 3.517        'my most st' 3.911
'my most h' 3.795        'my most sa' 4.328
'm my mout' 4.802        ' my most r' 4.556
'm thy mot' 4.808        ' my most l' 5.160
'm, my mot' 4.832        'my high bl' 5.198
'y most gr' 4.976        'm thy moth' 5.221
'my most g' 5.051        'mt my mast' 5.225
'mes my br' 5.177        'my most he' 5.338
'm my moth' 5.252        'my merry m' 5.445


## Try out loading with mmap_mode

>Code in this section doesn't run anymore because the Slicer doesn't exist and some other internal changes have been made based on the experiments here, but I'm leaving this in for the historical record. 

Set up a query:

In [None]:
prompt = 'my most gr'
prompt_accessors = BlockInternalsAccessors(prompt, encoding_helpers, accessors)
query = prompt_accessors.proj_output(block_idx=0)[:, -1, :]

Time loading a full batch from slen10 the regular way and subtracting the query:

In [None]:
%%timeit
batch = torch.load(exps[10].output_dir / 'proj_output-000-00.pt')
batch - query

28.4 ms ± 348 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Now load with `mmap=True` and try again:

In [None]:
%%timeit
batch = torch.load(str(exps[10].output_dir / 'proj_output-000-00.pt'), mmap=True)
batch - query

11.8 ms ± 56.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Ooh, it's way faster. Let's see if we can load multiple batches, cat them and do the subtraction:

In [None]:
%%timeit
batch1 = torch.load(str(exps[10].output_dir / 'proj_output-000-00.pt'), mmap=True)
batch2 = torch.load(str(exps[10].output_dir / 'proj_output-001-00.pt'), mmap=True)
batch3 = torch.load(str(exps[10].output_dir / 'proj_output-002-00.pt'), mmap=True)

big_batch = torch.cat([batch1, batch2, batch3])
big_batch - query

52.1 ms ± 637 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Wow. OK, this is a big deal. We did 3 batches in 51ms vs 28.ms for just one batch in the regular way. And I suspect this scales non-linearly. Let's try it with all the batches:


In [None]:
%%timeit
big_batch = torch.cat([
    torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exps[10].n_batches)
])
big_batch - query


8.88 s ± 294 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


So 8.88s for 86 batches. Memory usage peaked at 25GB during the run but went up and down and settled back down to where it was before the run started. 

But that's 103ms per batch which seems slower than just loading each batch one at a time. But the computation is simpler (no need for `topk_across_batches()` etc). 

Let's try running through all the batches the old way (no mmap) and time it.

In [None]:
%%timeit
for batch_idx in range(exps[10].n_batches):
    batch = torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0))) # no mmap
    batch - query

3.32 s ± 157 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


So this is actually faster. But let's try to do more of the complete operation and do multiple queries:

In [None]:
# Define prompts and extract the query values
prompts = ['First Citi', 'Citizen:\nB', 'Shyamalan ', 'more in jo']
prompts_exp = BlockInternalsExperiment(encoding_helpers, accessors, prompts)

t_i = -1

queries = prompts_exp.proj_output(block_idx=0)[:, t_i, :]

Time doing the equivalent of strings_with_topk_closest_ffwd_outputs() on the mmaped data:

In [None]:
%%timeit
big_batch = torch.cat([
    torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exps[10].n_batches)
])

n_queries, _ = queries.shape
B, T, _ = big_batch.shape
distances = torch.norm(
    big_batch[:, t_i, :].reshape(B, 1, -1).expand(-1, n_queries, -1) - queries,
    dim=2
)
topk = torch.topk(distances, k=10, dim=0, largest=False)
exps[10].strings_from_indices(topk.indices), topk.values

6.51 s ± 395 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


OK, that is slow. But I suspect there is some overhead in loading the data the first time. What if we load the data once and then process the queries on the loaded data:

In [None]:
big_batch = torch.cat([
    torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exps[10].n_batches)
])

In [None]:
%%timeit
n_queries, _ = queries.shape
B, T, _ = big_batch.shape
distances = torch.norm(
    big_batch[:, t_i, :].reshape(B, 1, -1).expand(-1, n_queries, -1) - queries,
    dim=2
)
topk = torch.topk(distances, k=10, dim=0, largest=False)
exps[10].strings_from_indices(topk.indices), topk.values

303 ms ± 8.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Compare that to doing it with a slicer.

In [None]:
%%timeit
slicers[10].strings_with_topk_closest_proj_outputs(block_idx=0, queries=queries, k=10, largest=False)

471 ms ± 3.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


We don't have the required methods on the exp class anymore, but if we want to test how long this would have taken without the slicer i.e. on the same data that the mmap version is using but without using mmap, we can resurrect the relevant code:

In [None]:
# Copy of the code from block_internals, just so we can run it below
def batch_distances(batch: torch.Tensor, queries: torch.Tensor) -> torch.Tensor:
    """Returns the distance between each item in the batch and the queries."""
    assert batch.dim() == 2, f"batch.dim() should be 2, was {batch.dim()}"
    assert queries.dim() == 2, f"query.dim() should be 2, was {queries.dim()}"
    assert (
        batch.shape[-1] == queries.shape[-1]
    ), f"last dimension of batch was {batch.shape[-1]}, which does not match last dimension of queries {queries.shape[-1]}"

    B, _ = batch.shape
    n_queries, _ = queries.shape

    distances = torch.norm(
        # Reshape the batch to a singleton dimension, then expand that dimension
        # by the number of queries. We can then subtract all the queries in one
        # go.
        batch.reshape(B, 1, -1).expand(-1, n_queries, -1) - queries,
        dim=2
    )
    return distances


In [None]:
%%timeit
n_queries, _ = queries.shape
values, indices = topk_across_batches(
    n_batches=exps[10].n_batches,
    k=10,
    largest=False,
    load_batch=lambda i: torch.load(exps[10]._proj_output_filename(i, block_idx=0))[:, t_i, :],
    process_batch=lambda batch: batch_distances(batch, queries=queries),
)
exps[10].strings_from_indices(indices), values


2.18 s ± 29.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The mmap way is a faster than the slicer (303ms vs 471ms) and doesn't require materializing the slices. And it's waaaay faster than doing it without the slicer and without mmap (303ms vs 2.18s).

In summary: 
It seems there is a one-time cost to loading all the batches via: 

```python
big_batch = torch.cat([
    torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exps[10].n_batches)
])
```

But this doesn't take up too much memory (fresh Jupyter kernel running just the stuff in this and the previous section has about 13 GB of memory per Activity Monitor). So we can load this once and run a lot of queries very fast. 

Let's check that the results are correct:

In [None]:
# Do for real so we can compare the results
big_batch = torch.cat([
    torch.load(str(exps[10]._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exps[10].n_batches)
])

n_queries, _ = queries.shape
B, T, _ = big_batch.shape
distances = torch.norm(
    big_batch[:, t_i, :].reshape(B, 1, -1).expand(-1, n_queries, -1) - queries,
    dim=2
)
topk = torch.topk(distances, k=10, dim=0, largest=False)
exps[10].strings_from_indices(topk.indices), topk.values

sim_strings, distances = slicers[10].strings_with_topk_closest_proj_outputs(block_idx=0, queries=queries, k=10, largest=False)
test_eq(exps[10].strings_from_indices(topk.indices), sim_strings)
test_close(topk.values, distances)

These test pass, so the output is the same!

### Perf tests for using mmap with the slicer

The analysis above showed that using mmap on the raw batch data is faster than using the slicer. But yesterday I found that just setting mmap=True on the load_batch function when finding closest embeddings made a huge difference. Let's try the same thing for the slicer and see if it makes a difference.

In [None]:
t_i = -1
queries = prompts_exp.proj_output(block_idx=0)[:, t_i, :]

Though we have a measurement for using the slicer above, let's just replicate it for completeness. Ran this line before making any changes:

In [None]:
%%timeit
slicers[10].strings_with_topk_closest_proj_outputs(block_idx=0, queries=queries, k=10, largest=False)

482 ms ± 14.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


OK, that's in line with the measurement above. Now let's try it after changing the implementation to use `mmap=True`:

In [None]:
%%timeit
slicers[10].strings_with_topk_closest_proj_outputs(block_idx=0, queries=queries, k=10, largest=False)

361 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


OK, so it got faster, but it's still not as fast as using the raw batch data with `mmap=True`. So we'll go with that. 

## Perf tests for finding closest embeddings

In [None]:
exp10 = BatchedBlockInternalsExperiment(
    eh=encoding_helpers,
    accessors=accessors,
    strings=strings_n[10],
    output_dir=Path(f'../artifacts/block_internals_results/large_files/slen10/'),
    batch_size=10000,
)

Measurement before any changes

In [None]:
%%timeit
exp10.strings_with_topk_closest_embeddings(queries=prompts_exp.embeddings, k=10, largest=False)

4.49 s ± 9.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Simulate loading all the batch data at once and then finding the topk closest strings for all the queries:

In [None]:
embeddings_data = torch.cat([
    torch.load(str(exp10._embeddings_filename(batch_idx=batch_idx)), mmap=True)
    for batch_idx in range(exp10.n_batches)
])

In [None]:
# Equivalent of topk_closest_embeddings - slow
B, _, _ = embeddings_data.shape

n_queries, _, _ = prompts_exp.embeddings.shape

distances = batch_distances(
    embeddings_data.reshape(B, -1),
    prompts_exp.embeddings.reshape(n_queries, -1)
)
topk = torch.topk(distances, dim=0, k=k, largest=False)
exp10.strings_from_indices(topk.indices), topk.values


The above was really slow and took a ton of memory.

Verify that the equivalent thing on the proj_out data is still fast:

In [None]:
proj_out_data = torch.cat([
    torch.load(str(exp10._proj_output_filename(batch_idx=batch_idx, block_idx=0)), mmap=True)
    for batch_idx in range(exp10.n_batches)
])

In [None]:
n_queries, _ = queries.shape
B, T, _ = proj_out_data.shape
distances = torch.norm(
    proj_out_data[:, t_i, :].reshape(B, 1, -1).expand(-1, n_queries, -1) - queries,
    dim=2
)
topk = torch.topk(distances, k=10, dim=0, largest=False)
exp10.strings_from_indices(topk.indices), topk.values


Yes it is. 

Suspecting the issue is the reshape needed. Show that we can calculate the norm we want without the reshape. 

What the reshape does is stack the embeddings across the T dimension. We have embeddings

$$
e_1, e_2, \ldots, e_T \in \mathbb{R}^{n\_embed}
$$

By stacking them, we get one big embedding:

$$
e_{1:T} \in \mathbb{R}^{T * n\_embed}
$$

We do the same with the queries:

$$
q_1, q_2, \ldots, q_T \in \mathbb{R}^{n\_embed} \rightarrow
q_{1:T} \in \mathbb{R}^{T * n\_embed}
$$

We then want to compute

$$
\Vert e_{1:T} - q_{1:T} \Vert_2 = \sqrt{\sum_{i=1}^{T * n\_embed} (e_{1:T} - q_{1:T})_i^2}
$$

Can we get to this if we only have $e_1, e_2, \ldots, e_T$ and $q_1, q_2, \ldots, q_T$? Yes, we can.

$$
\begin{align}
\Vert e_{1:T} - q_{1:T} \Vert_2 &= \sqrt{\sum_{i=1}^{T * n\_embed} (e_{1:T} - q_{1:T})_i^2} \\
\Vert e_{1:T} - q_{1:T} \Vert_2^2 &= \sum_{i=1}^{T * n\_embed} (e_{1:T} - q_{1:T})_i^2 \\
&=\sum_{i=1}^{n\_embed}(e_{1:T} - q_{1:T})_i^2 + \sum_{i=n\_embed+1}^{2*n\_embed}(e_{1:T} - q_{1:T})_i^2 + \ldots + \sum_{i=(T-1)*n\_embed+1}^{T*n\_embed}(e_{1:T} - q_{1:T})_i^2 \\
&=\sum_{i=1}^{n\_embed}(e_{1} - q_{1})_i^2 + \sum_{i=1}^{n\_embed}(e_{2} - q_{2})_i^2 + \ldots + \sum_{i=1}^{n\_embed}(e_{T} - q_{T})_i^2 \\
&=\Vert e_1 - q_1 \Vert_2^2 + \Vert e_2 - q_2 \Vert_2^2 + \ldots + \Vert e_T - q_T \Vert_2^2 \\
&=\sum_{i=1}^{T}\Vert e_i - q_i \Vert_2^2
\end{align}
$$

So 

$$
\begin{align}
\Vert e_{1:T} - q_{1:T} \Vert_2 &= \sqrt{\sum_{i=1}^{T}\Vert e_i - q_i \Vert_2^2}
\end{align}
$$

Let's check it in code. 

In [None]:
# Show that we can effectively compute the norm without reshaping
x = torch.randn((100, 5, 384))
B, T, _ = x.shape
q = torch.randn(5, 384)

norm1 = torch.norm(x.reshape(B, -1) - q.reshape(-1), dim=-1)
norm2 = (torch.norm(x - q, dim=-1) ** 2).sum(dim=-1).sqrt()
test_close(norm1, norm2)


Now let's do it with the embeddings

In [None]:
# Version without reshaping
B, _, _ = embeddings_data.shape

n_queries, _, _ = prompts_exp.embeddings.shape

distances = (
    torch.norm(embeddings_data.unsqueeze(dim=1).expand(-1, n_queries, -1, -1) - prompts_exp.embeddings, dim=-1) ** 2
).sum(dim=-1).sqrt()

topk = torch.topk(distances, dim=0, k=10, largest=False)
exp10.strings_from_indices(topk.indices), topk.values


: 

The above used so much memory that it crashed the kernel. So this is a no go. I think the fundamental problem is we're computing over a lot more data: all elements of the T dimension vs just one with the proj_outputs/ffwd_outputs. So let's go back to the original way of doing it in batches, but let's see if it helps to load "super batches" by combining several of the batches on disk into one batch in memory.

In [None]:
# Parameters for creating super batches

k=10
largest=False
combine_n_batches = 5

batch_size = exp10.batch_size * combine_n_batches
n_batches = math.ceil(len(exp10.strings) / batch_size)
queries = prompts_exp.embeddings


In [None]:
%%timeit
def _load_batch(batch_idx: int):
    start = batch_idx * combine_n_batches
    end = min((batch_idx + 1) * combine_n_batches, exp10.n_batches)

    batch = torch.cat([
        torch.load(
            str(exp10._embeddings_filename(batch_idx=i)),
            mmap=True,
        )
        for i in range(start, end)
    ])
    return batch

n_queries, _, _ = queries.shape

def _process_batch(batch: torch.Tensor) -> torch.Tensor:
    B, _, _ = batch.shape
    # Batch and queries and both shape (B, s_len, n_embed).
    # For the purposes of finding the closest values, we
    # reshape both the batch and queries to eliminate the
    # s_len dimension, effectively concatenating all the
    # embedding tensors across positions.
    return batch_distances(batch.reshape(B, -1), queries.reshape(n_queries, -1))

values, indices = topk_across_batches(
    n_batches=n_batches,
    k=k,
    largest=largest,
    load_batch=_load_batch,
    process_batch=_process_batch,
)

exp10.strings_from_indices(indices), values

4.19 s ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


This helps only a tiny amount. And it seems to get faster the fewer number of batches we combine. So let's just do it with one batch at a time. I added code to load the one batch in the existing implementation with mmap=True and timed it:

In [None]:
%%timeit
exp10.strings_with_topk_closest_embeddings(queries=prompts_exp.embeddings, k=10, largest=False)

3.34 s ± 46.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


This is the best result so far so we'll go with this. 