# Figures for "The Developmental Landscape of In-Context Learning"

TODO: add a description

In [1]:
import os
from dotenv import load_dotenv

load_dotenv();

In [1]:
import logging
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
import seaborn as sns
from copy import deepcopy 

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.nn import functional as F
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import os
from typing import Optional
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from icl.analysis.utils import get_unique_run, get_unique_config
from icl.constants import FIGURES, SWEEPS, DATA
from icl.figures.notation import str_d_dlogt, str_d_dt, str_dlog_dlogt
from icl.figures.colors import plot_transitions, gen_transition_colors, get_transition_type, PRIMARY, SECONDARY, TERTIARY, BRED, BBLUE, BRED, BGREEN
from icl.constants import DEVICE
from icl.figures.plotting import WIDTH, HEIGHT, FULL_WIDTH, FULL_HEIGHT
from icl.monitoring import stdlogger

MODELS_ID = "L2H4Minf"
LLC_SWEEP_ID = "hmy71gjb"

stdlogger.setLevel(logging.DEBUG)

FIGURES





PosixPath('/Users/Jesse/Projects/icl/figures')

In [2]:
# shorthands
BATCH_SIZE = 8192
K = 8
D = 4

runs = [get_unique_run(
    str(SWEEPS / "training-runs/L2H4Minf.yaml"), 
    task_config={"model_seed": model_seed, "layer_norm": True},
) for model_seed in range(5)]

steps = set(runs[0].checkpointer.file_ids)

# for run in runs[1:]:
#     if steps != set(run.checkpointer.file_ids):
#         stdlogger.warning("Not all runs have the same checkpoints. Using intersection.")

#     steps.intersection_update(run.checkpointer.file_ids)

steps = list(steps)
num_steps = len(steps)
num_steps

  ws_hat = torch.linalg.solve(LHS, RHS)   # BKDD^-1 @ BKD1 -> B K D 1


190

# Retrieve checkpoints

In [4]:
all_models = []
all_optimizer_state_dicts = []

if os.path.exists(DATA / MODELS_ID / 'models.pt'):
    stdlogger.info(f"Loading models from {DATA / MODELS_ID}")
    all_models = torch.load(DATA / MODELS_ID / 'models.pt')
    all_optimizer_state_dicts = torch.load(DATA / MODELS_ID / 'optimizer_state_dicts.pt')
else:
    stdlogger.info(f"Retrieving models from bucket")

    pbar = tqdm.tqdm(runs)

    for run in pbar:
        models = []
        optimizer_state_dicts = []

        for i, step in enumerate(steps):
            checkpoint = run.checkpointer.load_file(step)
            m = deepcopy(run.model)
            m.load_state_dict(checkpoint["model"])
            models.append(m)
            optimizer_state_dicts.append(checkpoint["optimizer"])

            pbar.set_description(f"Checkpoint {i}/{num_steps}")
            
        all_models.append(models)
        all_optimizer_state_dicts.append(optimizer_state_dicts)

    stdlogger.info(f"Saving models to {DATA / MODELS_ID}")
    
    with open(DATA / MODELS_ID / 'models.pt', 'wb') as f:
        torch.save(all_models, f)

    with open(DATA / MODELS_ID / 'optimizer_state_dicts.pt', 'wb') as f:
        torch.save(all_optimizer_state_dicts, f)

  0%|          | 0/5 [00:00<?, ?it/s]

Checkpoint 99/190:  40%|████      | 2/5 [03:26<04:01, 80.63s/it] 

In [16]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params} ({total_params//1e6}M)")
    return total_params

count_parameters(all_models[0][-1]);

+----------------------------------------------------------------+------------+
|                            Modules                             | Parameters |
+----------------------------------------------------------------+------------+
|       token_sequence_transformer.token_embedding.weight        |    320     |
|       token_sequence_transformer.postn_embedding.weight        |    1024    |
| token_sequence_transformer.blocks.0.attention.attention.weight |   12288    |
|  token_sequence_transformer.blocks.0.attention.output.weight   |    4096    |
|      token_sequence_transformer.blocks.0.compute.0.weight      |    4096    |
|       token_sequence_transformer.blocks.0.compute.0.bias       |     64     |
|      token_sequence_transformer.blocks.0.compute.2.weight      |    4096    |
|       token_sequence_transformer.blocks.0.compute.2.bias       |     64     |
|    token_sequence_transformer.blocks.0.layer_norms.0.weight    |     64     |
|     token_sequence_transformer.blocks.

# Evals

## Behavioral (loss, delta ridge, etc.)

In [7]:
from icl.figures.derivatives import d_dt, d_dlogt, dlog_dlogt

def add_slopes(df, column, model_seed):
    seed_subset = df[df.model_seed == model_seed].sort_values("step")
    _steps = seed_subset['step'].values

    d_dts = d_dt(_steps, seed_subset[column].values)
    d_dlogts = d_dlogt(_steps, seed_subset[column].values)
    dlog_dlogts = dlog_dlogt(_steps, seed_subset[column].values)
    
    for step, _d_dt, _d_dlogt, _dlog_dlogt, in zip(seed_subset['step'], d_dts, d_dlogts, dlog_dlogts):
        df.loc[((df.step == step) & (df.model_seed==model_seed)), f"{column}/d_dt"] = _d_dt
        df.loc[((df.step == step) & (df.model_seed==model_seed)), f"{column}/d_dlogt"] = _d_dlogt
        df.loc[((df.step == step) & (df.model_seed==model_seed)), f"{column}/dlog_dlogt"] = _dlog_dlogt
    

In [17]:
from icl.analysis.evals import ICLEvaluator

torch.manual_seed(0)

run = runs[0]

evaluator = ICLEvaluator(
    pretrain_dist=run.pretrain_dist,
    true_dist=run.true_dist,
    max_examples=run.config.task_config.max_examples,
    eval_batch_size=BATCH_SIZE,
    seed=run.config.task_config.true_seed,  # type: ignore 
)

evals_over_time_df = pd.DataFrame()

if os.path.exists(DATA / MODELS_ID / "evals_over_time.pt"):
    stdlogger.info("Loading evals from disk")
    with open(DATA / MODELS_ID / "evals_over_time.pt", 'rb') as f:
        evals_over_time = torch.load(f)
else:
    stdlogger.info("Running evals")
    evals_over_time = [{**evaluator(model), "step": step, "model_seed": i} for i, _models in enumerate(all_models) for step, model in zip(steps, tqdm.tqdm(_models))]
    evals_over_time_df = pd.DataFrame(evals_over_time)

    stdlogger.info("Calculating weight norms")
    for i, _models in enumerate(all_models):
        for step, model in zip(steps, _models):
            evals_over_time_df.loc[((evals_over_time_df.step == step) & (evals_over_time_df.model_seed==i)), "weight/norm"] = (sum(torch.norm(p) ** 2 for p in model.parameters()) ** 0.5).item()

    stdlogger.info("Calculating derivatives")
    for i in evals_over_time_df.model_seed.unique():
        for column in ['pretrain/mse', "weight/norm"]:
            add_slopes(evals_over_time_df, column, i)

    stdlogger.info("Saving evals to disk")
    with open(DATA / MODELS_ID / "evals_over_time.pt", 'wb') as f:
        torch.save(evals_over_time, f)

evals_over_time_df

 99%|█████████▉| 189/190 [08:47<00:02,  2.79s/it]
 99%|█████████▉| 189/190 [08:46<00:02,  2.79s/it]
 99%|█████████▉| 189/190 [08:52<00:02,  2.82s/it]
 99%|█████████▉| 189/190 [08:57<00:02,  2.84s/it]
 99%|█████████▉| 189/190 [08:49<00:02,  2.80s/it]


Unnamed: 0,pretrain/mse_subsequence,pretrain/mse_subseq/token/0,pretrain/mse_subseq/token/1,pretrain/mse_subseq/token/2,pretrain/mse_subseq/token/3,pretrain/mse_subseq/token/4,pretrain/mse_subseq/token/5,pretrain/mse_subseq/token/6,pretrain/mse_subseq/token/7,pretrain/mse_subseq,...,true/delta_ridge,step,model_seed,weight/norm,pretrain/mse/d_dt,pretrain/mse/d_dlogt,pretrain/mse/dlog_dlogt,weight/norm/d_dt,weight/norm/d_dlogt,weight/norm/dlog_dlogt
0,4.264059,4.223984,4.360218,4.157660,4.104471,4.213433,4.166513,4.175281,4.247562,4.206140,...,2.536717,0,0,24.909348,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000
1,4.222313,4.223981,4.360214,4.157659,4.104469,4.213432,4.166512,4.175280,4.247561,4.206139,...,2.536715,1,0,24.909348,-2.622604e-06,-0.000006,-0.000001,0.000000,0.000000,0.000000
2,4.219220,4.223976,4.360209,4.157657,4.104465,4.213430,4.166508,4.175277,4.247559,4.206135,...,2.536712,2,0,24.909348,-3.814697e-06,-0.000011,-0.000003,0.000000,0.000000,0.000000
3,4.283030,4.223969,4.360202,4.157654,4.104461,4.213428,4.166503,4.175275,4.247555,4.206131,...,2.536708,3,0,24.909348,-4.768372e-06,-0.000019,-0.000005,0.000000,0.000000,0.000000
4,4.233604,4.223961,4.360194,4.157651,4.104454,4.213425,4.166497,4.175271,4.247551,4.206125,...,2.536702,4,0,24.909348,-6.198883e-06,-0.000031,-0.000007,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
945,2.673234,4.083452,3.260144,2.266535,1.499078,1.086742,0.774978,0.600018,0.521194,1.761518,...,0.102291,373737,4,222.775085,-2.638620e-07,-0.098285,-0.055775,0.000193,72.314930,0.324610
946,2.781868,4.085014,3.266781,2.301377,1.627052,1.218396,0.926676,0.754805,0.637363,1.852183,...,0.199315,75757,4,40.282597,-1.311492e-06,-0.096718,-0.052085,0.000352,26.459183,0.659412
947,2.779250,4.084165,3.258761,2.265608,1.477728,1.050309,0.745599,0.574186,0.493149,1.743688,...,0.084611,484848,4,232.149216,-1.004480e-07,-0.048488,-0.027801,0.000009,4.391759,0.018918
948,2.744716,4.083723,3.258915,2.299166,1.548486,1.126998,0.850175,0.667554,0.594283,1.803662,...,0.146541,186868,4,115.900276,8.405793e-07,0.153130,0.085275,0.000774,144.731806,1.247879


## Geometric

### LLC estimates

In [26]:
import wandb

def merge_dfs(df1, df2, inplace=True):
    if not inplace:
        raise NotImplementedError()
    
    seeds = df1["model_seed"].unique()
    steps = df1["step"].unique()

    stdlogger.info(f"Merging {len(seeds)} seeds and {len(steps)} steps")
    for seed in seeds:
        for step in steps:
            for k in df2.columns:
                if k not in df1.columns:
                    df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]

    return df1

if os.path.exists(DATA / MODELS_ID / "llcs.pt"):
    stdlogger.info("Retrieving LLC estimates from disk")
    with open(DATA / MODELS_ID / "llcs.pt", 'rb') as f:
        llc_df = torch.load(f)

else:
    stdlogger.info("Retrieving LLC estimates from wandb sweep")

    api = wandb.Api()
    sweep = api.sweep(f"devinterp/icl/{LLC_SWEEP_ID}")
    wandb_runs = sweep.runs

    llc_df = None

    for llc_run in tqdm.tqdm(wandb_runs):
        history_df = llc_run.history()

        llc_mean_columns = [f'llc/mean/{i}' for i in range(8)]
        history_df[llc_mean_columns] = history_df[llc_mean_columns].replace("NaN", np.nan)

        llc_std_columns = [f'llc/std/{i}' for i in range(8)]
        history_df[llc_std_columns] = history_df[llc_std_columns].replace("NaN", np.nan)
        history_df['model_seed'] = llc_run.config['task_config']['model_seed']
        history_df['step'] = history_df['_step']

        if llc_df is None:
            llc_df = history_df
        else:
            llc_df = pd.concat([llc_df, history_df])

    stdlogger.info("Saving LLCs to disk")
    with open(DATA / MODELS_ID / "llcs.pt", 'wb') as f:
        torch.save(llc_df, f)

stdlogger.info("Merging LLCs into evals")
evals_over_time_df = merge_dfs(evals_over_time_df, llc_df)

stdlogger.info("Calculating derivatives")
for i in evals_over_time_df.model_seed.unique():
    for column in ['llc/mean/mean']:
        add_slopes(evals_over_time_df, column, i)

llc_df

100%|██████████| 5/5 [00:03<00:00,  1.40it/s]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] == seed) & (df1["step"] == step), k] = df2.loc[(df2["model_seed"] == seed) & (df2["step"] == step), k]
  df1.loc[(df1["model_seed"] =

Unnamed: 0,batch-loss/chain-2/std/std,batch-loss/chain-7/std/4,batch-loss/chain-6/std/mean,loss/mean/2,loss/mean/7,batch-loss/chain-5/std/5,wbic/mean/6,batch-loss/chain-0/mean/5,batch-loss/chain-1/std/0,batch-loss/chain-2/mean/6,...,batch-loss/chain-8/std/0,batch-loss/chain-1/mean/2,batch-loss/chain-7/std/5,batch-loss/chain-1/mean/3,llc/std/2,batch-loss/chain-7/mean/3,wbic/std/std,batch-loss/chain-0/std/mean,model_seed,step
0,0.002800,0.264470,0.265750,4.439744,4.474955,0.285651,4688701.50,4.453844,0.273078,4.487923,...,0.266796,4.433786,0.256581,4.451212,17.755661,4.432512,1703.146477,0.268458,1,0
1,0.003820,0.265399,0.263671,4.425170,4.477533,0.301905,4667845.00,4.461284,0.257812,4.469196,...,0.274583,4.419053,0.268279,4.446731,17.693489,4.426115,2348.498732,0.270319,1,1
2,0.003786,0.265424,0.263710,4.425153,4.477554,0.301775,4667819.50,4.461306,0.257505,4.469173,...,0.274791,4.419082,0.267966,4.446737,17.725080,4.426130,2216.673692,0.270326,1,2
3,0.003757,0.265280,0.263705,4.425192,4.477541,0.301968,4667816.50,4.461197,0.257453,4.469088,...,0.274774,4.419080,0.268162,4.446727,17.679588,4.426131,2413.262776,0.270338,1,3
4,0.003815,0.265399,0.263675,4.425185,4.477520,0.301639,4667831.00,4.461257,0.257757,4.469166,...,0.274576,4.419018,0.268173,4.446680,17.683422,4.426101,2358.911137,0.270344,1,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
185,0.073918,0.340378,0.313977,3.699310,3.328932,0.503810,3466857.00,3.775999,0.238505,3.090042,...,0.267582,3.617803,0.345056,3.469490,24.637434,3.420046,126904.854249,0.486331,0,479797
186,0.073030,0.340323,0.314419,3.698307,3.326499,0.492948,3463597.25,3.777043,0.238893,3.087273,...,0.267225,3.611918,0.344430,3.462779,24.557392,3.418888,125944.830586,0.485179,0,484848
187,0.073581,0.338793,0.315070,3.696387,3.323102,0.482055,3459223.00,3.774743,0.238713,3.089475,...,0.267813,3.607190,0.343320,3.456326,24.520861,3.418793,125107.340861,0.485441,0,489898
188,0.074062,0.337162,0.315380,3.696808,3.322949,0.488047,3460043.50,3.777103,0.238817,3.091058,...,0.267835,3.606944,0.342204,3.456615,24.499849,3.415974,125458.254138,0.485523,0,494949


In [27]:
# Wait if my layer norm theory is right. Then we should see a sudden improvement in the ability of the model to make predictions for out-of-distribution xs/ys (not ws). 
from devinfra.utils.seed import set_seed
from icl.analysis.baselines import fit_ridge
from icl.regression.tasks import apply_transformations
from devinfra.utils.iterables import flatten_dict
from copy import deepcopy

torch.manual_seed(0)

def eval_loss(yhats, ys):
    losses = ((yhats - ys) ** 2).mean(dim=0)[:, 0]
    return [loss.item() for loss in losses]

OOD_MULTIPLIER = 5
ws = torch.normal(
    mean=0.,
    std=1.,
    size=(BATCH_SIZE, D,),
    device=DEVICE
)

# sample i.i.d. inputs and outputs for each task according to the
# regression model
xs = torch.normal(
    mean=0.,
    std=1.,
    size=(BATCH_SIZE, K, D),
    device=DEVICE
)

errors = torch.normal(
    mean=0.,
    std=0.125,
    size=(BATCH_SIZE, K, 1,),
    device=DEVICE,
)
ys = apply_transformations(ws, xs, 0.125, DEVICE) # xs @ ws.view(BATCH_SIZE, D, 1) + errors

ood_a_xs = 3 * xs
ys_ood_a_inputs = apply_transformations(ws, ood_a_xs, 0.125, DEVICE) # ood_xs @ ws.view(BATCH_SIZE, D, 1) + errors
ood_b_xs = 5 * xs
ys_ood_b_inputs = apply_transformations(ws, ood_b_xs, 0.125, DEVICE) # ood_xs @ ws.view(BATCH_SIZE, D, 1) + errors
ood_c_xs = 10 * xs
ys_ood_c_inputs = apply_transformations(ws, ood_c_xs, 0.125, DEVICE) # ood_xs @ ws.view(BATCH_SIZE, D, 1) + errors
ood_d_xs = 100 * xs
ys_ood_d_inputs = apply_transformations(ws, ood_d_xs, 0.125, DEVICE) # ood_xs @ ws.view(BATCH_SIZE, D, 1) + errors

first_xs = deepcopy(xs[:, 0:1, :])
first_x_ws = fit_ridge(first_xs, ys[:, 0:1, :], 0.125)
first_x_ys = apply_transformations(first_x_ws, xs, 0., DEVICE)

first_xs /= torch.norm(first_xs, dim=-1, keepdim=True) ** 2
for b in range(BATCH_SIZE):
    first_xs[b] *= ys[b, 0, 0]

ys_using_first_x = xs @ first_xs.view(BATCH_SIZE, D, 1) 

def eval_all(model):
    ypreds = model(xs, ys)
    results = {
        "loss": eval_loss(ypreds, ys),
        "loss_0": eval_loss(ypreds, torch.zeros_like(ys)),
        "loss_first_x": eval_loss(ypreds, first_x_ys),
        "ood_a_inputs_loss": eval_loss(model(ood_a_xs, ys_ood_a_inputs), ys_ood_a_inputs),
        "ood_a_tasks_loss": eval_loss(model(xs, ys_ood_a_inputs), ys_ood_a_inputs),
        "ood_b_inputs_loss": eval_loss(model(ood_b_xs, ys_ood_b_inputs), ys_ood_b_inputs),
        "ood_b_tasks_loss": eval_loss(model(xs, ys_ood_b_inputs), ys_ood_b_inputs),
        "ood_c_inputs_loss": eval_loss(model(ood_c_xs, ys_ood_c_inputs), ys_ood_c_inputs),
        "ood_c_tasks_loss": eval_loss(model(xs, ys_ood_c_inputs), ys_ood_c_inputs),
        "ood_d_inputs_loss": eval_loss(model(ood_d_xs, ys_ood_d_inputs), ys_ood_d_inputs),
        "ood_d_tasks_loss": eval_loss(model(xs, ys_ood_d_inputs), ys_ood_d_inputs),
    }
    return results


for seed, (run, _models) in enumerate(zip(runs, all_models)):
    # sample a batch of random tasks
    for step, model in tqdm.tqdm(zip(steps, _models), total=len(steps), desc=f"Seed {seed}"):
        metrics = eval_all(model)
        
        for k, v in metrics.items():
            for i in range(8):
                    evals_over_time_df.loc[((evals_over_time_df.step == step) & (evals_over_time_df.model_seed==seed)), f"{k}/{i}"] = v[i]
                    
            evals_over_time_df.loc[((evals_over_time_df.step == step) & (evals_over_time_df.model_seed==seed)), f"{k}/mean"] = np.mean(v)


ModuleNotFoundError: No module named 'icl.regression'

### Hessians

In [None]:
import pyhessian
from pyhessian import hessian # Hessian computation

class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inputs):
        return self.model(inputs[0], inputs[1])

hessian_stats_df = pd.DataFrame()

if os.path.exists(DATA / MODELS_ID / "hessian_stats.pt"):
    stdlogger.info("Loading hessian stats from disk")
    with open(DATA / MODELS_ID / "hessian_stats.pt", 'rb') as f:
        hessian_stats_df = torch.load(f)

else:
    stdlogger.info("Running hessian stats")
    hessian_stats = []

    xs = xs.to('cpu')
    ys = ys.to('cpu')

    for i, _models in enumerate(all_models):
        for step, model in zip(steps, tqdm.tqdm(_models)):
            model = model.to('cpu')
            ref_model = ModelWrapper(model)
            hessian_comp = hessian(ref_model, F.mse_loss, data=((xs[:1024], ys[:1024]), ys[:1024]), cuda=False)

            _top_evals, _ = hessian_comp.eigenvalues(top_n=20)
            trace = hessian_comp.trace()

            hessian_stats.append({
                "model_seed": i,
                "step": step,
                "hessian/trace": trace,
            } | {f"hessian/evals/{i}": _top_evals[i] for i in range(20)})
            model.to('mps')
                
    xs = xs.to('mps')
    ys = ys.to('mps')

    hessian_stats_df = pd.DataFrame(hessian_stats)

    stdlogger.info("Saving hessian stats to disk")
    with open(DATA / MODELS_ID / "hessian_stats.pt", 'wb') as f:
        torch.save(hessian_stats_df, f)

hessian_stats_df

## Structural

### Embedding

In [None]:
def get_embedding(model):
    return model.state_dict()['token_sequence_transformer.token_embedding.weight']

def get_unembedding(model):
    return model.state_dict()['token_sequence_transformer.unembedding.1.weight']

import itertools

models1 = all_models[MODEL_SEED]

embed_sing_vals = []
postn_sing_vals = []
entangling = []

for step, model in zip(steps, models1):
    embed = model.token_sequence_transformer.token_embedding.weight.detach().cpu().numpy()
    U_embed, S_embed, Vt_embed = np.linalg.svd(embed)
    embed_trace = np.sum(S_embed ** 2)
    S_embed_normed = S_embed ** 2 / embed_trace 

    # pca = PCA(n_components=5).fit(embed)

    for i in range(len(S_embed)):
        embed_sing_vals.append({
            "step": step, "index": i,
            "embed/S": S_embed[i],
            # "embed/S_normed": pca.explained_variance_ratio_[i],
            "embed/S_normed": S_embed_normed[i], 
        })
    
    postn = model.token_sequence_transformer.postn_embedding.weight.detach().cpu().numpy()
    U_postn, S_postn, Vt_postn = np.linalg.svd(postn)
    postn_trace = np.sum(S_postn ** 2)
    S_postn_normed = S_postn ** 2 / postn_trace

    # pca = PCA(n_components=16).fit(postn)

    for i in range(len(S_postn)):
        postn_sing_vals.append({
            "step": step, "index": i,
            "postn/S": S_postn[i], 
            # "postn/S_normed": pca.explained_variance_ratio_[i], # 
            "postn/S_normed": S_postn_normed[i], 
        })

    # Compute the cossim between the singular vectors of the embedding and the positional encoding space
    # entangling = np.zeros((len(S_embed), len(S_postn)))

    # for i, j in itertools.product(range(len(S_embed)), range(len(S_postn))):
    #     v = V_embed[i]
    #     u = V_postn[j]
    #     cossim = np.dot(v, u) / (np.linalg.norm(v) * np.linalg.norm(u))
    #     entangling[i, j] = cossim


    cossims = []
    for i in range(len(S_embed)):           
        u = embed[:, i] # U_embed[:,]
        # v = Vt_postn[i]
        u_proj = np.dot(U_postn, u)
        cossim = np.abs(np.dot(u_proj, u) / (np.linalg.norm(u_proj) * np.linalg.norm(u)))
        cossims.append(cossim)

    # cossims = sorted(cossims, reverse=True) 
    
    for i in range(len(S_embed)):
        entangling.append({
            "step": step,
            "embed/trace": embed_trace,
            "postn/trace": postn_trace,
            "index": i,
            "cossim": cossims[i],   
        })

print(embed.shape, postn.shape)   
print(U_embed.shape, S_embed.shape, Vt_embed.shape)
print(U_postn.shape, S_postn.shape, Vt_postn.shape)

embed_sing_vals = pd.DataFrame(embed_sing_vals)
postn_sing_vals = pd.DataFrame(postn_sing_vals)
entangling = pd.DataFrame(entangling)


### Unembedding

In [None]:
unembeddings = []

for step, model in zip(steps, models1):
    for subset in ["weight", "bias"]:
        layer = f"ln.{subset}"
        for i, param in enumerate(getattr(model.token_sequence_transformer.unembedding[0], subset)):
            unembeddings.append({"p": param.item(), "step": step, "layer": layer, "i": i})

        layer = f"linear.{subset}"
        layer_param = getattr(model.token_sequence_transformer.unembedding[1], subset)
        if subset == "weight":
            layer_param = layer_param[0, :]
            for i, param in enumerate(layer_param):
                unembeddings.append({"p": param.item(), "step": step, "layer": layer, "i": i})
        else:
            layer_param = layer_param[0]
            unembeddings.append({"p": layer_param.item(), "step": step, "layer": layer, "i": 0})


unembeddings = pd.DataFrame(unembeddings)

In [None]:
reduced_unembeddings = []

for step, model in zip(steps, models1):
    reduced_weight = model.token_sequence_transformer.unembedding[1].weight[0, :] * model.token_sequence_transformer.unembedding[0].weight
    reduced_bias = model.token_sequence_transformer.unembedding[1].weight[0, :] @ model.token_sequence_transformer.unembedding[0].bias + model.token_sequence_transformer.unembedding[1].bias[0]

    for i, param in enumerate(reduced_weight):
        reduced_unembeddings.append({"p": param.item(), "subset": "weight", "step": step,  "i": i})

    reduced_unembeddings.append({"p": reduced_bias.item(), "subset": "bias", "step": step,  "i": 0})

reduced_unembeddings = pd.DataFrame(reduced_unembeddings)

### Layer Norm

In [None]:
from icl.analysis.slt import prepend_keys

layer_norms = [
    "token_sequence_transformer.blocks.0.layer_norms.0",
    "token_sequence_transformer.blocks.0.layer_norms.1",
    "token_sequence_transformer.blocks.1.layer_norms.0",
    "token_sequence_transformer.blocks.1.layer_norms.1",
    "token_sequence_transformer.unembedding.0",
]

list(model.state_dict().keys())

def get_ln(model, key):
    return (model.state_dict()[f'{key}.weight'], model.state_dict()[f'{key}.bias'])

unembedding_lns = [get_ln(model, 'token_sequence_transformer.unembedding.0') for model in models]
block_1_attn_lns =  [get_ln(model, 'token_sequence_transformer.blocks.0.layer_norms.0') for model in models]
block_1_mlp_lns =  [get_ln(model, 'token_sequence_transformer.blocks.0.layer_norms.1') for model in models]
block_2_attn_lns =  [get_ln(model, 'token_sequence_transformer.blocks.1.layer_norms.0') for model in models]
block_2_mlp_lns =  [get_ln(model, 'token_sequence_transformer.blocks.1.layer_norms.1') for model in models]

def ln_norm(weight, bias):
    return torch.norm(weight).detach().cpu().numpy()

def ln_norm_std(weight, bias):
    return torch.std(weight.abs()).detach().cpu().numpy()

unembedding_ln_norms = [ln_norm(weight, bias) for weight, bias in unembedding_lns]
block_1_attn_ln_norms = [ln_norm(weight, bias) for weight, bias in block_1_attn_lns]
block_1_mlp_ln_norms = [ln_norm(weight, bias) for weight, bias in block_1_mlp_lns]
block_2_attn_ln_norms = [ln_norm(weight, bias) for weight, bias in block_2_attn_lns]
block_2_mlp_ln_norms = [ln_norm(weight, bias) for weight, bias in block_2_mlp_lns]

unembedding_ln_norms_std = np.array([ln_norm_std(weight, bias) for weight, bias in unembedding_lns])
block_1_attn_ln_norms_std = np.array([ln_norm_std(weight, bias) for weight, bias in block_1_attn_lns])
block_1_mlp_ln_norms_std = np.array([ln_norm_std(weight, bias) for weight, bias in block_1_mlp_lns])
block_2_attn_ln_norms_std = np.array([ln_norm_std(weight, bias) for weight, bias in block_2_attn_lns])
block_2_mlp_ln_norms_std = np.array([ln_norm_std(weight, bias) for weight, bias in block_2_mlp_lns])

def frac_nonzero(weight, eps=1e-1):
    return (weight.abs() > eps).float().mean().detach().cpu().numpy()

unembedding_ln_norm_nonzero = [frac_nonzero(weight) for weight, bias in unembedding_lns]
block_1_attn_ln_norm_nonzero = [frac_nonzero(weight) for weight, bias in block_1_attn_lns]
block_1_mlp_ln_norm_nonzero = [frac_nonzero(weight) for weight, bias in block_1_mlp_lns]
block_2_attn_ln_norm_nonzero = [frac_nonzero(weight) for weight, bias in block_2_attn_lns]
block_2_mlp_ln_norm_nonzero = [frac_nonzero(weight) for weight, bias in block_2_mlp_lns]

ln_stats = []

def get_stats(weight):
    return {
        "norm": weight.norm().item(),
        "norm_std": weight.abs().std().item(),
        "std": weight.std().item(),
        "mean": weight.mean().item(),
        "max": weight.max().item(),
        "min": weight.min().item(),
    }
    

for step, model in zip(steps, models1):
    for layer in [ "blocks.0.layer_norms.0", "blocks.0.layer_norms.1", "blocks.1.layer_norms.0", "blocks.1.layer_norms.1", "unembedding.0"]:
        weight, bias = get_ln(model, f"token_sequence_transformer.{layer}")

        ln_stats.append({
            "step": step,
            "layer": layer,
            "layer_pretty": layer.replace("_", " ").title(),
            **prepend_keys(get_stats(weight), "weight"),
            **prepend_keys(get_stats(bias), "bias"),
        })

ln_stats = pd.DataFrame(ln_stats)
ln_stats

### Attention

In [None]:
from collections import defaultdict
from typing import List, Union, Iterable, Optional
from torchtyping import TensorType
from devinfra.utils.iterables import map_nested

from icl.experiments.utils import iter_models
from devinfra.utils.iterables import flatten_dict

from icl.train import Run

def compute_attention_entropies(attn: TensorType["B", "H", "2K", "2K"]):
    """
    Computes the entropy of each token in each head, averaged across the batch, 
    then averages this over heads. 

    """
    
    # Threshold attention weights to avoid log(0)
    log_attention = torch.where(attn > 0, torch.log(attn), torch.tensor(0.0).to(attn.device))
    entropy_per_token = - torch.sum(attn * log_attention, dim=-1).mean(dim=0).squeeze(-1) # TensorType["H", "2K"]

    num_heads, num_tokens = entropy_per_token.shape

    entropy_per_head = entropy_per_token.mean(dim=-1) # TensorType["H"]
    entropy = entropy_per_head.mean() # TensorType[]    
    
    # Each token computes entropy over a variable context length, so we normalize by the maximum possible entropy
    # for a token with a fixed context length.

    max_entropy_per_token = torch.log2(torch.arange(1, num_tokens + 1).to(attn.device)) # TensorType["H", "2K"]
    max_entropy_per_token[0] = 1. # Special case for the first token to avoid dividing by 0

    entropy_per_token_normalized = entropy_per_token / max_entropy_per_token
    entropy_per_head_normalized = entropy_per_token_normalized.mean(dim=-1) # TensorType["H"]
    entropy_normalized = entropy_per_head_normalized.mean() # TensorType[]    

    results: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] = {"mean": entropy, "mean_normalized": entropy_normalized}

    for i in range(num_heads):
        head_results = {"mean": entropy_per_head[i], "mean_normalized": entropy_per_head_normalized[i]}

        for j in range(num_tokens):
            head_results[f"token_{j}"] = entropy_per_token[i, j]
            head_results[f"token_{j}_normalized"] = entropy_per_token_normalized[i, j]

        results[f"head_{i}"] = head_results

    return map_nested(lambda x: convert_tensor(x, "np"), results)


def get_attention_entropies_trace(
    steps: List[int],
    models: Iterable[nn.Module],
    xs: torch.Tensor,
    ys: torch.Tensor,
    **paths,
):
    results = defaultdict(list)
    reverse_paths = {v: k for k, v in paths.items()}

    for activations in extract_activations_over_checkpoints(models, xs, ys, *paths.values(), return_type="pt"):
        for k, v in activations.items():
            if k == "":
                continue
            path = reverse_paths[k]
            results[path].append(compute_attention_entropies(v))

    values = []

    for i in range(len(steps)):
        value = {}

        for block in results.keys():
            value[block] = results[block][i]
        
        value["step"] = steps[i]
        values.append(flatten_dict(value, flatten_lists=True))

    return pd.DataFrame(values)

def compute_attention_variability(attn: TensorType["B", "H", "2K", "2K"]):
    """
    Computes the variability of the attention pattern of each head across the batch.
    """

    num_batches, num_heads, num_tokens, _ = attn.shape
    num_tokens = num_tokens - 1

    results: Dict[str, Union[float, Dict[str, float]]] = {}

    attn = attn[:, :, :-1, :-1] # Remove the last y token which gets to training signal
    mean_attn_pattern = attn.mean(dim=0, keepdim=True)

    log_attention = torch.where(attn > 0, torch.log2(attn), torch.tensor(0.0).to(attn.device))
    entropy_per_token = - torch.sum(attn * log_attention, dim=-1).mean(dim=0).squeeze(-1)

    variability = (((attn - mean_attn_pattern).abs().sum(dim=-1)) / (2 * mean_attn_pattern.sum(dim=-1))).mean(dim=0) 
    prev_token_attn = torch.zeros(num_heads, num_tokens, device=attn.device)

    self_attn = attn.diagonal(dim1=-2, dim2=-1).mean(dim=0)
    prev_token_attn[:, 1:] = attn.diagonal(dim1=-2, dim2=-1, offset=-1).mean(dim=0)
    x_tokens_attn = attn[:, :, :, 0::2].sum(dim=-1).mean(dim=0)
    y_tokens_attn = attn[:, :, :, 1::2].sum(dim=-1).mean(dim=0)
    hardness_per_token = (attn <= 0.001).float().mean(dim=-1).mean(dim=0)
    distances = (torch.arange(0, num_tokens, device=attn.device).view(1, 1, 1, -1) * attn).sum(dim=-1).mean(dim=0)
    # first_x_attn = attn[:, :, :, 0].mean(dim=0)

    # print(hardness_per_token.shape)
    for i in range(15):
        hardness_per_token[:, i] = hardness_per_token[:, i] / (i + 1)
        distances[:, i] = distances[:, i] / (i + 1)

    results["entropy"] = entropy_per_token.mean() 
    results["variability"] = variability.mean().item()
    results["self_attn"] = self_attn.mean().item()
    results["prev_token_attn"] = prev_token_attn.mean().item()
    results["x_tokens_attn"] = x_tokens_attn.mean().item()
    results["y_tokens_attn"] = y_tokens_attn.mean().item()
    results["hardness"] = hardness_per_token.mean().item()
    results["distance"] = distances.mean().item()

    results["x"] = {
        "entropy": entropy_per_token[::2].mean().item(),
        "variability": variability[::2].mean().item(),
        "self_attn": self_attn[::2].mean().item(),
        "prev_token_attn": prev_token_attn[::2].mean().item(),
        "x_tokens_attn": x_tokens_attn[::2].mean().item(),
        "y_tokens_attn": y_tokens_attn[::2].mean().item(),
        "hardness": hardness_per_token[::2].mean().item(),
        "distance": distances[::2].mean().item()
    }

    results["y"] = {
        "entropy": entropy_per_token[1::2].mean().item(),
        "variability": variability[1::2].mean().item(),
        "self_attn": self_attn[1::2].mean().item(),
        "prev_token_attn": prev_token_attn[1::2].mean().item(),
        "x_tokens_attn": x_tokens_attn[1::2].mean().item(),
        "y_tokens_attn": y_tokens_attn[1::2].mean().item(),
        "hardness": hardness_per_token[1::2].mean().item(),
        "distance": distances[1::2].mean().item()
    }

    for i in range(num_heads):  
        head_entropy = entropy_per_token[i]
        head_variability = variability[i]
        head_self_attn = self_attn[i]
        head_prev_token_attn = prev_token_attn[i]
        head_x_tokens_attn = x_tokens_attn[i]
        head_y_tokens_attn = y_tokens_attn[i]
        head_hardness = hardness_per_token[i]
        head_distance = distances[i]

        head_results = {
            "entropy": head_entropy.mean().item(),
            "variability": head_variability.mean().item(),
            "self_attn": head_self_attn.mean().item(),
            "prev_token_attn": head_prev_token_attn.mean().item(),
            "x_tokens_attn": head_x_tokens_attn.mean().item(),
            "y_tokens_attn": head_y_tokens_attn.mean().item(),
            "hardness": head_hardness.mean().item(),
            "distance": head_distance.mean().item()
        }

        for x_or_y in (1, 0):
            head_half_results = dict(
                entropy = head_entropy[x_or_y::2].mean().item(),
                variability = head_variability[x_or_y::2].mean().item(),
                self_attn = head_self_attn[x_or_y::2].mean().item(),
                prev_token_attn = head_prev_token_attn[x_or_y::2].mean().item(),
                x_tokens_attn = head_x_tokens_attn[x_or_y::2].mean().item(),
                y_tokens_attn = head_y_tokens_attn[x_or_y::2].mean().item(),
                hardness = head_hardness[x_or_y::2].mean().item(),
                distance = head_distance[x_or_y::2].mean().item()
            )

            head_results["x" if not x_or_y else "y"] = head_half_results

            for j in range(x_or_y, num_tokens, 2):
                head_results[f"token_{j}/entropy"] = head_entropy[j].item()
                head_results[f"token_{j}/entropy_normalized"] = head_entropy[j].item() / np.log2(j + 1)
                head_results[f"token_{j}/variability"] = head_variability[j].item()
                head_results[f"token_{j}/self_attn"] = self_attn[i, j].item()
                head_results[f"token_{j}/prev_token_attn"] = prev_token_attn[i, j].item()
                head_results[f"token_{j}/x_tokens_attn"] = x_tokens_attn[i, j].item()
                head_results[f"token_{j}/y_tokens_attn"] = y_tokens_attn[i, j].item()
                head_results[f"token_{j}/hardness"] = hardness_per_token[i, j].item()
                head_results[f"token_{j}/distance"] = distances[i, j].item()

        results[f"head_{i}"] = head_results

    return map_nested(lambda x: convert_tensor(x, "np"), results)


def get_attention_variabilities(
    steps: List[int],
    models: Iterable[nn.Module],
    xs: torch.Tensor,
    ys: torch.Tensor,
    **paths,
):
    results = defaultdict(list)
    reverse_paths = {v: k for k, v in paths.items()}

    for activations in tqdm.tqdm(extract_activations_over_checkpoints(models, xs, ys, *paths.values(), return_type="pt"), total=len(models)):
        for k, v in activations.items():
            if k == "":
                continue
            path = reverse_paths[k]
            results[path].append(compute_attention_variability(v))

    values = []

    for i in range(len(steps)):
        value = {}

        for block in results.keys():
            value[block] = results[block][i]
        
        value["step"] = steps[i]
        values.append(flatten_dict(value, flatten_lists=True))

    return pd.DataFrame(values)

attn_variabilities = get_attention_variabilities(
    run.checkpointer.file_ids,
    models, 
    xs, 
    ys, 
    **{f"block_{b}": f"token_sequence_transformer.blocks.{b}.attention.attention_softmax" for b in range(num_blocks)}
)

## Identify Transitions

In [None]:
import colorsys
from functools import partial
from matplotlib.colors import LinearSegmentedColormap
import sys

del sys.modules['icl.figures.colors']

from icl.figures.colors import plot_transitions, gen_transition_colors, get_transition_type, PRIMARY, SECONDARY, TERTIARY, BRED, BBLUE, BRED, BGREEN, decrease_brightness, increase_saturation, increase_contrast, rainbow, LR_TRANSITION_COLORS

def get_transition_indices(steps, transitions):
    transition_indices = []
    for step in steps:
        # Find the index of the transition that the current step falls into
        index = next((i for i, transition in enumerate(transitions) if transition[0] <= step < transition[1]), None)
        transition_indices.append(index if index is not None else -1)

    return transition_indices

def get_nearest_step(step):
    idx = np.argmin(np.abs(np.array(steps) - step))
    return steps[idx]

TRANSITIONS = [
    (0, 1500, 'R1'),
    (1500, 40_000, 'R2'),
    (40000, 320000, 'R3'),
    (320000, 500000, 'R4'),
]

TRANSITIONS = [(get_nearest_step(start), get_nearest_step(end), label) for start, end, label in TRANSITIONS]

transition_rainbow = list(reversed(rainbow(len(TRANSITIONS))))
transitions_cmap = LinearSegmentedColormap.from_list("transitions", LR_TRANSITION_COLORS)
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
_plot_transitions = partial(plot_transitions, transitions=TRANSITIONS, colors=LR_TRANSITION_COLORS)

transitions_of_steps = get_transition_indices(steps, TRANSITIONS)
highlight_steps = list(map(get_nearest_step, [t[0] for t in TRANSITIONS][1:]))

# Show the transition colors
fig, ax = plt.subplots(figsize=(10, 1))
ax.set_xscale('log')
ax.set_xlim(100, 500_000)
_plot_transitions(ax)

# Multiple-Seed Figures

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_palette('deep')

# Data setup (use your actual data here)
# For demonstration, replace evals_over_time_df with your DataFrame
# evals_over_time_df = ...

# Create figure
fig, axs = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))

axs[0].set_ylabel(r'Test loss  $\hat\ell(w_t)$' '\n')
axs[1].set_ylabel(r'Local learning coeff.  $\hat\lambda(w_t)$')

# First line plot
sns.lineplot(evals_over_time_df, x='step', y='pretrain/mse', ax=axs[0])
# axs[0].set_title(r'(b) Loss over Time')
axs[0].set_xscale('log')
# axs[0].set_yscale('log')

# Second line plot
sns.lineplot(evals_over_time_df, x='step', y='llc/mean/mean', ax=axs[1])
# axs[1].set_title(r'(d) Local Learning Coefficient over Time')
axs[1].set_xscale('log')

# Set x-label for both plots
# for ax in axs:

handles = _plot_transitions(axs, xlim=(100, 500_000)) 
axs[0].set_xlabel('')
axs[1].set_xlabel('Training step $t$')


# fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(1.05, .65), ncol=1)
fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.55, 0.01), ncol=len(TRANSITIONS))

# axs[1].legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(TRANSITIONS))
# Layout adjustments
plt.tight_layout()
fig.set_facecolor('white')

fig.savefig(FIGURES / f"lr-fig1-top.pdf", bbox_inches='tight')
plt.show()


# Model-Specific Figures

In [None]:
MODEL_SEED = 0  # 0, 1, 2, 3, 4
MODEL_ID = f"LR{MODEL_SEED}"

if not os.path.exists(FIGURES / MODEL_ID):
    os.makedirs(FIGURES / MODEL_ID)

models1 = all_models[MODEL_SEED]
final_model = deepcopy(models1[-1]).to('cpu')
run = runs[-1]

In [None]:
for t1, t2, _ in TRANSITIONS: # + [(499_999, None, None)]:
    print(f"Step {t1}->{t2} Delta Lambdahat:", evals_over_time_df.loc[(evals_over_time_df.step == get_nearest_step(t2)) & (evals_over_time_df.model_seed == MODEL_SEED)]['llc/mean/mean'].values[0] - evals_over_time_df.loc[(evals_over_time_df.step == get_nearest_step(t1)) & (evals_over_time_df.model_seed == MODEL_SEED)]['llc/mean/mean'].values[0])
    print(f"Step {t1}->{t2} Delta Loss:", evals_over_time_df.loc[(evals_over_time_df.step == get_nearest_step(t2)) & (evals_over_time_df.model_seed == MODEL_SEED)]['pretrain/mse'].values[0] - evals_over_time_df.loc[(evals_over_time_df.step == get_nearest_step(t1)) & (evals_over_time_df.model_seed == MODEL_SEED)]['pretrain/mse'].values[0])

In [None]:
from matplotlib import lines as mlines
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C


_df = evals_over_time_df.loc[evals_over_time_df.model_seed == MODEL_SEED]

metrics_to_plot = [
    (r"\hat\ell(w_t)", "pretrain/mse", {"logy": True, "derivative": "d_dlogt", "spline": True, "s": 0.1}, ),
    # (r"L_\mathcal{G}(t)", _df["true/mse"], {"logy": False}),
    (r"\hat \lambda(w_t)", 'llc/mean/mean', {"derivative": "d_dlogt", "spline": True}),
    (r"|w_t|", "weight/norm", {"derivative": "d_dt", "logy": True, "spline": True, "s": 0.1}),
] 
fig, axes = plt.subplots(2, len(metrics_to_plot), figsize=(FULL_WIDTH * 1.25, FULL_HEIGHT))

# axes = np.array(axes)
axes = axes.reshape(2, len(metrics_to_plot))

def str_dlog_dlogt(s):
    return r"$\delta \log " + s + r"/\delta\log t$"

for i, (metric_name, metric_key, kwargs) in enumerate(metrics_to_plot):
    use_spline = kwargs.get("spline", False)

    sns.lineplot(data=_df, x="step", y=_df[metric_key], ax=axes[0, i],label=metric_name, alpha=1 - use_spline * 0.75)
    # axes[0, i].plot(_df['step'], metric_values, label=metric_name, marker='.')
    axes[0, i].set_title(f"")
    axes[0, i].set_xlabel('Step $t$')
    axes[0, i].set_ylabel(f"${metric_name}$")

    if kwargs.get("logy", False):
        axes[0, i].set_yscale('log')

    axes[0, i].legend().remove()

    slope_type = kwargs.get("derivative", "d_dlogt")

    if slope_type == "d_dlogt":
        slope_name = str_d_dlogt(metric_name)
    elif slope_type == "d_dt":
        slope_name = str_d_dt(metric_name)
    elif slope_type == "dlog_dlogt":
        slope_name = str_dlog_dlogt(metric_name)
    else:
        raise ValueError(f"Unknown slope type {slope_type}")

    sns.lineplot(data=_df, x="step", y=f"{metric_key}/{slope_type}", ax=axes[1, i], label=metric_name + " Slope", alpha=1 - use_spline * 0.75)
    axes[1, i].axhline(0, linestyle='--', color='gray')
    axes[1, i].set_title("")
    axes[1, i].set_xlabel('Step, $t$')
    axes[1, i].set_ylabel(slope_name)
    axes[1, i].legend().remove()
    
    if use_spline:     
        _steps = np.log(np.array(steps) + 1 ).reshape((-1, 1))
        _y = _df.groupby('step').mean()[metric_key].values

        kernel = C(1.0, (1e-3, 1e3)) * RBF(3, (5e-1, 1e3))

        # Create a Gaussian Process Regressor
        gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)

        # Fit the Gaussian Process
        gp.fit(_steps, _y)
        _ypred = gp.predict(_steps)

        if slope_type == "d_dlogt":
            _derivy = d_dt(_steps, _ypred)
        elif slope_type == "d_dt":
            _derivy = d_dt(np.exp(_steps), _ypred)
        elif slope_type == "dlog_dlogt":            
            _derivy = d_dt(_steps, np.log(_ypred))
        
        axes[0, i].plot(steps, _ypred, label="Spline", linestyle='--', color=BRED)
        axes[1, i].plot(steps, _derivy, label="Spline", linestyle='--', color=BRED)


for ax in axes.flatten():
    ax.set_xscale('log')
    ax.set_xlim(100, 500_000)
    # ax.set_ylabel("")

# axes[0, 1].set_ylim(0, 100)
axes[1, 0].set_ylim(-2.25, 2.25)
axes[1, 1].set_ylim(-150, 160)

patch_list = plot_transitions(axes, TRANSITIONS, xlim=True, colors=LR_TRANSITION_COLORS)

# axes[1, 1].set_yscale('symlog')
# axes[1, 0].set_yscale('symlog')
# axes[0,0].set_ylim(0, 70)

milestone_labels = [label for _, _, label in TRANSITIONS]
gp_fit_patch = mlines.Line2D([], [], color=BRED, linestyle='--', label="GP Fit")
fig.legend(patch_list + [gp_fit_patch], milestone_labels + ["Fit"], loc='upper center', bbox_to_anchor=(0.5, -0.025), ncol=len(TRANSITIONS) + 1)

fig.set_facecolor("white")
fig.tight_layout()

fig.savefig(FIGURES / MODEL_ID /"loss-llc-with-slopes.pdf", bbox_inches='tight')

# Behavioral Indicators

In [None]:
k0 = 0

titles_and_labels = {
    "loss": ("Test Loss $L_\mathrm{val}$", "MSE"),
    "loss_0": ("$\mathbb{E}[|\hat{y}_k|^2]$", r"$\|\hat y_k\|^2$"),
    # "loss_first_x": ("(MSE from Ridge", "MSE"),
    "pretrain/delta_ridge/token": ("MSE from Ridge", "MSE"),  
    "ood_inputs_loss": (r"MSE on $x_i \sim \mathcal{N}(0, 5I_D)$", "MSE"),
    "ood_targets_loss": (r"MSE on $\mathbf{t} \sim \mathcal{N}(0, 5I_D)$", "MSE"),
    "llc/mean": (r"Per-Token $\lambda_t$", "$\lambda_t^{(i)}$"),
    "icl_score": (f"$ICL_{{{k0+1}:8}}(w_t)/g$", "ICL"),
    "ood_inputs_rel_loss": (r"$\frac{\mathrm{MSE}(5 x_i)}{\mathrm{MSE}(x_i)}$", "MSE"),
}

## Loss, Prediction Norm, OOD Loss

In [None]:
fig, axes = plt.subplots(len(metrics), 1, figsize=(WIDTH, HEIGHT * 2))  # Adjust the figsize as needed

df = evals_over_time_df.loc[evals_over_time_df.model_seed == MODEL_SEED]

metrics = ['loss_0', 'icl_score'] #, "ood_inputs_rel_loss"]# "loss", "ood_loss"]

fig, axes = plt.subplots(len(metrics), 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))  # Adjust the figsize as needed
token_cmap = ScalarMappable(norm=Normalize(vmin=0, vmax=8), cmap="viridis")

ood_colors = sns.color_palette("viridis", 4)


for m, metric in enumerate(metrics):
    if metric == 'icl_score':
        data = df[f'loss/7'].values - df[f'loss/{k0}'].values
        sns.lineplot(x=df.step, y=data, ax=axes[m], color=ood_colors[0], label=f"$g=1$")

        for i, (l, g) in enumerate(zip(['a', 'b', 'c'], [3, 5, 10, 100])):
            data = (df[f'ood_{l}_inputs_loss/7'].values - df[f'ood_{l}_inputs_loss/{k0}'].values ) / (g)
            sns.lineplot(x=df.step, y=data, ax=axes[m], color=ood_colors[i+1], label=f"$g={g}$")

    elif metric == 'ood_inputs_rel_loss':
        for l, g in zip(['a', 'b', 'c', 'd'], [3, 5, 10, 100]):
            data = df[f'ood_{l}_inputs_loss/mean'].values / g ** 2 # df[f'loss/mean'].values
            sns.lineplot(x=df.step, y=data, ax=axes[m])
    else:
        for i in range(8):
            color = token_cmap.to_rgba(i)
            # if i == 0 and metric == "loss_first_x":
            #     continue

            sns.lineplot(data=df, x="step", y=f"{metric}/{i}", ax=axes[m], alpha=0.5, color=color, label=f"$k={i+1}$" if i % 2 == 0 else "_")

        if metric.endswith('/token'):
            mean_metric = metric[:-6] 
        else:
            mean_metric = f"{metric}/mean"
            
        sns.lineplot(data=df, x="step", y=mean_metric, label="Mean", ax=axes[m], color=BRED)

for ax in axes:
    ax.set_xscale("log")
    ax.set_xlabel("")

    # legend = ax.legend()
    # legend.remove()
    ax.set_xlim(100, 500_000)

ax.set_xlabel("Step $t$")
# legend = axes[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=6)
# legend.set_title("Per-Token Losses")

# Move legend to be likee fig.legend(patch_list, milestone_titles_and_labels, loc='upper center', bbox_to_anchor=(0.5, -0.025), ncol=len(TRANSITIONS))
for i, (ax, metric) in enumerate(zip(axes, metrics)):
    title, label = titles_and_labels[metric]
    # ax.set_ylabel(label)
    ax.set_ylabel("")
    l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'][i]
    # ax.set_title(f"({l}) {title}")
    ax.set_ylabel(title)

    # if settings.get(metric, {}).get("logy", False):
    # ax.set_yscale("log")
    ax.set_xscale('log')

for ax in [axes[0]]: #, axes[2]]:
    ax.set_yscale('log')

for ax in axes.flatten()[:-1]:
    ax.set_xticks([])
    ax.set_xticklabels([])
    # ax.set_ylabel("")

# axes[3].set_yscale('linear')
# axes[5].set_yscale('linear')
plt.tight_layout()

# Change right space
fig.subplots_adjust(right=0.75) 
axes[0].legend(loc='upper center', bbox_to_anchor=(1.2, .95), ncol=1)
axes[1].legend(loc='upper center', bbox_to_anchor=(1.20, 1.), ncol=1, title="$x_k\sim \mathcal{N}(0, gI_D)$")

_plot_transitions(axes)

# Add color bar on the far right

fig.set_facecolor('white')
fig.savefig(FIGURES / 'lr-behavioral-indicators.pdf');

In [None]:
titles_and_labels = {
    "loss_0": ("Average Prediction Norm", r"$\|\hat y_k\|^2$"),
    "loss_first_x": ("MSE from Ridge-Optimal 1-sample Prediction", "MSE"),
    "pretrain/delta_ridge/token": ("MSE from Ridge Regression", "MSE"),  
    "loss": ("Loss", "MSE"),
    "ood_loss": ("MSE on Large Inputs", "MSE"),
    "llc/mean": ("Per-Token $\lambda_t$", "$\lambda_t^{(i)}$"),
    "hessian": ("Hessian Statistics", "")
}

df = evals_over_time_df.loc[evals_over_time_df.model_seed == MODEL_SEED]
df['hessian/trace'] = hessian_traces_np

metrics = ["llc/mean"]# "loss", "ood_loss"]

fig, axes = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))  # Adjust the figsize as needed

token_cmap = ScalarMappable(norm=Normalize(vmin=0, vmax=8), cmap="viridis")


for i in range(8):
    color = token_cmap.to_rgba(i)

    for m, metric in enumerate(metrics):
        # if i == 0 and metric == "loss_first_x":
        #     continue

        sns.lineplot(data=df, x="step", y=f"{metric}/{i}", ax=axes[m], alpha=1, color=color, label=f"$k={i+1}$" if i % 2 == 0 else "_")

    # if i > 0:
    #     sns.lineplot(data=df, x="step", y=f"loss_first_x/{i}", ax=axes[2], alpha=0.5, color=color)

for m, metric in enumerate(metrics):
    if metric.endswith('/token'):
        mean_metric = metric[:-6] 
    else:
        mean_metric = f"{metric}/mean"
        
    sns.lineplot(data=df, x="step", y=mean_metric, label="Mean", ax=axes[m], color=BRED)

ax = axes[0]
ax.set_ylabel(r"$\hat\lambda_k(w_t)$")
ax = axes[1]

sns.lineplot(x=steps, y=top_evals_np[:, 0], color=cmap3[0], label="Eigenvalue 1", ax=ax)
sns.lineplot(x=steps, y=top_evals_np[:, 1], color=cmap3[1], label="Eigenvalue 2", ax=ax)
sns.lineplot(x=steps, y=top_evals_np[:, 2], color=cmap3[2], label="Eigenvalue 3", ax=ax)
sns.lineplot(x=steps, y=hessian_traces_np, ax=ax, color=BRED, label="Trace")

ax.set_yscale('log')

for ax in axes:
    ax.set_xscale("log")
    ax.set_xlabel("Step $t$")

    ax.set_xlim(100, 500_000)

# legend = axes[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=6)
# legend.set_title("Per-Token Losses")

# Move legend to be likee fig.legend(patch_list, milestone_titles_and_labels, loc='upper center', bbox_to_anchor=(0.5, -0.025), ncol=len(TRANSITIONS))
for i, (ax, metric) in enumerate(zip(axes, [*metrics, 'hessian'])):
    title, label = titles_and_labels[metric]
    l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'][i]
    # ax.set_title(f"({l}) {title}")
    ax.set_ylabel(f"{title}")

    if settings.get(metric, {}).get("logy", False):
        ax.set_yscale("log")

for ax in axes:
    ax.set_xscale('log')
    ax.set_xlabel('')
ax.set_xlabel('Step $t$')

axes[0].set_xticks([])
axes[0].set_xticklabels([])
        
plt.tight_layout()

fig.subplots_adjust(right=0.75) 
axes[0].legend(loc='upper center', bbox_to_anchor=(1.2, .95), ncol=1)
axes[1].legend(loc='upper center', bbox_to_anchor=(1.20, 1.), ncol=1, title="Hessian Statistics")

_plot_transitions(axes)


# Add color bar on the far right

fig.set_facecolor('white')
fig.savefig(FIGURES / 'lr-geometric-indicators.pdf')


## ICL Score

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(FULL_WIDTH, FULL_HEIGHT))

icl_score_1 =  -df['loss/0'] + df['loss/3']
icl_score_2 =  - df['loss/3'] + df['loss/7']

print(icl_score_1)

df['icl_score_1'] = icl_score_1
df['icl_score_2'] = icl_score_2

for ax in axes:
    ax.axhline(0, color='black', linestyle='-', linewidth=0.5)
    sns.lineplot(data=df, x="step", y="icl_score_1", ax=ax)
    sns.lineplot(data=df, x="step", y="icl_score_2", ax=ax)
    ax.set_ylabel("$I_{k:k'}$")

axes[0].set_title("(a) ICL Scores")
axes[1].set_title("(b) ICL Scores (Zoomed)")
axes[0].set_xlim(100, 500_000)
axes[0].set_ylim(-2.8, 0.5)
axes[1].set_xlim(1000, 10_000)
axes[1].set_ylim(-2, 0.4)

# Draw box
ylim = axes[0].get_ylim()
height = ylim[1] - ylim[0]

bottom = (-2 - ylim[0]) / height
# print(bottom, bottom * height, ylim)
height = (0.4 - ylim[0]) / height

xlim = axes[0].get_xlim()
width = np.log10(xlim[1]) - np.log10(xlim[0])

left = (np.log10(1000) - np.log10(100)) / width
right = (np.log10(10000) - np.log10(100)) / width

axes[0].axvline(1000, bottom, height, alpha=0.2, color='black', linestyle='--', linewidth=0.5)
axes[0].axvline(10000, bottom, height, alpha=0.2, color='black', linestyle='--', linewidth=0.5)
axes[0].axhline(-2, left, right, alpha=0.2, color='black', linestyle='--', linewidth=0.5)
axes[0].axhline(0.4, left, right, alpha=0.2, color='black', linestyle='--', linewidth=0.5)

# ylims1 = axes[0].get_ylim()
# ylims2 = axes[1].get_ylim()

# min_ylim = min(ylims1[0], ylims2[0])
# max_ylim = max(ylims1[1], ylims2[1])

for ax in axes:
    ax.set_xlabel("Step $t$")
    ax.set_xscale("log")
    ax.set_ylabel("")
    # ax.set_yscale("log")
    ax.legend().remove()
    # ax.set_ylim(min_ylim, max_ylim)

# handles = axes[0].get_legend_handles_labels()[0]
# ax.legend(handles=handles, labels=[f"$M = 2^{{{m}}}$" for m in [4, 8, 12, 16, 20]], bbox_to_anchor=(1.05, .9), loc='upper left', borderaxespad=0.)
ax.legend(labels=["_", "$I_{1:4}$", "_", "$I_{4:8}$"], bbox_to_anchor=(1.05, .7), loc='upper left', borderaxespad=0.)

_plot_transitions(axes)

fig.tight_layout()
fig.savefig(FIGURES / "icl-scores.pdf", bbox_inches='tight')
plt.show()

## OOD Generalization

In [None]:
ood_losses = []
final_model_1 = models1[-1]

ws = torch.normal(0, 1, size=(len(xs), D,), device=DEVICE)

def _eval_loss(yhats, ys):
    return F.mse_loss(yhats, ys).item()


for step, model in zip(steps, tqdm.tqdm(models1)):
    y_preds = model(xs, ys)
    y_norm = y_preds.norm(dim=-1).mean().item()

    for multiplier in tqdm.tqdm(np.logspace(-2.5, 4.5, 100, base=10), disable=True):
        ood_xs = multiplier * xs

        ood_ys = apply_transformations(ws, ood_xs, 0.125, DEVICE)
        ood_input_preds = model(ood_xs, ood_ys)
        ood_task_preds = model(xs, ood_ys)

        # print(ood_input_preds.shape)
        ood_input_norm = ood_input_preds.norm(dim=-1).mean().item() / y_norm
        ood_task_norm = ood_task_preds.norm(dim=-1).mean().item() / y_norm

        ood_input_loss = _eval_loss(ood_input_preds, ood_ys)
        ood_task_loss = _eval_loss(ood_task_preds, ood_ys)

        ood_losses.append({
            "step": step,
            "multiplier": multiplier,
            "ood_input_loss": ood_input_loss,
            "ood_task_loss": ood_task_loss,
            "ood_input_loss_div_gain": ood_input_loss / multiplier, 
            "ood_task_loss_div_gain": ood_task_loss / multiplier,
            "ood_input_loss_to_og": _eval_loss(ood_input_preds, ys),
            "ood_task_loss_to_og": _eval_loss(ood_task_preds, ys),
            "ood_input_norm": ood_input_norm,
            "ood_task_norm": ood_task_norm,
        })

ood_losses_df = pd.DataFrame(ood_losses)

In [None]:
ood_losses_df['ood_input_loss_div_gain'] = ood_losses_df['ood_input_loss'] / ood_losses_df['multiplier'] ** 2
ood_losses_df['ood_task_loss_div_gain'] = ood_losses_df['ood_task_loss'] / ood_losses_df['multiplier'] ** 2

for m in ['ood_input_loss', 'ood_task_loss', 'ood_input_loss_div_gain', 'ood_task_loss_div_gain', 'ood_input_loss_to_og', 'ood_task_loss_to_og', 'ood_input_norm', 'ood_task_norm']:
    ood_losses_df[f"{m}_log"] = np.log10(ood_losses_df[m])

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(FULL_WIDTH, FULL_HEIGHT * 2))

sns.lineplot(data=ood_losses_df, x="multiplier", y="ood_input_loss", hue="step", palette='viridis', ax=axes[0, 0], alpha=0.5)
sns.lineplot(data=ood_losses_df, x="multiplier", y="ood_task_loss", hue="step", palette='viridis', ax=axes[0, 1], alpha=0.5)
sns.lineplot(data=ood_losses_df, x="multiplier", y="ood_input_norm", hue="step", palette='viridis', ax=axes[1, 0], alpha=0.5)
sns.lineplot(data=ood_losses_df, x="multiplier", y="ood_task_norm", hue="step", palette='viridis', ax=axes[1, 1], alpha=0.5)

for ax in axes[0]:
    ax.set_ylabel("MSE")

for ax in axes[1]:
    ax.set_ylabel("Norm")

axes[0, 0].set_title("(a) OOD Input Loss")
axes[0, 1].set_title("(b) OOD Task Loss")
axes[1, 0].set_title("(c) OOD Input Norm Relative to Baseline")
axes[1, 1].set_title("(d) OOD Task Norm Relative to Baseline")

for ax in axes.flatten():
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("OOD Multiplier")
    ax.legend().remove()
    ax.axvline(1, color='black', linestyle='--', linewidth=0.5)

# Add a colorbar to the right showing steps
plt.tight_layout()

fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.95, 0.15, 0.025, 0.7])
cbar = fig.colorbar(ScalarMappable(norm=Normalize(vmin=0, vmax=500_000), cmap='viridis'), cax=cbar_ax)
cbar.set_label("Step $t$")
cbar.set_ticks([0, 100_000, 200_000, 300_000, 400_000, 500_000])

fig.savefig(FIGURES / "lr/ood-performance.pdf", bbox_inches='tight')

## Cossims?

In [None]:
from icl.regression.tasks import apply_transformations

def compute_eff_weight(model, xs, ws, pos=-1, norm='auto', errors=0.125):
    B, K, D = xs.shape

    xs_copy = xs.clone().detach()
    avg_norm = xs.norm(dim=-1).mean()

    if norm == 'auto':
        norm = avg_norm

    if isinstance(errors, int):
        errors = torch.normal(
            mean=0.,
            std=errors,
            size=(B, K, 1,),
            device=DEVICE,
        )

    eff_weights = torch.zeros(B, D, device=DEVICE)

    for i in range(D):
        xs_copy[:, pos, :] = 0.0
        xs_copy[:, pos, i] = norm

        ys = xs_copy @ ws.view(B, D, 1) + errors

        y_preds = model(xs_copy, ys)
        eff_weights[:, i] = y_preds[:, pos, 0]
    
    return eff_weights

num_samples = 256
_xs = torch.normal(mean=0, std=1., size=(num_samples, K, D), device=DEVICE)
ws = torch.normal(mean=0, std=1., size=(num_samples, 4), device=DEVICE)
errors = torch.normal(
    mean=0.,
    std=0,
    size=(num_samples, K, 1,),
    device=DEVICE,
)
_ys = _xs @ ws.view(num_samples, D, 1) + errors
first_xs = deepcopy(_xs[:, 0:1, :])
first_xs /= torch.norm(first_xs, dim=-1, keepdim=True) ** 2
first_xs *= _ys[:, 0:1, :].repeat(1, 1, 4)

def prev_x_proj(xs, ys):
    prev_xs = deepcopy(xs[:, :, :])
    prev_xs /= torch.norm(prev_xs, dim=-1, keepdim=True) ** 2

    for i in range(K):   
        for j in range(D):
            prev_xs[:, i, j] *= ys[:, i, 0]

    prev_xs = prev_xs.roll(1, dims=1)
    prev_xs[:, 0, :] = 0.

    return prev_xs

prev_xs = prev_x_proj(_xs, _ys)

cumulative_xs = prev_xs.cumsum(dim=1)
for i in range(K):
    cumulative_xs[:, i, :] /= (i + 1)

def compute_eff_weights(_models, steps, xs, ws, errors, ref_ws = None):
    if ref_ws is None:
        ref_ws = ws

    eff_weights_df = []
    for step, model in zip(steps, tqdm.tqdm(_models)):
        if ref_ws == 'final_w':
            ref_ws = compute_eff_weight(model, xs, ws, errors=errors, pos=-1)
        for p in range(0, 8):
            eff_weights = compute_eff_weight(model, xs, ws, errors=errors, pos=p)

            if len(ref_ws.shape) > 2:
                _ws = ref_ws[:,p, :]
            else:
                _ws = ref_ws

            # print(_ws.shape, eff_weights.shape)
            dot_prods = (_ws * eff_weights).sum(dim=-1)
            cossim = dot_prods / (_ws.norm(dim=-1) * eff_weights.norm(dim=-1))
            cossim_mean = cossim.mean(dim=0)

            relnorm = (eff_weights.norm(dim=-1) / _ws.norm(dim=-1))

            eff_weights_df.append({
                "step": step,
                "cossim": cossim_mean.item(),
                "cossim_std": cossim.std(dim=0).item(),
                "relnorm": relnorm.mean().item(),
                "relnorm_std": relnorm.std().item(),
                "mse": (((eff_weights - _ws) ** 2).sum(dim=-1) ** 0.5).mean().item(),
                "position": p
            })

    return pd.DataFrame(eff_weights_df)

eff_weights_df = compute_eff_weights(models1, steps, _xs, ws, errors, ref_ws=ws)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(WIDTH * 1.5, HEIGHT))

ax = axes[0]
sns.lineplot(data=eff_weights_df, x="step", y="cossim", hue="position", ax=ax, palette='viridis', alpha=0.8)
ax.set_ylabel(r"$\mathbb{E}[S_C(\hat{\mathbf{t}}, \mathbf{t})]$")
ax.set_title("(a) Avg. Cosine Similarity")

ax = axes[1]
sns.lineplot(data=eff_weights_df, x="step", y="relnorm", hue="position", ax=ax, palette='viridis', alpha=0.8)
ax.set_ylabel(r"$\mathbb{E}[\hat{\mathbf{t}}/\mathbf{t}]$")
ax.set_title("(b) Avg. Relative Norm")

# axes[1].set_yscale('log')

for ax in axes:
    ax.set_xscale('log')
    ax.set_xlabel("Step $t$")
    ax.set_xlim(100, 500_000)
    ax.legend().remove()
    ax.axvline(4100)

_plot_transitions(axes)
plt.tight_layout()
fig.savefig(FIGURES / "lr/eff-s-gain-w-10.pdf", bbox_inches='tight')

# Essential Dynamics

In [None]:
from icl.figures.plotting import plot_explained_variance

def plot_multiple_slices(steps, samples, pca, transitions, highlighted_steps=None, connect_dots=False, palette='tab10', alpha=0.8, save=False, line_color="auto", figsize=(20, 4)):
    transition_idxs = get_transition_indices(steps, transitions)
    # transition_idxs = [(0 if i != 4 else 1) for i in transition_idxs]

    # for i in range(1, 5):
    #     transition_idxs[-i] = 10  

    if highlighted_steps is None:
        highlighted_steps = list(map(get_nearest_step, [t[0] for t in transitions][1:]))

    num_pca_components = samples.shape[-1]
    
    # Create a single row of subplots
    num_pca_combos = (num_pca_components * (num_pca_components-1)) // 2
    
    fig, axes = plt.subplots(1, num_pca_combos + 1, figsize=figsize)
    # fig.suptitle(title)

    # Ensure ax is iterable by converting to a list if there's only one subplot
    if num_pca_components == 2:
        axes = [axes]

    I = 0
    for i in range(1, num_pca_components):
        for j in range(i):

            if connect_dots:
                axes[I].plot(samples[:, i], samples[:, j], c='black', alpha=0.2)

            # sc = axes[I].scatter(samples[:, i], samples[:, j], c=transition_idxs, cmap=cmap, s=50, alpha=alpha)
            sns.scatterplot(x=samples[:, i], y=samples[:, j], hue=transition_idxs, palette=palette, s=50, alpha=alpha, ax=axes[I], legend=False)
            axes[I].set_xlabel(f'PC {i}')
            axes[I].set_ylabel(f'PC {j}')
            axes[I].set_title(f'PC {i} vs PC {j}')

            # Label some points
            total_samples = len(samples)
            for step in highlighted_steps:
                k = steps.index(step)  # Find the index of the highlighted step
                axes[I].text(samples[k, i], samples[k, j], str(step), fontsize=8, ha='right', va='bottom', alpha=0.8)

            I += 1

    plot_explained_variance(pca, ax=axes[-1], num_pca_components=num_pca_components)
    # for I in range( num_pca_combos):
    #     axes[I].axis('off')
            
    # Colorbar for the last plot
    # cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # Adjust as necessary
        # plt.colorbar(sc, cax=cbar_ax, label='Milestones')

    cmap = sns.palettes.color_palette(palette, n_colors=len(transitions) + 1)

    # Plot the legend on the first subplot on the left
    legend_ax = axes[0]
    scatter_proxy = [plt.Line2D([0], [0], linestyle='none', marker='o', alpha=alpha, color=cmap[i]) for i in range(len(transitions))]
    legend_labels = [label for _, _, label in transitions]
    legend_ax.legend(scatter_proxy, legend_labels, loc='center', ncol=1, frameon=False, bbox_to_anchor=(-0.5, 0.5), title='Developmental Stages')
    # legend_ax.set_title()

    plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust the right side to make room for the colorbar

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)


    fig.set_facecolor('white')
    
# Usage of the function
# Call the function with your data and the list of highlighted steps
# plot_multiple_slices(steps, samples, pca, highlighted_steps=[100, 1000, 10000], title="Your Title", num_points_to_label=10, save="path/to/save.png", connect_dots=True)

In [None]:
all_outputs = []
all_weights = []
all_gradients = []

all_running_grads = []
all_running_grads_squared = []

def get_weights_vector(model):
    return np.concatenate([param.detach().cpu().numpy().flatten() for name, param in model.named_parameters() if param is not None])

def get_gradients_vector(model):
    return np.concatenate([param.grad.detach().cpu().numpy().flatten() for name, param in model.named_parameters() if param.grad is not None])


def get_exp_avg_sq_grads(optimizer_state_dict):
    return np.concatenate([g["exp_avg_sq"].cpu().numpy().flatten() for g in optimizer_state_dict["state"].values()])

def get_exp_avg_grads(optimizer_state_dict):
    return np.concatenate([g["exp_avg"].cpu().numpy().flatten() for g in optimizer_state_dict["state"].values()])


for i, (_models, opt_state_dicts) in enumerate(zip(tqdm.tqdm(all_models, desc="Getting outputs..."), all_optimizer_state_dicts)):
    # Outputs of token sequence transformer
    outputs = []
    for path, activations in get_vectorized_activations_trace(_models, xs, ys, 'token_sequence_transformer', normalize=False).items():
        outputs.append(activations)

    all_outputs.append(np.concatenate(outputs, axis=1))

    # Weights & Gradients
    gradients = []
    weights = []

    for model in _models:
        model.to(DEVICE)
        xs.to(DEVICE)
        ys.to(DEVICE)
        model.train()
        model.zero_grad()
        ys_pred = model(xs, ys)
        loss = F.mse_loss(ys_pred, ys)
        loss.backward()

        gradients.append(get_gradients_vector(model))
        weights.append(get_weights_vector(model))

    all_gradients.append(np.array(gradients))
    all_weights.append(np.array(weights))

    # Optimizer states
    running_grads = []
    running_grads_squared = []
    for opt_state_dict in opt_state_dicts:
        running_grads.append(get_exp_avg_grads(opt_state_dict))
        running_grads_squared.append(get_exp_avg_sq_grads(opt_state_dict))

    all_running_grads.append(np.array(running_grads))
    all_running_grads_squared.append(np.array(running_grads_squared))

In [None]:
import pickle


def plot_essential_dynamics_grid(steps, all_samples, transitions, palette='tab10', save=False, figsize=(20, 4), num_pca_components=3, max_step=None, normalize=False, labels=None, max_plot=-1):
    num_samples = len(all_samples)  

    # Create a single row of subplots
    num_pca_combos = (num_pca_components * (num_pca_components-1)) // 2
    fig, all_axes = plt.subplots(num_samples, num_pca_combos + 1, figsize=figsize)
    
    if num_samples == 1:
        all_axes = [all_axes]

    labels = labels or [f"Model {i+1}" for i in range(num_samples)]

    for samples_idx, _samples in enumerate(tqdm.tqdm(all_samples, desc="Plotting...")):
        if max_step is not None:
            max_step_idx = steps.index(max_step)
            _samples = _samples[:max_step_idx, :]
        if normalize:
            _samples = _samples / np.linalg.norm(_samples, axis=1, keepdims=True)

        pca = PCA(n_components=num_pca_components)
        samples = pca.fit_transform(_samples)

        with open(DATA / f"pca-{samples_idx}.pkl", "wb") as f:
            pickle.dump((pca, samples), f)

        axes = all_axes[samples_idx]

        # Ensure ax is iterable by converting to a list if there's only one subplot
        if num_pca_components == 2:
            axes = [axes]

        # colors = list(reversed(rainbow(len(transitions) + 1)))
        if isinstance(palette, str):
            colors = sns.palettes.color_palette(palette, n_colors=len(transitions) + 1)
        else:
            colors = palette

        I = 0
        for i in range(1, num_pca_components):
            for j in range(i):
                sns.scatterplot(x=samples[:max_plot, i], y=samples[:max_plot, j], ax=axes[I], alpha=0.5, color="gray", s=10, legend=False)
                for k, (start, end, stage) in enumerate(transitions):
                    start_idx = steps.index(start)
                    end_idx = steps.index(end) + 1

                    if max_plot > 0:
                        if start > max_plot:
                            continue

                        end_idx = min(end_idx, max_plot)

                    # sc = axes[I].scatter(samples[:, i], samples[:, j], c=transition_idxs, cmap=cmap, s=50, alpha=alpha)
                    axes[I].plot(samples[start_idx:end_idx, i], samples[start_idx:end_idx, j], color=colors[k])

                if not transitions:
                    axes[I].plot(samples[:max_plot, i], samples[:maxplot, j])

                axes[I].set_xlabel(f'PC {i+1}')
                axes[I].set_ylabel(f'PC {j+1}')
                axes[I].set_title(f'PC {j+1} vs PC {i+1}')

                I += 1

        axes[0].set_ylabel(f"{labels[samples_idx]}\n\nPC 1")

        plot_explained_variance(pca, ax=axes[-1], num_pca_components=num_pca_components)

    # cmap = sns.palettes.color_palette(palette, n_colors=len(transitions) + 1)
    # Plot the legend on the first subplot on the left
    # legend_ax = axes[0]
    # scatter_proxy = [plt.Line2D([0], [0], linestyle='none', marker='o', alpha=alpha, color=cmap[i]) for i in range(len(transitions))]
    # legend_labels = [label for _, _, label in transitions]
    # legend_ax.legend(scatter_proxy, legend_labels, loc='center', ncol=1, frameon=False, bbox_to_anchor=(-0.5, 0.5), title='Developmental Stages')
    # legend_ax.set_title()

    # plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust the right side to make room for the colorbar
    # plt.tight_layout(rect=[0, 0, 1, 1])

    if transitions:
        # Create an axis for the legend

        # Create a list of handles for the legend
        handles = [plt.Line2D([0], [0], color=colors[i], linestyle='-') for i in range(len(transitions))]
        labels = [label for _, _, label in transitions]

        # Add legend to the new axis
        fig.legend(handles, labels, loc='center', ncol=len(labels), frameon=False, bbox_to_anchor=(0.5, 0.02))
        # Add some space at the bottom for the legend

    plt.tight_layout()  # Adjust layout first
    plt.subplots_adjust(bottom=0.1, top=0.9)  # Fine-tune spacing, adjust these values as needed

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    fig.set_facecolor('white')
    return fig 

In [None]:
colors

In [None]:
# For model 2
FULL_WIDTH = 6.75
FULL_HEIGHT = FULL_WIDTH / golden_ratio

red =sns.color_palette('tab10', 4)[3]
deep_orange = sns.color_palette('tab20c', 5)[4]
light_orange = sns.color_palette('tab20', 4)[2]
light_orange = decrease_brightness(light_orange, 1.5)
light_blue = sns.color_palette('tab20', 4)[1]
gray = sns.color_palette('tab20c', 18)[17]
gray = decrease_brightness(gray, 0.5)
colors = [light_orange, red, light_blue, gray]

print(label)
fig = plot_essential_dynamics_grid(steps, [all_outputs[MODEL_SEED], all_weights[MODEL_SEED], all_gradients[MODEL_SEED], all_running_grads[MODEL_SEED], all_running_grads_squared[MODEL_SEED]], TRANSITIONS, num_pca_components=4, figsize=(FULL_WIDTH * 2, FULL_HEIGHT * .8), labels=['(a) Behavioral ED', '(b) Weight ED', '(c) Gradient ED', '(d) 1st Moment Gradient ED', '(e) 2nd Moment Gradient ED'],
                                   palette=colors)        
fig.savefig(FIGURES / 'lr-essential-dynamics.pdf')
plt.show()


### Default

In [None]:
def plot_essential_dynamics_grid(steps, samples, transitions, save=False, figsize=(20, 4), labels=None):
    fig, axes = plt.subplots(1, len(transitions) - 1, figsize=figsize)

    colors = sns.color_palette("tab10", n_colors=len(transitions) + 1)

    for i, (ax, t1, t2) in enumerate(zip(axes, transitions, transitions[1:])):
        t1_start_idx = steps.index(t1[0])
        boundary_idx = steps.index(t1[1])
        t2_start_idx = steps.index(t2[1])

        min_ivl = min(boundary_idx - t1_start_idx, t2_start_idx - boundary_idx)

        t1_start_idx = boundary_idx - min_ivl
        t2_start_idx = boundary_idx + min_ivl

        print(t1_start_idx, boundary_idx, t2_start_idx)
        print(steps)

        _steps = steps[t1_start_idx:t2_start_idx]
        _samples = samples[t1_start_idx:t2_start_idx, :]

        pca = PCA(n_components=3)
        projected = pca.fit_transform(_samples)

        sc = ax.scatter(projected[:, 0], projected[:, 1], c='gray', s=10, alpha=0.2)

        ax.plot(projected[:min_ivl+1, 0], projected[:min_ivl+1, 1], color=colors[i])
        ax.plot(projected[min_ivl:, 0], projected[min_ivl:, 1], color=colors[i+1])

        ax.set_xlabel('PC 1')
        ax.set_ylabel('PC 2')

        ax.set_title(f"{t1[2]}-{t2[2]}")

    return fig

plot_essential_dynamics_grid(steps, all_outputs[1], TRANSITIONS)



In [None]:
def plot_essential_dynamics_grid(steps, samples, transitions, save=False, figsize=(20, 4), labels=None):
    fig, axes = plt.subplots(1, len(transitions) - 1, figsize=figsize)

    colors = sns.color_palette("tab10", n_colors=len(transitions) + 1)
    pca = PCA(n_components=3)
    projected = pca.fit_transform(samples)

    for i, (ax, t1, t2) in enumerate(zip(axes, transitions, transitions[1:])):
        t1_start_idx = steps.index(t1[0])
        boundary_idx = steps.index(t1[1])
        t2_start_idx = steps.index(t2[1])

        min_ivl = min(boundary_idx - t1_start_idx, t2_start_idx - boundary_idx)

        t1_start_idx = boundary_idx - min_ivl
        t2_start_idx = boundary_idx + min_ivl

        print(t1_start_idx, boundary_idx, t2_start_idx)
        print(steps)

        _steps = steps[t1_start_idx:t2_start_idx]
        _samples = projected[t1_start_idx:t2_start_idx, :]
        sc = ax.scatter(_samples[:, 0], _samples[:, 1], c='gray', s=10, alpha=0.2)

        ax.plot(_samples[:min_ivl+1, 0], _samples[:min_ivl+1, 1], color=colors[i])
        ax.plot(_samples[min_ivl:, 0], _samples[min_ivl:, 1], color=colors[i+1])

        ax.set_xlabel('PC 1')
        ax.set_ylabel('PC 2')

        ax.set_title(f"{t1[2]}-{t2[2]}")

    return fig

plot_essential_dynamics_grid(steps, all_outputs[1], TRANSITIONS)



In [None]:
sns.set_palette("tab10")
for label, all_samples in zip(["Outputs", "Weights", "Gradients", "Running Grads", "Running Grads Squared"], [all_outputs, all_weights, all_gradients, all_running_grads, all_running_grads_squared]):
    print(label)
    plot_essential_dynamics_grid(steps, all_samples, TRANSITIONS, num_pca_components=4, figsize=(16, 4))        
    plt.show()

### Show only first N steps (but fit on all)

In [None]:
sns.set_palette("tab10")
for label, all_samples in zip(["Outputs", "Weights", "Gradients", "Running Grads", "Running Grads Squared"], [all_outputs, all_weights, all_gradients, all_running_grads, all_running_grads_squared]):
    print(label)
    plot_essential_dynamics_grid(steps, all_samples, TRANSITIONS, num_pca_components=4, figsize=(14, 12), max_plot=70)        
    plt.show()

### First 60k steps

In [None]:
# First 60k steps
MAX_STEP = get_nearest_step(TRANSITIONS[4][1])

for label, all_samples in zip(["Outputs", "Weights", "Gradients", "Running Grads", "Running Grads Squared"], [all_outputs, all_weights, all_gradients, all_running_grads, all_running_grads_squared]):
    print(label)
    plot_essential_dynamics_grid(steps, [all_samples[0]], TRANSITIONS[:5], num_pca_components=4, figsize=(15, 3), max_step=MAX_STEP)        
    plt.show()

### Normalized trajectories

In [None]:
# Normalized

for label, all_samples in zip(["Outputs", "Weights", "Gradients", "Running Grads", "Running Grads Squared"], [all_outputs, all_weights, all_gradients, all_running_grads, all_running_grads_squared]):
    print(label)
    plot_essential_dynamics_grid(steps, [all_samples[0]], TRANSITIONS, num_pca_components=4, figsize=(15, 3), normalize=True)        
    plt.show()

### Role of checkpoint interval

In [None]:
from devinfra.utils.iterables import int_linspace

linear_steps = int_linspace(0, 500_000, 100)[:-1]
linear_step_idxs = [steps.index(get_nearest_step(step)) for step in linear_steps]

logarithmic_steps = int_logspace(1, 500_000, 100)
logarithmic_step_idxs = [steps.index(get_nearest_step(step)) for step in logarithmic_steps]

assert (all([step in steps for step in logarithmic_steps]))

num_pca_components = 4
transitions = TRANSITIONS
    

for label, all_samples in zip(["Outputs", "Weights", "Gradients", "Running Grads", "Running Grads Squared"], [all_outputs, all_weights, all_gradients, all_running_grads, all_running_grads_squared]):
    linear_samples0s = [all_samples[0][idx] for idx in linear_step_idxs]
    logarithmic_samples0s = [all_samples[0][idx] for idx in logarithmic_step_idxs]
    print(label)
    all_samples = [all_samples[0], linear_samples0s, logarithmic_samples0s]  
    num_samples = len(all_samples)  

    # Create a single row of subplots
    num_pca_combos = (num_pca_components * (num_pca_components-1)) // 2
    fig, all_axes = plt.subplots(num_samples, num_pca_combos + 1, figsize=(15, 9))
    
    if num_samples == 1:
        all_axes = [all_axes]

    labels = ["Linear + Logarithmic Interval", "Linear Interval", "Logarithmic Interval"]
    
    pca = PCA(n_components=num_pca_components)
    reduced_all_samples0 = pca.fit_transform(all_samples[0])

    pca = PCA(n_components=num_pca_components)
    pca.fit(linear_samples0s)
    reduced_linear_samples0 = pca.transform(all_samples[0])

    pca = PCA(n_components=num_pca_components)
    pca.fit(logarithmic_samples0s)
    reduced_logarithmic_samples0 = pca.transform(all_samples[0])

    reduced_all_samples = [reduced_all_samples0, reduced_linear_samples0, reduced_logarithmic_samples0]

    for samples_idx, samples in enumerate(tqdm.tqdm(reduced_all_samples, desc="Plotting...")):

        axes = all_axes[samples_idx]

        # Ensure ax is iterable by converting to a list if there's only one subplot
        if num_pca_components == 2:
            axes = [axes]

        I = 0
        for i in range(1, num_pca_components):
            for j in range(i):
                sns.scatterplot(x=samples[:, i], y=samples[:, j], ax=axes[I], alpha=0.5, color="gray", s=10, legend=False)
                for k, (start, end, stage) in enumerate(transitions):
                    start_idx = steps.index(start)
                    end_idx = steps.index(end) + 1
                        
                    # sc = axes[I].scatter(samples[:, i], samples[:, j], c=transition_idxs, cmap=cmap, s=50, alpha=alpha)
                    axes[I].plot(samples[start_idx:end_idx, i], samples[start_idx:end_idx, j])

                axes[I].set_xlabel(f'PC {i+1}')
                axes[I].set_ylabel(f'PC {j+1}')
                axes[I].set_title(f'PC {j+1} vs PC {i+1}')

                I += 1

        axes[0].set_ylabel(f"{labels[samples_idx]}\n\nPC 1")

        plot_explained_variance(pca, ax=axes[-1], num_pca_components=num_pca_components)

    # cmap = sns.palettes.color_palette(palette, n_colors=len(transitions) + 1)
    # Plot the legend on the first subplot on the left
    # legend_ax = axes[0]
    # scatter_proxy = [plt.Line2D([0], [0], linestyle='none', marker='o', alpha=alpha, color=cmap[i]) for i in range(len(transitions))]
    # legend_labels = [label for _, _, label in transitions]
    # legend_ax.legend(scatter_proxy, legend_labels, loc='center', ncol=1, frameon=False, bbox_to_anchor=(-0.5, 0.5), title='Developmental Stages')
    # legend_ax.set_title()

    # plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust the right side to make room for the colorbar
    plt.tight_layout(rect=[0, 0, 1, 1])

    if transitions:
        # Create an axis for the legend
        legend_ax = fig.add_axes([0.1, -0.03, 0.95, 0.05])  # Adjust these values as needed

        # Create a list of handles for the legend
        handles = [plt.Line2D([0], [0], color=sns.color_palette('tab10')[i], linestyle='-') for i in range(len(transitions))]
        labels = [label for _, _, label in transitions]

        # Add legend to the new axis
        legend_ax.legend(handles, labels, loc='center', ncol=len(labels), frameon=False)
        legend_ax.axis('off')  # Turn off axis lines and labels

    fig.set_facecolor('white')
    plt.show()

# Structural indicators


## Embedding

In [None]:

fig, axes = plt.subplots(1, 3, figsize=(FULL_WIDTH * 4 / 3, FULL_HEIGHT * .5 ))

# PCA explained Variance over time

ax = axes[0]

sns.lineplot(data=embed_sing_vals, x="step", y="embed/S_normed", hue="index", palette="viridis", ax=ax)
ax.set_xscale('log')
ax.set_title("(a) Singular Values of Token Embedding")

ax = axes[1]

sns.lineplot(data=postn_sing_vals, x="step", y="postn/S_normed", hue="index", palette="viridis", ax=ax)
ax.set_xscale('log')
ax.set_title("(b) Singular Values of Positional Encoding")

for ax in axes[:2]:
    ax.set_ylabel(r"$\sigma_i^2/\,\mathrm{Tr}\,\Sigma$")
    ax.set_yscale('log')

ax = axes[2]

sns.lineplot(data=entangling, x="step", y="cossim", hue="index", palette="viridis", ax=ax)
ax.set_xscale('log')
ax.set_title(r"(c) Cossim between $W_{\mathrm{embed}}$ and $W_{\mathrm{postn}}$")
ax.set_ylabel("$S_C$")

for ax in axes:
    ax.set_xlabel("Step $t$")
    ax.legend().remove()
    ax.set_xlim(100, 500_000)

plt.tight_layout()
fig.set_facecolor('white')
_plot_transitions(axes)

fig.savefig(FIGURES / "embedding.pdf")

## Unembedding

In [None]:
FULL_HEIGHT = FULL_WIDTH * golden_ratio
fig, axes = plt.subplots(2, 3, figsize=(FULL_WIDTH, FULL_HEIGHT))

fancy = {
    'ln.weight': 'Layer Norm Weights',
    'ln.bias': 'Layer Norm Biases',
    'linear.weight': 'Linear Weights',
    'linear.bias': 'Linear Bias',
}

for i, layer in enumerate(["ln", "linear"]):
    for j, layer_subset in enumerate(["weight", "bias"]):
        data_subset = unembeddings.loc[unembeddings.layer == f"{layer}.{layer_subset}"]
        means = data_subset.groupby('step').mean()
        stds = data_subset.groupby('step').std()

        if i + j != 2:
            sns.lineplot(data=means, x="step", y='p', color=BRED, ax=axes[j, i])
            sns.lineplot(data=data_subset, x="step", y="p", hue="i", palette="viridis", ax=axes[j, i], alpha=0.8, linewidth=0.5)
            sns.lineplot(data=means, x="step", y='p', color=BRED, ax=axes[j, i])
            axes[j, i].fill_between(means.index, means.p - stds.p, means.p + stds.p, alpha=0.25, color=BRED)
            axes[j, i].legend().remove()
            # handles = [axes[j, i].get_legend_handles_labels()[0][0]]
            # axes[j, i].legend(labels=['Mean'], handles=handles, loc='upper left', frameon=False)

        else:
            sns.lineplot(data=data_subset, x="step", y="p", hue="i", ax=axes[j, i], color=BRED)
            axes[j, i].legend().remove()   

        axes[j, i].set_xscale('log')
        axes[j, i].set_ylabel(fancy[f"{layer}.{layer_subset}"])
        axes[j, i].set_xlim(200, 500000)
        axes[j, i].set_xlabel("Training step $t$")
        # axes[j, i].set_ylabel("Weight")


axes[0,0].hlines(0, 0, 500_000, color='black', linestyle='--', alpha=0.5)

# fig.suptitle("Unembedding Weights over Time", fontsize=14)

axes[1,1].legend().remove()
plt.tight_layout()

ax = axes[0, 2]

means = reduced_unembeddings.loc[reduced_unembeddings.subset == "weight"].groupby('step').mean()["p"]
stds = reduced_unembeddings.loc[reduced_unembeddings.subset == "weight"].groupby('step').std()["p"]

sns.lineplot(data=reduced_unembeddings.loc[reduced_unembeddings.subset == "weight"], x="step", y=means, ax=ax, color=BRED, label="Mean")
sns.lineplot(data=reduced_unembeddings.loc[reduced_unembeddings.subset == "weight"], x="step", y="p", hue="i", palette="viridis", ax=ax, alpha=0.8, linewidth=0.5)
sns.lineplot(data=reduced_unembeddings.loc[reduced_unembeddings.subset == "weight"], x="step", y=means, ax=ax, color=BRED)
ax.fill_between(means.index, means - stds, means + stds, alpha=0.25, color=BRED)
ax.set_title(f"{layer}.{layer_subset}")

ax = axes[1, 2]
sns.lineplot(data=reduced_unembeddings.loc[reduced_unembeddings.subset == "bias"], x="step", y="p", hue="i", ax=ax, color=BRED)

for ax, subset in zip([axes[0, 2], axes[1, 2]], ["Effective Unembedding Weights", "Effective Unembedding Bias"]):
    ax.legend().remove()
    ax.set_xscale('log')
    ax.set_ylabel(f"{subset}")
    ax.set_xlabel("Training step $t$")
    # ax.set_ylabel("Weight")
    ax.set_title('')
    ax.set_xlim(200, 500000)

_plot_transitions(axes)

fig.savefig(FIGURES / "lr-unembed.pdf")
# handles = [ax.get_legend_handles_labels()[0][0] for ax in axes]
# axes[0, 2].legend(labels=['Mean'], handles=handles, loc='upper left', frameon=False)
# ax.get_legend_handles_labels();

## Layer norms

In [None]:
HEIGHT = WIDTH * golden_ratio
fig, axes = plt.subplots(2, 1, figsize=(WIDTH, HEIGHT * 1.5))

# Fill between using the stdp
colors = sns.color_palette("viridis", 5)

for i, type_ in enumerate(["weight", "bias"]):
    sns.lineplot(data=ln_stats, x="step", y=f"{type_}/mean", hue="layer", palette="viridis", ax=axes[i], alpha=0.8)

    for l, layer in enumerate(ln_stats.layer.unique()):
        layer_ln_stats = ln_stats.loc[ln_stats.layer == layer]

        axes[i].fill_between(steps, layer_ln_stats[f"{type_}/mean"] - layer_ln_stats[f"{type_}/std"], layer_ln_stats[f"{type_}/mean"] + layer_ln_stats[f"{type_}/std"], alpha=0.1, color=colors[l])

for ax in axes.flatten():
    plot_transitions(ax, TRANSITIONS, colors=LR_TRANSITION_COLORS)
    ax.set_xscale('log')
    # ax.set_yscale('log')
    ax.set_xlim(100, 500_000)
    ax.legend().remove()

layers = ["Attn. 1", '_', "MLP 1",  '_', "Attn. 2", '_', "MLP 2", '_', "Unembed"]
axes[0].set_title(r"(a) LN Weights $\gamma$")
axes[0].set_xlabel("")
axes[0].set_xticklabels([])
axes[1].set_title(r"(b) LN Biases $\beta$")

axes[0].set_ylabel(r'$\mathbb{E}[\gamma^{(l)}_i]$')
axes[1].set_ylabel(r'$\mathbb{E}[\beta^{(l)}_i]$')
# axes[1].legend(layers, title="", loc='lower left', bbox_to_anchor=(-0.25, -.5), frameon=False, ncols=5)
axes[1].legend(layers, title="Layer", loc='center', bbox_to_anchor=(1.3, 1.05), frameon=False, ncols=1)
ax.set_xlabel('Step $t$')

# plt.tight_layout(pad=0.1)
fig.set_facecolor('white')
fig.subplots_adjust(hspace=0.5)
fig.savefig(FIGURES / "lr-ln.pdf")
# Increase space between subplots to fit legend
# plt.tight_layout()

In [None]:
all_lns = {
    "Unembedding": unembedding_lns,
    "Block 1 Pre-Attention": block_1_attn_lns,
    "Block 1 Pre-MLP": block_1_mlp_lns,
    "Block 2 Pre-Attention": block_2_attn_lns,
    "Block 2 Pre-MLP": block_2_mlp_lns,
}

In [None]:
fig, axes = plt.subplots(2, len(all_lns), figsize=(FULL_WIDTH * 2, FULL_HEIGHT ))

def frac_eff_zero(w):
    return (w.abs() < 1e-1).float().mean().detach().cpu().numpy()

insets = []

for i, (name, lns) in enumerate(list(all_lns.items())[1:] + list(all_lns.items())[:1]):
    axes[0, i].axhline(0, color='black', linestyle='--', alpha=0.5)
    sns.set_palette(sns.color_palette("viridis", 64))
    
    axes[0, i].plot(steps, np.array([w.detach().cpu().numpy() for w, b in lns]), alpha=0.8, linewidth=0.5)
    axes[1, i].plot(steps, np.array([b.detach().cpu().numpy() for w, b in lns]), alpha=0.8, linewidth=0.5)

    w_means = np.array([w.detach().cpu().numpy() for w, b in lns]).mean(axis=-1)
    b_means = np.array([b.detach().cpu().numpy() for w, b in lns]).mean(axis=-1)
    w_stds = np.array([w.detach().cpu().numpy() for w, b in lns]).std(axis=-1)
    b_stds = np.array([b.detach().cpu().numpy() for w, b in lns]).std(axis=-1)

    axes[0, i].plot(steps, w_means, color=BRED)
    axes[0, i].fill_between(steps, w_means - w_stds, w_means + w_stds, alpha=0.25, color=BRED)
    axes[1, i].plot(steps, b_means, color=BRED)
    axes[1, i].fill_between(steps, b_means - b_stds, b_means + b_stds, alpha=0.25, color=BRED)

    axes[0, i].set_title(f"{name} LN Weight")
    axes[1, i].set_title(f"{name} LN Bias")
    
    # axes[0, i].legend(labels=['Mean'], handles=handles, loc='upper left', frameon=False)
    # axes[1, i].legend(labels=['Mean'], handles=handles, loc='upper left', frameon=False)
    # axes[0, i].legend().remove()
    # axes[1, i].legend().remove()

    axinset0 = axes[0, i].inset_axes([0.10, 0.1, 0.3, 0.3])
    axinset1 = axes[1, i].inset_axes([0.10, 0.1, 0.3, 0.3])

    frac_w_eff_zero = np.array([frac_eff_zero(w) for w, b in lns])
    frac_b_eff_zero = np.array([frac_eff_zero(b) for w, b in lns])
    
    axinset0.plot(steps, frac_w_eff_zero, color=BRED)
    axinset1.plot(steps, frac_b_eff_zero, color=BRED)

    insets.extend([axinset0, axinset1])
    

plot_transitions(axes, TRANSITIONS, colors=LR_TRANSITION_COLORS)
plot_transitions(insets, TRANSITIONS, colors=LR_TRANSITION_COLORS)
plt.tight_layout()

for ax in [*axes.flatten(), *insets]:
    ax.set_xscale('log')
    # ax.set_yscale('symlog')
    ax.set_xlim(100, 500000)

for ax in insets:
    ax.set_ylim(0, 1)
    ax.set_ylabel("% < 0.1")
    ax.set_yticks([])
    ax.set_xticks([])  # Remove x-ticks
    ax.set_xticklabels([])  # Remove x-tick labels
    ax.grid(False)

fig.savefig(FIGURES / "lr-ln-all.pdf")

## Attention 

In [None]:
def plot_attention_patterns(df: pd.DataFrame, num_blocks: int, num_heads: int, num_tokens: int, title=None, save: Optional[str] = None, figsize=(20, 25), logx=False, logy=False, metric="mean", label="Entropy", full_block=True, full_head=True, y_axis=True):
    fig = plt.figure(figsize=figsize)

    if title:
        plt.suptitle(label + "\n" + title)

    num_cols = num_blocks * 2
    num_rows = int(full_block) + int(full_head) + num_heads

    fig.set_facecolor('white')

    axes = []

    if full_block:
        # Create subplot for mean entropy of first two blocks
        ax0 = plt.subplot2grid((num_rows, num_cols), (0, 0), colspan=num_cols)
        block_cmap = sns.color_palette("viridis", num_blocks * 2)

        for b in range(num_blocks):
            for x_or_y in (1, 0):
                ax0.plot(df.step, df[f"block_{b}/{'x' if not x_or_y else 'y'}/{metric}"], label=f"Block {b + 1} {'xs' if not x_or_y else 'ys'}", color=block_cmap[b+1-x_or_y])

        ax0.set_title("Blocks")
        ax0.set_xlabel("Step, $t$")
        ax0.set_ylabel(label)
        ax0.legend()

        axes.append(ax0)

    if full_head:
        # Create subplots for each block, showing entropy in different heads
        ax1 = [plt.subplot2grid((num_rows, num_cols), (int(full_block), i), colspan=1) for i in range(num_blocks * 2)]
        head_cmap = sns.color_palette("viridis", num_heads)
            
        for b in range(num_blocks):
            for x_or_y in (1, 0):
                _ax1 = ax1[2 * b + 1-x_or_y]
                _ax1.set_title(f"Block {b + 1} {'xs' if not x_or_y else 'ys'}")
                # _ax1.set_xlabel("Step, $t$")
                # _ax1.set_ylabel(label)
                
                for h in range(num_heads):
                    series = df[f"block_{b}/head_{h}/{'x' if not x_or_y else 'y'}/{metric}"]
                    _ax1.plot(df.step, series, label=f"Head {h + 1}", color=head_cmap[h])

                _ax1.set_xticks([])
                _ax1.set_xticklabels([])
                
        ax1[0].legend()
        ax1[0].set_ylabel(label)

        axes.extend(ax1)

    # Create subplots for each head in each block, detailing entropy for each token
    ax2 = [plt.subplot2grid((num_rows, num_cols), (i//(num_cols) + int(full_block) + int(full_head), i%(num_cols))) for i in range(num_heads * num_blocks * 2)]
    ax_idx = 0

    for b in range(num_blocks):
        for x_or_y in (0, 1):
            ax2[ax_idx].set_ylabel(label)
            token_cmap = sns.color_palette("viridis" if not x_or_y else "magma", num_tokens)
            for h in range(num_heads):
                _ax2 = ax2[ax_idx]
                _ax2.set_title(f"$b={b + 1}, h={h + 1}, {'x' if not x_or_y else 'y'}_k$")
                # ax2[ax_idx].set_xlabel("Step, $t$")
                # ax2[ax_idx].set_ylabel(label)

                for t in range(int(x_or_y), num_tokens, 2):
                    series = df[f"block_{b}/head_{h}/token_{t}/{metric}"]

                    if y_axis:
                        _ax2.axhline(0, color='black', linestyle='--', linewidth=0.5)

                    _ax2.plot(df.step, series, label=f"Token {'x' if not x_or_y else 'y'} {t + 1}", color=token_cmap[t])
                    
                _ax2.set_xticks([])
                _ax2.set_xticklabels([])
                ax_idx += 1

    for ax in ax2[-4:]:
        ax.set_xlabel("Step $t$")
        ax.set_xticks([10**i for i in range(6)]) + [500_000]
        ax.set_xticklabels([10**i for i in range(6)]) + [500_000]

    axes.extend(ax2)
    for ax in axes:
        if logx:
            ax.set_xscale("log")
        if logy:
            ax.set_yscale("log")

        ax.set_xlim(100, 500_000)

    _plot_transitions(axes)


    # ax2[0].legend()
    # ax2[1].legend()
    viridis = sns.color_palette("viridis", 8)
    magma = sns.color_palette("magma", 8)
    colors = [(viridis[c] if not x_or_y else magma[c]) for c in range(8) for x_or_y in (0, 1)]
    labels = [f"${'x' if not x_or_y else 'y'}_{i + 1}$" for i in range(8) for x_or_y in (0, 1)]

    # Handles (patches)
    # handles = [mpatches.Line(color=colors[i], label=labels[i * 2 + x_or_y]) for i in range(8) for x_or_y in (0, 1)]
    # Handles (lines)
    handles = [plt.Line2D([0,0],[0,0], color=colors[i], label=labels[i]) for i in range(16)]
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.subplots_adjust(bottom=0.2)      
    fig.legend(handles=handles, loc='lower center', ncol=8, frameon=False, bbox_to_anchor=(0.5, 0.01))

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    return fig, axes

In [None]:
for i, (title, metric) in enumerate(zip(("$\hat{H}_{k}^{(b, h)}$", "$V_{k}^{(b, h)}$", "$D_{k}^{(b, h)}$", "Self attention", "Previous token attention", "X tokens attention", "Hardness"), ("entropy_normalized", "variability", "distance", "self_attn", "prev_token_attn", "x_tokens_attn", "hardness"))):
    fig, axes = plot_attention_patterns(
        attn_variabilities, 
        num_blocks=num_blocks, 
        num_heads=num_heads, 
        num_tokens=num_tokens-1, 
        save=False,
        figsize=(FULL_WIDTH, 2.5 * FULL_HEIGHT),
        logx=True,
        metric=metric,
        title=None, # run.config.to_latex(), 
        label=title,
        full_block = False,
        full_head = False,
    )
    for ax in axes:
        ax.set_yscale('log')
        ax.set_ylim(10e-4, 1)
    fig.savefig(FIGURES / (f"lr/lr-attn-{metric}" + ".pdf"))
    plt.show()
    

In [None]:
import warnings

metrics_to_labels = {
    "first_x": r"\alpha_{k, 1}^{(b, h)}",
    "first_y": r"\alpha_{k, 2}^{(b, h)}",
    "prev_token_attn": r"\alpha_{k, k-1}^{(b, h)}",
    "previous_x": r"\alpha_{k, y_{k-1}}^{(b, ha)}",
    "previous_y": r"\alpha_{k, x_{k-1}}^{(b, h)}",
    "x_tokens_attn": r"\alpha_{k, x}^{(b, h)}",
    "y_tokens_attn": r"\alpha_{k, y}^{(b, h)}",
    "self_attn": r"\alpha_{k, k}^{(b, h)}",
}

head_metrics = {
    # "b1h1x": "?",
    "b1h1y": "x_tokens_attn", # It's attending to x. For first y 100% on previous token. For others 0% on previous token but 100% on x
    # "b1h2x": "?",
    # "b1h2y": "?",
    "b1h3x": "y_tokens_attn", # Maybe close to uniform?
    "b1h3y": "prev_token_attn",
    # "b1h4x": "?",
    "b1h4y": "y_tokens_attn", # Almost a particular one?
    # "b2h1x": "?",
    # "b2h1y": "?",
    # "b2h2x": "?",
    # "b2h2y": "?",
    # "b2h3x": "?",
    "b2h3y": "self_attn", # relaxes in R3
    # "b2h4x": "?",
    # "b2h4y": "?",
}

fig, axes = plt.subplots(1, len(head_metrics), figsize=(FULL_WIDTH * 1.5, FULL_HEIGHT))

df = attn_variabilities
colors = sns.color_palette("viridis", 8)

for i, ((id_, metric), ax) in enumerate(zip(head_metrics.items(), axes)):

    try:
        x_or_y = 1 if id_[4] == "y" else 0
        for token in range(x_or_y, 15, 2):
            ax.plot(df.step, df[f"block_{int(id_[1])-1}/head_{int(id_[3])-1}/token_{token}/{metric}"], color=colors[token])

    except:
        warnings.warn(f"Failed to plot {id_} {metric}")
        
    ax.set_title(f"$b={id_[1]}, h={id_[3:5]}$")
    ax.set_xscale("log")
    ax.set_xlim(100, 500_000)
    ax.set_xlabel("Step $t$")
    ax.set_ylabel(f"${metrics_to_labels[metric]}$")

_plot_transitions(axes)
plt.tight_layout()

fig.savefig(FIGURES / "lr/lr-attn-ids.pdf")