# Small-model report

On the small model (L=2, H=4):

- Replication of Raventós et al. (2023) + fitting the various algorithms
- All the analyses (RLCT, PCA, Attention Entropies, Covariance, Weight-staring). 

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

if not "AWS_ACCESS_KEY_ID" in os.environ or not "AWS_SECRET_ACCESS_KEY" in os.environ:
    raise Exception("AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY not found in environment variables. Please set them in .env file.")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pprint import pp
from pathlib import Path
from typing import Optional, Iterable, List, Tuple, Dict, Union, Callable

import seaborn as sns
import pandas as pd
from tqdm import tqdm
import torch 
from torch import nn

import devinterp
import devinfra

from icl.constants import SWEEPS, FIGURES, ANALYSIS
from icl.analysis.utils import get_unique_run

In [None]:
sns.set_theme(style="whitegrid")

SWEEP_ID = "n698i1jy"
SWEEP_FILENAME = "training-runs/small-L-2.yaml"

K = 2

from icl.constants import FIGURES, ANALYSIS
from icl.constants import DEVICE

In [None]:
LVAL = "L_\mathrm{val}"

## Set-up

In [None]:
from icl.analysis.utils import get_sweep_configs

filters = {"task_config": {"num_layers": 2, "num_heads": 4}, "optimizer_config": {"lr": 0.01}}  # TODO: Where are the H=2 runs?
configs = list(get_sweep_configs(SWEEPS / SWEEP_FILENAME, **filters))

print(f"Found {len(configs)} runs.")

In [None]:
# Figure out which checkpoints are available

checkpointers = [config.checkpointer_config.factory() for config in tqdm(configs, desc="Reading checkpoints")]

for checkpointer in tqdm(checkpointers, desc="Loading checkpoints"):
    print(f"Found {len(checkpointer.file_ids)} checkpoints for {checkpointer}")

In [None]:
MS = [config.task_config.num_tasks for config in configs] # [1, 4, 64, 2**10, 2**20]
STEPS = checkpointer.file_ids

## Replication 

In [None]:
import wandb
from devinfra.utils.iterables import filter_objs

api= wandb.Api()
sweep = api.sweep(f"devinterp/icl/{SWEEP_ID}")
runs = list(filter_objs([r for r in sweep.runs], config=filters))

print(f"Found {len(runs)} runs.")

In [None]:
from devinfra.utils.iterables import flatten_dict
from icl.analysis.utils import wandb_runs_to_df

df = wandb_runs_to_df(runs)

In [None]:
pp(list(df.columns))

df

In [None]:
torch.Tensor([1, 2,3]).tolist()

In [None]:
llc_chain_columns = [f'llc-chain/{i}' for i in range(25)]
df[llc_chain_columns] = df[llc_chain_columns].replace("NaN", np.nan)


# Calculate the average of non-NaN values in llc-chain columns
# and the fraction of NaN values
llc_chain_values = df[llc_chain_columns]
mean_llc_chain = llc_chain_values.mean(axis=1, skipna=True)
frac_nan = llc_chain_values.isna().mean(axis=1)

df["llc/mean-fixed"] = mean_llc_chain
df["llc/frac-nan"] = frac_nan
df["log_num_tasks"] = np.log(df["task_config/num_tasks"])

mean_llc_chain, frac_nan

In [None]:
sns.lineplot(data=df, x="_step", y="llc/frac-nan", hue="task_config/num_tasks")
plt.xscale("log")

In [None]:
from matplotlib import patches


TRANSITIONS = [
    (100, 800, 'A1'),
    (800, 10_000, 'A2'),
    (10_000, 28_000, 'B1'),
    (28_000, 280_000, 'B2'),
]

INIT_X = TRANSITIONS[0][0]
FINAL_X = TRANSITIONS[-1][1]

def plot_transitions(axes, **kwargs):
    from icl.figures.colors import plot_transitions as _plot_transitions
    return _plot_transitions(axes, TRANSITIONS, **kwargs)

In [None]:
steps = checkpointers[0].file_ids
steps

In [None]:
from collections import defaultdict
from icl.analysis.evals import ICLEvaluator
from icl.experiments.activations_analysis import iter_models
from icl.train import Run

evals = []
functional_metrics = []
gradient_norms = []

B = 8192
K = 8
D = 4
OOD_MULTIPLIER = 3

def eval_loss(yhats, ys):
    return ((yhats - ys) ** 2).mean(dim=0)[:, 0]

def apply_transformations(ws, xs):
    return xs @ ws.view(B, D, 1)

for log2_M, config in tqdm(enumerate(configs)):
    run = Run(config)
    run.evaluator = ICLEvaluator(
        pretrain_dist=run.pretrain_dist,
        true_dist=run.true_dist,
        max_examples=config.task_config.max_examples,
        eval_batch_size=8192,
        seed=config.task_config.true_seed,
    )
    pretrain_dist_noiseless = run.config.task_config.pretrain_dist_factory().to(
        DEVICE
    )
    noise_std = pretrain_dist_noiseless.std
    pretrain_dist_noiseless.std = 0.

    ws = pretrain_dist_noiseless.task_distribution.sample_tasks(B) # -> B D 
    wpriors = pretrain_dist_noiseless.task_distribution.tasks.mean(dim=0) # -> D
    wpriors = wpriors.repeat(B, 1) # -> B D

    xs = torch.normal(
        mean=0.,
        std=1.,
        size=(B, K, D,),
        device=DEVICE
    )
    ood_xs = OOD_MULTIPLIER * xs

    errors = torch.normal(
        mean=0.,
        std=noise_std,
        size=(B, K, 1,),
        device=DEVICE,
    )

    ys_without_noise = apply_transformations(ws, xs)
    ood_ys_without_noise= OOD_MULTIPLIER * ys_without_noise

    ys = ys_without_noise + errors
    ood_ys = ood_ys_without_noise + errors

    yhats_prior = apply_transformations(wpriors, xs)
    yhats_zero = torch.zeros_like(ys)
    # ood_yhats_prior = apply_transformations(wpriors, ood_xs)
    
    for step, model in zip(steps, iter_models(run.model, run.checkpointer)):
        yhats = model(xs, ys)
        ood_yhats = model(xs, ood_ys)
        # yhats_without_noise = model(xs, ys_without_noise)

        losses = eval_loss(yhats, ys)
        # losses_without_noise = eval_loss(yhats_without_noise, ys_without_noise)
        losses_prior = eval_loss(yhats, yhats_prior)
        losses_zero = eval_loss(yhats, yhats_zero)

        ood_losses = eval_loss(ood_yhats, ood_ys)
        # ood_losses_midpoint = eval_loss(ood_yhats, ood_yhats_prior)
        # ood_losses_without_noise = eval_loss(yhats, ood_ys_without_noise)

        loss = losses.mean()
        loss.backward()

        for n, p in model.named_parameters():
            if p.grad is None:
                continue

            grad_sq_mean = (p.grad ** 2).mean().item()
            grad_sq_std = (p.grad ** 2).std().item()

            gradient_norms.append({
                "m": log2_M,
                "M": 2 ** log2_M,    
                "step": step,
                "layer": n,
                "grad/norm": grad_sq_mean ** 0.5,
                "grad_sq/mean": grad_sq_mean,
                "grad_sq/std": grad_sq_std,
                "numel": p.numel(),
                "loss": loss.item(),
            })          

            p.grad = None 

        for token in range(8):
            functional_metrics.append({
                "m": log2_M,
                "M": 2 ** log2_M,
                "step": step,
                "loss": losses[token].item(),
                "ood_loss": ood_losses[token].item(),
                # "loss_without_noise": losses_without_noise[i],
                # "ood_loss_without_noise": ood_losses_without_noise[i],
                "loss_prior": losses_prior[token].item(),
                "loss_zero": losses_zero[token].item(),
                # "ood_loss_midpoint": ood_losses_midpoint[i],
                "token": token
            })

        evals.append({
            "m": log2_M,
            "M": 2 ** log2_M,
            "step": step,
            "weight_norm": (sum([(p ** 2).sum() for p in model.parameters()]) ** 0.5).item(),
            **run.evaluator(model),
        })


evals = pd.DataFrame(evals)
evals.to_csv(ANALYSIS / "small-model-evals.csv", index=False)
functional_metrics = pd.DataFrame(functional_metrics)
functional_metrics.to_csv(ANALYSIS / "small-model-functional-metrics.csv", index=False)
gradient_norms = pd.DataFrame(gradient_norms)
gradient_norms.to_csv(ANALYSIS / "small-model-gradient-norms.csv", index=False)

In [None]:
evals = pd.read_csv(ANALYSIS / "small-model-evals.csv")
functional_metrics = pd.read_csv(ANALYSIS / "small-model-functional-metrics.csv")
gradient_norms = pd.read_csv(ANALYSIS / "small-model-gradient-norms.csv")

In [None]:
evals.columns

In [None]:
from icl.figures.derivatives import d_dt, d_dlogt, dlog_dlogt

In [None]:
for log2_M in range(20):
    # Filter the DataFrame and compute the derivatives
    mse_values = evals.loc[evals.m == log2_M, "pretrain/mse"].values
    llc_values = df.loc[df["task_config/num_tasks"] == int(2 ** log2_M), "llc/mean-fixed"].values
    weightnorm_values = evals.loc[evals.m == log2_M, "weight_norm"].values
    
    # Compute the derivatives using your d_dlogt function
    dloss_dlogt_values = d_dlogt(steps, mse_values)
    dllc_dlogt_values = d_dlogt(steps, llc_values)
    dweightnorm_dlogt_values = d_dlogt(steps, weightnorm_values)

    # Assign the computed derivatives back to the original DataFrame
    evals.loc[evals.m == log2_M, "dloss_dlogt"] = dloss_dlogt_values
    evals.loc[evals.m == log2_M, "dllc_dlogt"] = dllc_dlogt_values
    evals.loc[evals.m == log2_M, "dweightnorm_dlogt"] = dweightnorm_dlogt_values


In [None]:
evals.columns

In [None]:
from matplotlib import colors, lines, patches

def get_reduced_viridis_palette(num_colors, ratio=3 / 3.5):
    return sns.color_palette("viridis", int(num_colors // ratio))[:num_colors]

LINE_PALETTE = get_reduced_viridis_palette(21-5)
# num_palette_steps = int((21 * 3.75) // 3)
# LINE_PALETTE = [sns.color_palette("coolwarm", num_palette_steps)[i] for i in [*range(10), *range(num_palette_steps - 10, num_palette_steps)]]

print(LINE_PALETTE)
# "viridis"
ALPHA=0.75

fig, axes = plt.subplots(2, 2, figsize=(20, 6))

ax = axes[0, 0]

filtered_evals = evals.loc[(evals.step != 20408) & (evals['m'] > 5)]

sns.lineplot(data=filtered_evals, x="step", y="pretrain/mse", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$L_\mathrm{val}$")
ax.legend().remove()

ax = axes[1, 0]

sns.lineplot(data=filtered_evals, x="step", y="dloss_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\delta L_\mathrm{val}/\delta \log t$")
ax.legend().remove()
ax.set_ylim(-2.5, 2.5)

ax = axes[0, 1]

sns.lineplot(data=df.loc[(df._step != 20408) & (df.log_num_tasks > 5)], x="_step", y="llc/mean-fixed", hue="log_num_tasks", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\hat\lambda$")
ax.legend().remove()

ax = axes[1, 1]

sns.lineplot(data=filtered_evals, x="step", y="dllc_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\delta \hat\lambda/\delta \log t$")
ax.set_ylim(-500, 500)
ax.legend().remove()

# ax = axes[0, 2]

# sns.lineplot(data=filtered_evals, x="step", y="weight_norm", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
# ax.set_xscale("log")
# ax.set_xlim(INIT_X, FINAL_X)
# ax.set_xlabel("Step, $t$")
# ax.set_ylabel("$|w_t|$")
# ax.legend().remove()

# ax.set_ylim(20, 800)

# ax = axes[1, 2]

# sns.lineplot(data=filtered_evals, x="step", y="dweightnorm_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
# ax.set_xscale("log")
# ax.set_xlim(INIT_X, FINAL_X)
# ax.set_xlabel("Step, $t$")
# ax.set_ylabel("$\delta|w_t|/\delta\log t$")
# ax.legend().remove()

_patches = plot_transitions(axes)


handles, labels = ax.get_legend_handles_labels()

# Create custom handles
# handles = [lines.Line2D([], [], color=custom_colors[i], marker='o', linestyle='', label=custom_labels[i]) for i in range(len(custom_labels))]

# Add the custom handles to the existing ones
# handles.extend(custom_handles)

# Now, you can create the legend with the updated handles and custom labels
# axes[0, 2].legend(handles=handles, title="$\log_2 M$", loc='center right', bbox_to_anchor=(0.9, .5))

# plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to make room for colorbar
# plt.tight_layout()

cbar_ax = fig.add_axes([0.93, 0.125, 0.02, 0.33])  # Adjust as necessary for position and size
custom_cmap = colors.LinearSegmentedColormap.from_list("custom_cmap", LINE_PALETTE)

sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=0+5, vmax=20), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions =  [5, 10, 15, 20]  # Positions for each color
tick_labels = map(str, tick_positions)  # Labels for each color
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$\log_2 M$")


stages_legend_ax = fig.add_axes([0.945, 0.68, 0.02, 0.25])  # Adjust as necessary for position and size
stages_legend_ax.axis('off')
stages_legend_ax.legend(handles=_patches, title="Stage", loc='upper center', bbox_to_anchor=(0, .5))

fig.set_facecolor('white')


In [None]:
from matplotlib import colors, lines, patches

def get_reduced_viridis_palette(num_colors, ratio=3 / 3.5):
    return sns.color_palette("viridis", int(num_colors // ratio))[:num_colors]

LINE_PALETTE = get_reduced_viridis_palette(5)
# num_palette_steps = int((21 * 3.75) // 3)
# LINE_PALETTE = [sns.color_palette("coolwarm", num_palette_steps)[i] for i in [*range(10), *range(num_palette_steps - 10, num_palette_steps)]]

print(LINE_PALETTE)
# "viridis"
ALPHA=0.75

fig, axes = plt.subplots(2, 2, figsize=(20, 6))

ax = axes[0, 0]

filtered_evals = evals.loc[(evals.step != 20408) & (evals['m'] <= 5)]

sns.lineplot(data=filtered_evals, x="step", y="pretrain/mse", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$L_\mathrm{val}$")
ax.legend().remove()

ax = axes[1, 0]

sns.lineplot(data=filtered_evals, x="step", y="dloss_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\delta L_\mathrm{val}/\delta \log t$")
ax.legend().remove()
ax.set_ylim(-2.5, 2.5)

ax = axes[0, 1]

sns.lineplot(data=df.loc[(df._step != 20408) & (df.log_num_tasks <= 5)], x="_step", y="llc/mean-fixed", hue="log_num_tasks", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\hat\lambda$")
ax.legend().remove()

ax = axes[1, 1]

sns.lineplot(data=filtered_evals, x="step", y="dllc_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_xscale("log")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\delta \hat\lambda/\delta \log t$")
ax.set_ylim(-500, 500)
ax.legend().remove()

# ax = axes[0, 2]

# sns.lineplot(data=filtered_evals, x="step", y="weight_norm", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
# ax.set_xscale("log")
# ax.set_xlim(INIT_X, FINAL_X)
# ax.set_xlabel("Step, $t$")
# ax.set_ylabel("$|w_t|$")
# ax.legend().remove()

# ax.set_ylim(20, 800)

# ax = axes[1, 2]

# sns.lineplot(data=filtered_evals, x="step", y="dweightnorm_dlogt", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
# ax.set_xscale("log")
# ax.set_xlim(INIT_X, FINAL_X)
# ax.set_xlabel("Step, $t$")
# ax.set_ylabel("$\delta|w_t|/\delta\log t$")
# ax.legend().remove()

_patches = plot_transitions(axes)


handles, labels = ax.get_legend_handles_labels()

# Create custom handles
# handles = [lines.Line2D([], [], color=custom_colors[i], marker='o', linestyle='', label=custom_labels[i]) for i in range(len(custom_labels))]

# Add the custom handles to the existing ones
# handles.extend(custom_handles)

# Now, you can create the legend with the updated handles and custom labels
# axes[0, 2].legend(handles=handles, title="$\log_2 M$", loc='center right', bbox_to_anchor=(0.9, .5))

# plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to make room for colorbar
# plt.tight_layout()

cbar_ax = fig.add_axes([0.93, 0.125, 0.02, 0.33])  # Adjust as necessary for position and size
custom_cmap = colors.LinearSegmentedColormap.from_list("custom_cmap", LINE_PALETTE)

sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=0, vmax=5), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions =  [0, 1, 2, 3, 4, 5]  # Positions for each color
tick_labels = map(str, tick_positions)  # Labels for each color
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$\log_2 M$")


stages_legend_ax = fig.add_axes([0.945, 0.68, 0.02, 0.25])  # Adjust as necessary for position and size
stages_legend_ax.axis('off')
stages_legend_ax.legend(handles=_patches, title="Stage", loc='upper center', bbox_to_anchor=(0, .5))

fig.set_facecolor('white')

# Functional

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Average across token colum
functional_metrics_averages = functional_metrics.groupby(["m", "step"]).mean().reset_index()

ax = axes[0]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)

ax = axes[1]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss_prior", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_ylim(0.01, 10)

ax = axes[2]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss_zero", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_ylim(0.01, 10)

ax = axes[3]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="ood_loss", hue="m", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)

for ax in axes.flatten():
    ax.set_yscale('log')
    ax.set_xscale("log")
    ax.set_xlim(100, 400_000)
    ax.set_xlabel("Step, $t$")
    ax.legend().remove()

add_milestones(axes) #, alpha=0.8)

In [None]:
LINE_PALETTE = get_reduced_viridis_palette(8)
ALPHA = 1

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Average across token colum
functional_metrics_m20 = functional_metrics.loc[functional_metrics.m == 20]

ax = axes[0]
sns.lineplot(data=functional_metrics_m20.loc[functional_metrics_m20.step != 20408], x="step", y="loss", hue="token", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)

ax = axes[1]
sns.lineplot(data=functional_metrics_m20.loc[functional_metrics_m20.step != 20408], x="step", y="loss_prior", hue="token", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_ylim(0.01, 10)

ax = axes[2]
sns.lineplot(data=functional_metrics_m20.loc[functional_metrics_m20.step != 20408], x="step", y="loss_zero", hue="token", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
ax.set_ylim(0.01, 10)

ax = axes[3]
sns.lineplot(data=functional_metrics_m20.loc[functional_metrics_m20.step != 20408], x="step", y="ood_loss", hue="token", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)

for ax in axes.flatten():
    ax.set_yscale('log')
    ax.set_xscale("log")
    ax.set_xlim(100, 400_000)
    ax.set_xlabel("Step, $t$")
    ax.legend().remove()

add_milestones(axes) #, alpha=0.8)

In [None]:
from matplotlib import colors as mcolors

# LINE_PALETTE="viridis"
# ALPHA=1

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for token, log2_M in enumerate([0, 1, 3, 6, 20]):
    # Average across token colum
    ax = axes[token]
    functional_metrics_specific = functional_metrics.loc[functional_metrics.m == log2_M]
    sns.lineplot(data=functional_metrics_specific, x="step", y="loss", hue="token", palette=LINE_PALETTE, alpha=ALPHA, ax=ax)
    ax.set_title(f"$M = 2^{{{log2_M}}}$")


for ax in axes.flatten():
    ax.set_yscale('log')
    ax.set_xscale("log")
    ax.set_xlim(INIT_X, FINAL_X)
    ax.set_xlabel("Step, $t$")
    ax.set_ylabel("$L_\mathrm{val}$")
    ax.legend().remove()

add_milestones(axes, alpha=0.2) #, alpha=0.25)

cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7])  # Adjust as necessary for position and size

custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", LINE_PALETTE)

sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=1, vmax=8), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions = range(1, len(LINE_PALETTE)+1)  # Positions for each color
tick_labels = [f"${i}$" for i in range(1, len(LINE_PALETTE) + 1)] # Replace with your labels
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$k$")

plt.tight_layout(rect=[0, 0, 0.91, 1])  # Adjust layout to make room for colorbar


In [None]:
from matplotlib import colors as mcolors

# LINE_PALETTE="viridis"
# ALPHA=1

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for token, log2_M in enumerate([0, 1, 3, 6, 20]):
    # Average across token colum
    ax = axes[token]
    functional_metrics_specific = functional_metrics.loc[functional_metrics.m == log2_M]

    icl_score = functional_metrics_specific.loc[functional_metrics_specific.token == 7, "loss"].values - functional_metrics_specific.loc[functional_metrics_specific.token == 4, "loss"].values
    sns.lineplot(x=steps, y=icl_score, alpha=ALPHA, ax=ax)
    ax.set_title(f"$M = 2^{{{log2_M}}}$")


for ax in axes.flatten():
    # ax.set_yscale('symlog')
    ax.set_xscale("log")
    ax.set_xlim(INIT_X, FINAL_X)
    ax.set_xlabel("Step, $t$")
    ax.set_ylabel("$L_\mathrm{val}$")
    ax.legend().remove()

add_milestones(axes, alpha=0.2) #, alpha=0.25)

cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7])  # Adjust as necessary for position and size

custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", LINE_PALETTE)

sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=1, vmax=8), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions = range(1, len(LINE_PALETTE)+1)  # Positions for each color
tick_labels = [f"${i}$" for i in range(1, len(LINE_PALETTE) + 1)] # Replace with your labels
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$k$")

plt.tight_layout(rect=[0, 0, 0.91, 1])  # Adjust layout to make room for colorbar


In [None]:
wpriors_over_m = []
wprior_norms_over_m = []

for log2_M, config in tqdm(enumerate(configs)):
    run = Run(config)
    run.evaluator = ICLEvaluator(
        pretrain_dist=run.pretrain_dist,
        true_dist=run.true_dist,
        max_examples=config.task_config.max_examples,
        eval_batch_size=8192,
        seed=config.task_config.true_seed,
    )
    pretrain_dist_noiseless = run.config.task_config.pretrain_dist_factory().to(
        DEVICE
    )
    noise_std = pretrain_dist_noiseless.std
    pretrain_dist_noiseless.std = 0.

    ws = pretrain_dist_noiseless.task_distribution.sample_tasks(B) # -> B D 
    wpriors = pretrain_dist_noiseless.task_distribution.tasks.mean(dim=0) # -> D
    wpriors_over_m.append(wpriors)
    wprior_norms_over_m.append(wpriors.norm())

In [None]:
plt.plot(np.arange(21), [w.item() for w in wprior_norms_over_m])
plt.yscale("log")

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Average across token colum
for log2_M in range(21):
    functional_metrics_averages.loc[functional_metrics_averages.m == log2_M, "loss_zero_norm"] = functional_metrics_averages.loc[functional_metrics_averages.m == log2_M, "loss_zero"] / wprior_norms_over_m[log2_M].item()
    functional_metrics_averages.loc[functional_metrics_averages.m == log2_M, "loss_prior_norm"] = functional_metrics_averages.loc[functional_metrics_averages.m == log2_M, "loss_prior"] / wprior_norms_over_m[log2_M].item()

ax = axes[0]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss", hue="m", palette="viridis", alpha=0.5, ax=ax)

ax = axes[1]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss_prior_norm", hue="m", palette="viridis", alpha=0.5, ax=ax)
ax.set_ylim(0.01, 1000)

ax = axes[2]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="loss_zero_norm", hue="m", palette="viridis", alpha=0.5, ax=ax)
ax.set_ylim(0.01, 1000)

ax = axes[3]
sns.lineplot(data=functional_metrics_averages.loc[functional_metrics_averages.step != 20408], x="step", y="ood_loss", hue="m", palette="viridis", alpha=0.5, ax=ax)

for ax in axes.flatten():
    ax.set_yscale('log')
    ax.set_xscale("log")
    ax.set_xlim(100, 400_000)
    ax.set_xlabel("Step, $t$")
    ax.legend().remove()

add_milestones(axes)

# Show that using this validation loss is reasonable. 

In [None]:
from torch.nn import functional as F
from icl.config import ICLConfig
from devinfra.utils.seed import set_seed

def get_first_t_batches(config: ICLConfig, t=100):
    """
    Initialise and train an InContextRegressionTransformer model, tracking
    various metrics.
    """
    run = Run(config)
    num_steps = config.num_steps
    sampling_seed = config.task_config.sampling_seed if config.task_config.sampling_seed is not None else config.task_config.pretrain_seed * num_steps

    batches = []

    for step in range(t):
        set_seed(
            sampling_seed + step
        )  # For reproducibility if we resume training

        xs, ys = run.pretrain_dist.get_batch(
            num_examples=config.task_config.max_examples,
            batch_size=config.batch_size,
        )

        batches.append((xs, ys))

    return batches

def get_first_t_batch_losses(config: ICLConfig, t=100):
    run = Run(configs[20])
    batches = get_first_t_batches(config, t)
    batch_losses = []

    for model in iter_models(run.model, run.checkpointer, verbose=True):
        for xs, ys in batches:
            yhats = model(xs, ys)
            batch_losses.append(F.mse_loss(ys, yhats).item())

    return batch_losses

m20_batch_losses = get_first_t_batch_losses(configs[20], t=100)


In [None]:
m20_batch_losses_np = [m20_batch_losses[I:I + 100] for I in range(0, len(m20_batch_losses), 100)]
m20_batch_losses_np = np.array(m20_batch_losses_np)
m20_batch_losses_np.shape

In [None]:
m20_batch_losses_cumsum = np.cumsum(m20_batch_losses_np, axis=1)
m20_batch_losses_cumavg = m20_batch_losses_cumsum / np.arange(1, 101)

m20_batch_losses_df = pd.DataFrame([{"step": step, "loss": loss, "b": b} for step, losses in zip(steps, m20_batch_losses_cumavg) for b, loss in enumerate(losses)])

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

fig, axes = plt.subplots(1, 2, figsize=(20, 5))


ax = axes[0]
# First plot
m20_functional_metrics = functional_metrics_averages.loc[functional_metrics_averages.m == 20]
sns.lineplot(data=m20_batch_losses_df, x="step", y="loss", hue="b", palette=get_reduced_viridis_palette(100), alpha=0.5, ax=ax, legend=None)
sns.lineplot(data=m20_functional_metrics, x="step", y="loss", color=sns.color_palette('bright')[1], alpha=1, ax=ax, linewidth=2, label="$L_\mathrm{val}$")
ax.set_yscale('log')
ax.set_xscale("log")
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\mathrm{Loss}$")
ax.set_xlim(INIT_X, FINAL_X)
ax.legend(loc="lower left")

ax=axes[1]

for b in range(100):
    m20_batch_loss_slope = d_dlogt(steps, m20_batch_losses_cumavg[:, b])
    m20_batch_losses_df.loc[m20_batch_losses_df.b == b, "slope"] = m20_batch_loss_slope

m20_val_loss_slopes = d_dlogt(steps, m20_functional_metrics.loss.values)
sns.lineplot(data=m20_batch_losses_df, x="step", y="slope", hue="b", palette=get_reduced_viridis_palette(100), alpha=0.5, ax=ax, legend=None)
sns.lineplot(data=m20_functional_metrics, x="step", y=m20_val_loss_slopes, color=sns.color_palette('bright')[1], alpha=0.75, ax=ax, linewidth=2, label="$\delta L_\mathrm{val}/\delta \log t$")
# ax.set_yscale('log')
ax.set_xscale("log")
ax.set_xlabel("Step, $t$")
ax.set_ylabel("$\delta \mathrm{Loss}/\delta \log t$")
ax.set_xlim(INIT_X, FINAL_X)
ax.set_ylim(-2, 1)
ax.legend(loc="lower left")

# Create an inset for the second plot

# ax_inset = inset_axes(ax, width="30%", height="30%", loc='upper right')

# # Second plot (inset)
# sns.lineplot(data=m20_batch_losses_df, x="step", y="loss", hue="b", palette=get_reduced_viridis_palette(100), alpha=0.5, ax=ax_inset, legend=None)
# sns.lineplot(data=m20_functional_metrics, x="step", y="loss", color=sns.color_palette('deep')[3], alpha=1, ax=ax_inset)
# ax_inset.set_xlim(10_000, 250_000)
# ax_inset.set_ylim(1.75, 2.25)
# ax_inset.set_xscale("log")
# ax_inset.set_yscale('log')

add_milestones(axes)

# plt.show()

cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7])  # Adjust as necessary for position and size

custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", get_reduced_viridis_palette(100))

sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=1, vmax=101), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions = list(range(1, 101, 10)) + [100] # Positions for each color
tick_labels = ["1"] +  [f"${i}$" for i in range(10, 101, 10)]  # Replace with your labels
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$b$")


## Fitting

In [None]:
def get_milestone_indices(steps, milestones):
    milestone_indices = []
    for step in steps:
        # Find the index of the milestone that the current step falls into
        index = next((i for i, milestone in enumerate(milestones) if milestone[0] <= step < milestone[1]), None)
        milestone_indices.append(index if index is not None else 'Out of defined milestones')
    return milestone_indices

milestones_of_steps = get_milestone_indices(steps, TRANSITIONS)

In [None]:
run = Run(configs[0])
sum(p.numel() for p in run.model.parameters())

In [None]:
from icl.analysis.baselines import dmmse_predictor, ridge_predictor
from icl.tasks import TaskDistribution, DiscreteTaskDistribution, RegressionSequenceDistribution

class DMMSE(nn.Module):
    def __init__(self, dist: RegressionSequenceDistribution, noise_variance: float, learn_prior: bool = False):
        super().__init__()

        self.prior = dist.task_distribution
        self.noise_variance = nn.Parameter(torch.tensor(noise_variance))

        if learn_prior:
            self.tasks = nn.Parameter(self.prior.tasks)
            self.prior.tasks = self.tasks
    
    def forward(self, xs, ys):
        return dmmse_predictor(xs, ys, self.prior, self.noise_variance)


class Ridge(nn.Module):
    def __init__(self, noise_variance: float):
        super().__init__()

        self.noise_variance = nn.Parameter(torch.tensor(noise_variance))


    def forward(self, xs, ys):
        return ridge_predictor(xs, ys, self.noise_variance)


def fit_baseline_predictor(baseline: nn.Module, model: nn.Module, dist: RegressionSequenceDistribution, num_steps: int=1000, lr: float=0.0001, device: str = "cpu", batch_size=128, num_examples=8, verbose=True):
    optimizer = torch.optim.Adam(baseline.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # We're fitting just a single parameter (the noise variance)

    if verbose:
        losses = []
        sigmas = []

    for step in tqdm(range(num_steps), desc="Fitting..."):
        optimizer.zero_grad()

        # Get a batch of data
        xs, ys = dist.get_batch(batch_size=batch_size, num_examples=num_examples)
        xs = xs.to(device)
        ys = ys.to(device)

        # Get the predictions of the reference model
        with torch.no_grad():
            yhats = model(xs, ys)

        # Get the predictions of the baseline
        baseline_preds = baseline(xs, ys)

        # Update the baseline to be closer to the reference model
        loss = criterion(baseline_preds, yhats)
        loss.backward()
        optimizer.step()

        if verbose:
            losses.append(loss.item())
            sigmas.append(baseline.noise_variance.item())

    if verbose:
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        plt.suptitle(f"Baseline fitting {baseline.__class__.__name__} on {dist.task_distribution.__class__.__name__}")
        axes[0].plot(losses)
        axes[0].set_title("Loss")
        axes[1].plot(sigmas)
        axes[1].set_title("Noise variance")
        plt.show()

    return baseline

def eval_delta_predictor(baseline: nn.Module, model: nn.Module, xs, ys, device: str = "cpu"):
    baseline_preds = baseline(xs, ys)
    preds = model(xs, ys)

    return nn.MSELoss()(baseline_preds, preds)

In [None]:
from devinfra.utils.iterables import flatten_dict
from icl.train import Run
import random

fit_baseline_results = []

lr = 0.01
num_steps = 2_00
LEARN_PRIOR = False

for config, checkpointer in zip(configs, checkpointers):
    run = Run(config)
    run.model.load_state_dict(checkpointer[-1]["model"])

    print("Evaluating", run.config.to_slug())

    batch_size = run.config.batch_size
    num_examples = run.config.task_config.max_examples
    
    noise = run.config.task_config.noise_variance

    learned_dmmse_pretrain = DMMSE(run.pretrain_dist, noise_variance=noise, learn_prior=LEARN_PRIOR)
    learned_ridge_pretrain = Ridge(noise_variance=noise)
    learned_ridge_true = Ridge(noise_variance=noise)

    init_learned_dmmse_pretrain_delta = eval_delta_predictor(learned_dmmse_pretrain, run.model, run.evaluator.pretrain_xs, run.evaluator.pretrain_ys, device=DEVICE)
    init_learned_ridge_pretrain_delta = eval_delta_predictor(learned_ridge_pretrain, run.model, run.evaluator.pretrain_xs, run.evaluator.pretrain_ys, device=DEVICE)
    init_learned_ridge_true_delta = eval_delta_predictor(learned_ridge_true, run.model, run.evaluator.true_xs, run.evaluator.true_ys, device=DEVICE)
    
    fit_baseline_predictor(learned_dmmse_pretrain, run.model, run.pretrain_dist, num_steps=num_steps, lr=lr, device=DEVICE, batch_size=batch_size, num_examples=num_examples)
    fit_baseline_predictor(learned_ridge_pretrain, run.model, run.pretrain_dist, num_steps=num_steps, lr=lr, device=DEVICE, batch_size=batch_size, num_examples=num_examples)
    fit_baseline_predictor(learned_ridge_true, run.model, run.true_dist, num_steps=num_steps, lr=lr, device=DEVICE, batch_size=batch_size, num_examples=num_examples)

    learned_dmmse_pretrain_delta = eval_delta_predictor(learned_dmmse_pretrain, run.model, run.evaluator.pretrain_xs, run.evaluator.pretrain_ys, device=DEVICE)
    learned_ridge_pretrain_delta = eval_delta_predictor(learned_ridge_pretrain, run.model, run.evaluator.pretrain_xs, run.evaluator.pretrain_ys, device=DEVICE)
    learned_ridge_true_delta = eval_delta_predictor(learned_ridge_true, run.model, run.evaluator.true_xs, run.evaluator.true_ys, device=DEVICE)

    fit_baseline_results.append({
        "step": checkpointer.file_ids[-1],
        "config": run.config.to_slug(),
        "learned_dmmse_pretrain/init_delta": init_learned_dmmse_pretrain_delta.item(),
        "learned_ridge_pretrain/init_delta": init_learned_ridge_pretrain_delta.item(),
        "learned_ridge_true/init_delta": init_learned_ridge_true_delta.item(),
        "learned_dmmse_pretrain/delta": learned_dmmse_pretrain_delta.item(),
        "learned_ridge_pretrain/delta": learned_ridge_pretrain_delta.item(),
        "learned_ridge_true/delta": learned_ridge_true_delta.item(),
        "learned_dmmse_pretrain/delta_delta": learned_dmmse_pretrain_delta.item() - init_learned_dmmse_pretrain_delta.item(),
        "learned_ridge_pretrain/delta_delta": learned_ridge_pretrain_delta.item() - init_learned_ridge_pretrain_delta.item(),        
        "learned_ridge_true/delta_delta": learned_ridge_true_delta.item() - init_learned_ridge_true_delta.item(),
        "learned_dmmse_pretrain/noise_variance": learned_dmmse_pretrain.noise_variance.item(),
        "learned_ridge_pretrain/noise_variance": learned_ridge_pretrain.noise_variance.item(),
        "learned_ridge_true/noise_variance": learned_ridge_true.noise_variance.item(),
        **flatten_dict(run.config.task_config.model_dump(), flatten_lists=True)
    })

    pp(fit_baseline_results[-1])



In [None]:
baseline_fits_df = pd.DataFrame(fit_baseline_results)

# Create 2x3 subplots
fig, axes = plt.subplots(3, 2, figsize=(12, 18))

plt.suptitle(
    "L2-H4-K8-D4-err0.125-dmlp64-dembed64-seeds0-1-2-3-n128000000-lr0.01-B256-T500000@t=499999"
)

# Define the labels for rows
row_labels = ['dmmse_pretrain', 'ridge_pretrain', 'ridge_true']

# Loop through the rows
for token, row_label in enumerate(row_labels):
    
    # First column: init_delta and delta
    ax1 = axes[token, 0]
    baseline_fits_df.plot(x='num_tasks', y=f'learned_{row_label}/init_delta', ax=ax1, label=f'{row_label} init_delta')
    baseline_fits_df.plot(x='num_tasks', y=f'learned_{row_label}/delta', ax=ax1, label=f'{row_label} delta')
    ax1.set_title(f"{row_label} init_delta and delta")
    ax1.set_xlabel('num_tasks')
    ax1.set_ylabel('Value')
    
    # Second column: noise_variance
    ax2 = axes[token, 1]
    baseline_fits_df.plot(x='num_tasks', y=f'learned_{row_label}/noise_variance', ax=ax2, label=f'{row_label} noise_variance')
    ax2.axhline(y=0.125, color='r', linestyle='-', label='True noise_variance')

    ax2.legend()

    ax2.set_title(f"{row_label} noise_variance")
    ax2.set_xlabel('num_tasks')
    ax2.set_ylabel('Noise Variance')

for ax in axes.flat:
    ax.set_xscale("log")

plt.tight_layout(rect=[0.1, 0.1, 1, 1])
plt.show()


## PCA

In [None]:
def iter_enumerated_models(model, checkpointer, verbose=False):
    for file_id in tqdm(checkpointer.file_ids, desc="Iterating over checkpoints", disable=not verbose):
        model.load_state_dict(checkpointer.load_file(file_id)["model"])
        yield file_id, model

def iter_models(model, checkpointer, verbose=False):
    for file_id in tqdm(checkpointer.file_ids, desc="Iterating over checkpoints", disable=not verbose):
        model.load_state_dict(checkpointer.load_file(file_id)["model"])
        yield model

In [None]:
from typing import Dict, Iterable, Tuple
from sklearn.decomposition import PCA
from collections import defaultdict
from devinterp.mechinterp.hooks import hook
import numpy as np
from icl.analysis.utils import map_evals_over_checkpoints, get_unique_run
from icl.train import Run
from devinfra.utils.tensors import convert_tensor, ReturnTensor


def extract_activations_over_checkpoints(models: Iterable[nn.Module], xs, ys, *paths, return_type: ReturnTensor="np"):
    def eval_activations(model):
        hooked_model = hook(model, *paths)
        return {k: convert_tensor(v, return_type) for k, v in hooked_model.run_with_cache(xs, ys)[1].items() if k in paths and v is not None}
    
    for model in models:
        yield eval_activations(model)


def get_vectorized_activations_trace(models: Iterable[nn.Module], xs, ys, *paths):
    evals: Dict[str, list] = defaultdict(list)
    
    for activations in extract_activations_over_checkpoints(models, xs, ys, *paths):
        for path, activation in activations.items():
            evals[path].append(activation)

    return {
        k: np.array(v).reshape(len(v), -1) for k, v in evals.items()
    }


def get_pca_activations_trace(models: Iterable[nn.Module], xs, ys, *paths, num_components=3) -> Dict[str, Tuple[PCA, np.ndarray]]:
    results = {}

    for path, activations in get_vectorized_activations_trace(models, xs, ys, *paths).items():
        pca = PCA(n_components=num_components)
        activations_reduced = pca.fit_transform(activations)
        results[path] = pca, activations_reduced

    return results

In [None]:
# demo = Run(configs[2])
# demo_models = iter_models(demo.model, demo.checkpointer, verbose=True)

# demo_logits_pca_3, demo_logits_reduced_3  = get_pca_activations_trace(
#     demo_models, 
#     demo.evaluator.pretrain_xs, 
#     demo.evaluator.pretrain_ys, 
#     "token_sequence_transformer",
#     num_components=3
# )['token_sequence_transformer']

# steps = demo.checkpointer.file_ids

In [None]:
from typing import Optional

def plot_sample_evolution(steps, samples, title="Sample Evolution in 2D Plane", num_points_to_label=10, save: Optional[str] = None, ax: Optional = None, connect_dots=False):
    if ax is None:
        fig, ax = plt.subplots(figsize=(15, 8))
    
    # Main plot
    sc = ax.scatter(samples[:, 0], samples[:, 1], c=steps, cmap='viridis', s=50, alpha=0.6)

    if connect_dots:
        ax.plot(samples[:, 0], samples[:, 1], c='black', alpha=0.2)

    plt.colorbar(sc, ax=ax, label='Steps')
    
    # Label some points
    total_samples = len(samples)
    step = total_samples // num_points_to_label
    for i in range(0, total_samples, step):
        sample_step = steps[i]
        ax.text(samples[i, 0], samples[i, 1], str(sample_step), fontsize=12, ha='right', va='bottom')
        
    ax.set_xlabel('Feature 1')
    ax.set_ylabel('Feature 2')
    ax.set_title(title)
    
    # Inset for explained variance at the bottom right corner with slight transparency
    if save:
        parent_dir = os.path.dirname(save)

        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)

        plt.savefig(save)


def plot_explained_variance(pca, title="Explained Variance", ax: Optional[plt.Axes] = None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(15, 8))

    ax.bar(range(len(pca.explained_variance_ratio_)), pca.explained_variance_ratio_)

    for i, ratio in enumerate(pca.explained_variance_ratio_):
        ax.text(i, ratio, f"{ratio:.2f}", fontsize=12, ha='center', va='bottom')

    ax.set_title(title)
    ax.set_xlabel('Component')
    ax.set_ylabel('Variance')


def plot_sample_evolution_with_inset(steps, samples, pca, title="Sample Evolution in 2D Plane", num_points_to_label=10, save: Optional[str] = None, ax: Optional = None, connect_dots=False):
    if ax is None:
        fig, ax = plt.subplots(figsize=(15, 8))
    
    plot_sample_evolution(steps, samples, title=title, num_points_to_label=num_points_to_label, ax=ax, connect_dots=connect_dots)

    axins = ax.inset_axes([0.7, 0.05, 0.25, 0.25])  # x, y, width, height
    axins.patch.set_alpha(0.5)
    plot_explained_variance(pca, ax=axins)
    
    # Inset for explained variance at the bottom right corner with slight transparency
    if save:
        parent_dir = os.path.dirname(save)

        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)

        plt.savefig(save)

    
def plot_multiple_slices(steps, samples, pca, title="Sample Evolution in 2D Plane", num_points_to_label=10, save: Optional[str] = None, ax: Optional = None, connect_dots=False):
    num_pca_components = samples.shape[-1]
    num_rows = num_pca_components - 1
    fig, ax = plt.subplots(num_rows, num_rows, figsize=(20, 20))

    fig.suptitle(title)

    for i in range(num_pca_components):
        for j in range(i):
            sc = ax[i-1, j].scatter(samples[:, i], samples[:, j], c=steps, cmap='viridis', s=50, alpha=0.6)
            ax[i-1, j].set_xlabel(f'Feature {i}')
            ax[i-1, j].set_ylabel(f'Feature {j}')
            ax[i-1, j].set_title(f'Feature {i} vs Feature {j}')

            if connect_dots:
                ax[i-1, j].plot(samples[:, i], samples[:, j], c='black', alpha=0.2)

            # Label some points
            total_samples = len(samples)
            step = total_samples // num_points_to_label
            for k in range(0, total_samples, step):
                sample_step = steps[k]
                ax[i-1, j].text(samples[k, i], samples[k, j], str(sample_step), fontsize=12, ha='right', va='bottom')

        for j in range(i + 1, num_rows):
            ax[i, j].axis('off')


    ax[0, -1].axis('on')
    plot_explained_variance(pca, ax=ax[0, -1])

    plt.colorbar(sc, ax=ax[0, -1], label='Steps')
    plt.tight_layout()

    if save:
        parent_dir = os.path.dirname(save)

        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)

        plt.savefig(save)

# plot_multiple_slices(steps, demo_logits_reduced_3, demo_logits_pca_3, title=demo.config.to_latex(), connect_dots=True)

In [None]:
for config, checkpointer in zip(configs, checkpointers):
    run = Run(config)
    _steps = checkpointer.file_ids

    _pca, _logits_reduced = get_pca_activations_trace(
        iter_models(run.model, run.checkpointer, verbose=True), 
        run.evaluator.pretrain_xs, 
        run.evaluator.pretrain_ys, 
        "token_sequence_transformer",
        num_components=3
    )['token_sequence_transformer']
    
    plot_multiple_slices(
        _steps, 
        _logits_reduced, 
        _pca, 
        connect_dots=True, 
        title=config.to_latex(), 
        save=FIGURES / ("pca3-logits-" + config.to_slug(delimiter="-") + ".png")
    )

## Attention Entropies

In [None]:
from typing import List, Union
from torchtyping import TensorType
from devinfra.utils.iterables import map_nested

from devinfra.utils.iterables import flatten_dict

from icl.train import Run

def compute_attention_entropies(attn: TensorType["B", "H", "2K", "2K"]):
    """
    Computes the entropy of each token in each head, averaged across the batch, 
    then averages this over heads. 

    """
    
    # Threshold attention weights to avoid log(0)
    log_attention = torch.where(attn > 0, torch.log(attn), torch.tensor(0.0).to(attn.device))
    entropy_per_token = - torch.sum(attn * log_attention, dim=-1).mean(dim=0).squeeze(-1) # TensorType["H", "2K"]

    num_heads, num_tokens = entropy_per_token.shape

    entropy_per_head = entropy_per_token.mean(dim=-1) # TensorType["H"]
    entropy = entropy_per_head.mean() # TensorType[]    
    
    # Each token computes entropy over a variable context length, so we normalize by the maximum possible entropy
    # for a token with a fixed context length.

    max_entropy_per_token = torch.log2(torch.arange(1, num_tokens + 1).to(attn.device)) # TensorType["H", "2K"]
    max_entropy_per_token[0] = 1. # Special case for the first token to avoid dividing by 0

    entropy_per_token_normalized = entropy_per_token / max_entropy_per_token
    entropy_per_head_normalized = entropy_per_token_normalized.mean(dim=-1) # TensorType["H"]
    entropy_normalized = entropy_per_head_normalized.mean() # TensorType[]    

    results: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] = {"mean": entropy, "mean_normalized": entropy_normalized}

    for i in range(num_heads):
        head_results = {"mean": entropy_per_head[i], "mean_normalized": entropy_per_head_normalized[i]}

        for j in range(num_tokens):
            head_results[f"token_{j}"] = entropy_per_token[i, j]
            head_results[f"token_{j}_normalized"] = entropy_per_token_normalized[i, j]

        results[f"head_{i}"] = head_results

    return map_nested(lambda x: convert_tensor(x, "np"), results)


def get_attention_entropies_trace(
    steps: List[int],
    models: Iterable[nn.Module],
    xs: torch.Tensor,
    ys: torch.Tensor,
    **paths,
):
    results = defaultdict(list)
    reverse_paths = {v: k for k, v in paths.items()}

    for activations in extract_activations_over_checkpoints(models, xs, ys, *paths.values(), return_type="pt"):
        for k, v in activations.items():
            path = reverse_paths[k]
            results[path].append(compute_attention_entropies(v))

    values = []

    for i in range(len(steps)):
        value = {}

        for block in results.keys():
            value[block] = results[block][i]
        
        value["step"] = steps[i]
        values.append(flatten_dict(value, flatten_lists=True))

    return pd.DataFrame(values)

In [None]:
def plot_attention_patterns(df: pd.DataFrame, num_blocks: int, num_heads: int, num_tokens: int, title="", save: Optional[str] = None, normalized=False, figsize=(20, 25), logx=False, logy=False):
    fig = plt.figure(figsize=figsize)
    plt.suptitle(title)

    num_cols = num_blocks * 2
    num_rows = 1 + 1 + num_heads

    suffix = "" if not normalized else "_normalized"
    suffix_title = "" if not normalized else " (Normalized)"

    fig.set_facecolor('white')

    # Create subplot for mean entropy of first two blocks
    ax0 = plt.subplot2grid((num_rows, num_cols), (0, 0), colspan=num_cols)
    block_cmap = sns.color_palette("viridis", num_blocks)

    for b in range(num_blocks):
        ax0.plot(df.step, df[f"block_{b}/mean{suffix}"], label=f"block_{b}", color=block_cmap[b])

    ax0.set_title("Blocks")
    ax0.set_xlabel("Step")
    ax0.set_ylabel(f"Entropy{suffix_title}")
    ax0.legend()

    # Create subplots for each block, showing entropy in different heads
    ax1 = [plt.subplot2grid((num_rows, num_cols), (1, i*2), colspan=2) for i in range(num_blocks)]
    head_cmap = sns.color_palette("viridis", num_heads)
    
    for b in range(num_blocks):
        ax1[b].set_title(f"Block {b}")
        ax1[b].set_xlabel("Step")
        ax1[b].set_ylabel(f"Entropy{suffix_title}")
        for h in range(num_heads):
            series = df[f"block_{b}/head_{h}/mean{suffix}"]
            ax1[b].plot(df.step, series, label=f"Head {h}", color=head_cmap[h])

    ax1[0].legend()

    # Create subplots for each head in each block, detailing entropy for each token
    ax2 = [plt.subplot2grid((num_rows, num_cols), (i//(num_cols) + 2, i%(num_cols))) for i in range(num_heads * num_blocks * 2)]
    ax_idx = 0
    token_cmap = sns.color_palette("viridis", num_tokens)


    for h in range(num_heads):
        for b in range(num_blocks):
            for x_or_y in (1, 0):
                ax2[ax_idx].set_title(f"Block {b} Head {h}")
                ax2[ax_idx].set_xlabel("Step")
                ax2[ax_idx].set_ylabel(f"Entropy{suffix_title}")

                for t in range(1-int(x_or_y), num_tokens, 2):
                    series = df[f"block_{b}/head_{h}/token_{t}{suffix}"]
                    ax2[ax_idx].plot(df.step, series, label=f"Token {t}", color=token_cmap[t])
                    
                ax_idx += 1

    ax2[0].legend()
    ax2[1].legend()

    for ax in [ax0, *ax1, *ax2]:
        if logx:
            ax.set_xscale("log")
        if logy:
            ax.set_yscale("log")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)


In [None]:
# demo = Run(configs[2])

# num_blocks = demo.config.task_config.num_layers
# num_heads = demo.config.task_config.num_heads
# num_tokens = demo.config.task_config.max_examples * 2

# df = get_attention_entropies_trace(
#     demo.checkpointer.file_ids,
#     iter_models(demo.model, demo.checkpointer, verbose=True), 
#     demo.evaluator.pretrain_xs, 
#     demo.evaluator.pretrain_ys, 
#     **{f"block_{b}": f"token_sequence_transformer.blocks.{b}.attention.attention_softmax" for b in range(num_blocks)}
# )

# demo_attn_entropy_slug = "attn-S-" + demo.config.to_slug(delimiter="-")

for normalized in (True, False):
    plot_attention_patterns(
        subdf, 
        num_blocks=num_blocks, 
        num_heads=num_heads, 
        num_tokens=num_tokens, 
        title=demo.config.to_latex(), 
        save=FIGURES / (demo_attn_entropy_slug + ".png"),
        figsize=(25, 25),
        normalized=normalized
    )

# df.to_csv(ANALYSIS / (demo_attn_entropy_slug + ".csv"))

In [None]:
for config, checkpointer in zip(configs, checkpointers):
    run = Run(config)
    
    num_blocks = run.config.task_config.num_layers
    num_heads = run.config.task_config.num_heads
    num_tokens = run.config.task_config.max_examples * 2

    subdf = get_attention_entropies_trace(
        checkpointer.file_ids,
        iter_models(run.model, checkpointer, verbose=True), 
        run.evaluator.pretrain_xs, 
        run.evaluator.pretrain_ys, 
        **{f"block_{b}": f"token_sequence_transformer.blocks.{b}.attention.attention_softmax" for b in range(num_blocks)}
    )
    
    slug = "attn-S-" + run.config.to_slug(delimiter="-")

    plot_attention_patterns(
        subdf, 
        num_blocks=num_blocks, 
        num_heads=num_heads, 
        num_tokens=num_tokens, 
        title=run.config.to_latex(), 
        save=FIGURES / (slug + ".png"),
        figsize=(25, 25),
        normalized=True
    )

    subdf.to_csv(ANALYSIS / (slug + ".csv"))

# os.system('say "Your program has finished."')

In [None]:
from icl.train import Run
demo = Run(configs[2])
attn_weights = demo.model.token_sequence_transformer.blocks[0].attention.attention.weight
attn_weights.shape


In [None]:
numel_per_layer = attn_weights.numel()

def num_params_to_gb(num: int):
    return f"{num * (32 / 8) / (10 ** 9):.2f} Gb"

for num_blocks in [2, 4, 8]:
    for num_heads in [2, 4]:
        numel_per_head = numel_per_layer // num_heads

        within_head_cov_size = (numel_per_head ** 2)  * num_heads * num_blocks
        between_head_cov_size = (numel_per_head ** 2) * num_heads * num_heads * (num_blocks-1)

        full_cov_size = (numel_per_layer * num_blocks) ** 2

        reduction = full_cov_size - within_head_cov_size - between_head_cov_size

        print(f"\nL{num_blocks}H{num_heads}")
        print("Full:", f"{full_cov_size:,} ({num_params_to_gb(full_cov_size)})")
        print("Within heads:", f"{within_head_cov_size:,} ({num_params_to_gb(within_head_cov_size)})")
        print("Between heads:", f"{between_head_cov_size:,} ({num_params_to_gb(between_head_cov_size)})")
        print("Reduction:", f"-{reduction:,} (-{reduction/full_cov_size * 100:.2f}%)")

# attn_weights.numel(), f"{(32 // 8 * (attn_weights.numel() * 2 ) ** 2):,}", attn_weights.dtype

In [None]:
def split_attn_weights(W: torch.Tensor, num_heads: int, embed_dim: int, head_size: int):
    W_split = W.view((embed_dim, num_heads, head_size * 3))
    
    for h in range(num_heads):
        yield tuple(W_split[:, h, i*head_size:(i+1)*head_size] for i in range(3))


def plot_attn_weights(W: torch.Tensor, num_heads: int, embed_dim: int, head_size: int, subtitles=("$W_Q^{(h)}$", "$W_K^{(h)}$", "$W_V^{(h)}$"), title="", save: Optional[str] = None):
    heads = list(split_attn_weights(W, num_heads, embed_dim, head_size))

    fig, axs = plt.subplots(num_heads, 3, figsize=(25, 10))
    plt.suptitle(title)

    for h, head in enumerate(heads):
        axs[h, 0].set_ylabel(f"Head {h}\nHead Size")

        for i, mat in enumerate(head):
            axs[h, i].matshow(mat.detach().cpu().numpy().T, cmap='viridis') 

    for i, subtitle in enumerate(subtitles):
        axs[0, i].set_title(subtitle)
        axs[-1, i].set_xlabel("Embedding Dimension")

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    plt.show()


def plot_attn_head_weights(head: torch.Tensor, embed_dim, head_size: int, title="", subtitles=("$W_Q$", "$W_K$", "$W_V$"), save: Optional[str] = None):
    head_Ex3c = head.view((embed_dim, head_size * 3))
    q, k, v = tuple(head_Ex3c[:, i*head_size:(i+1)*head_size].detach().cpu().numpy() for i in range(3))

    fig, ax = plt.subplots(1, 3, figsize=(30, 3.5))
    plt.suptitle(title)

    for i, (mat, subtitle) in enumerate(zip((q, k, v), subtitles)):
        ax[i].set_title(subtitle)
        ax[i].matshow(mat.T, cmap='viridis')
        ax[i].set_xlabel("Embedding Dimension")
        ax[i].set_ylabel("Head Size")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    plt.show()

In [None]:
def plot_attn_eigencomponents(evecs, evals, slug: Optional[str] = None):
    for i in range(1, 1 + len(evals)):
        attn0, attn1 = evecs[:evecs.shape[0]//2, -i], evecs[evecs.shape[0]//2:, -i]

        for layer, attn in enumerate((attn0, attn1)):
            plot_attn_weights(
                torch.Tensor(attn), 
                num_heads=4,
                embed_dim=64, 
                head_size=16, 
                title=f"Eigenvector {i-1} of covariance matrix within attention layer 0 ($\lambda_{i-1}={evals[-i]}$)",
                subtitles=(f"$u_{{Q,{i-1}}}^{{({layer})}}$", f"$u_{{K,{i-1}}}^{{({layer})}}$", f"$u_{{V,{i-1}}}^{{({layer})}}$"),
                save=(FIGURES / (f"cov-attn{layer}-evec{i-1}-" + slug + ".png") if slug else None)
            )

In [None]:
plot_attn_weights(attn0, 4, 64, 16, title="Attention layer 0")

num_heads = 4
attn0_view = attn0.view((64, num_heads, 16 * 3))
heads = [attn0_view[:, h, :] for h in range(num_heads)]
full_head_size = 16 * 3 * 64
pseudo_cov = heads[0].reshape((full_head_size, 1)) * heads[1].reshape((1, full_head_size)) 
head_evals, head_evecs = eigsh(pseudo_cov.detach().cpu().numpy(), k=3, which="LM")
del pseudo_cov


print(head_evals)
plot_attn_head_weights(
    torch.Tensor(head_evecs[:, -1]), 
    64, 
    16, 
    title="Principal eigenvalue of covariance matrix within head 1",
    subtitles=("$u_{Q,1}^{(1)}$", "$u_{K,1}^{(1)}$", "$u_{V,1}^{(1)}$")   
)


In [None]:
from icl.analysis.sample import make_slt_evals

def generate_slt_observables(
    steps: List[int],
    models: Iterable[nn.Module],
    xs: torch.Tensor,
    ys: torch.Tensor,
    **kwargs
):
    trainset = torch.utils.data.TensorDataset(xs, ys)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(xs))
    slt_evals = make_slt_evals(
        dataset=trainset,
        loader=trainloader,
        **kwargs
    )

    for step, model in zip(steps, models):
        yield step, slt_evals(model)

In [None]:
print(checkpointer.file_ids)

In [None]:
def plot_learning_coeff_over_time(steps, lcs, lc_stds, title="", save: Optional[str] = None):

    # Initialize the figure
    fig, ax = plt.subplots(1, 1, figsize=(18, 12))
    ax.set_title(title)

    # Plot mean values as a line
    ax.plot(steps, lcs, 'o-', linewidth=2)
    
    # Add shaded area for error
    ax.fill_between(steps, lcs - lc_stds, lcs + lc_stds, color='gray', alpha=0.4)

    # Labels and scales
    ax.set_xlabel("Steps")
    ax.set_ylabel(r"$\hat\lambda$")
    
    fig.tight_layout()

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    plt.show()


def plot_cov_evals_over_time(steps, *eval_traces, title="", save: Optional[str] = None):

    # Initialize the figure
    fig, ax = plt.subplots(1, 1, figsize=(18, 12))
    ax.set_title(title)

    # Plot mean values as a line
    for i, eval_trace in enumerate(eval_traces):
        ax.plot(steps, eval_trace, 'o-', label=f"Eigenvalue {i}", linewidth=2)
    
    # Labels and scales
    ax.set_xlabel("Steps")
    ax.set_ylabel(r"$\hat\lambda$")
    
    # Show legend
    ax.legend()

    fig.tight_layout()
    
    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)
        plt.savefig(save)

    plt.show()

In [None]:
import yaml
from scipy.sparse.linalg import eigsh



for log2_M in MS:
    wandb.init(entity="devinterp", project="icl")

    gc.collect()
    torch.cuda.empty_cache()

    log2_m = int(np.log2(log2_M))
    config, checkpointer = configs[log2_m], checkpointers[log2_m]
    run = Run(config)

    xs, ys = run.evaluator.pretrain_xs, run.evaluator.pretrain_ys
    trainset = torch.utils.data.TensorDataset(xs, ys)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(xs))
    observables_over_time = []
    
    slt_evals = make_slt_evals(
        dataset=trainset,
        loader=trainloader,
        cores=1,
        lr=1e-5,
        num_draws=100,
        elasticity=1.,
        num_chains=20,
        device="cuda",
        covariance_paths=[
            f"token_sequence_transformer.blocks.{b}.attention.attention"
            for b in range(run.config.task_config.num_layers)
        ],
    )

    slug = run.config.to_slug(delimiter="-")

    min_step = -1

    if os.path.exists(ANALYSIS / f"cov-tmp-{slug}.pt"):
        min_step, observables = torch.load(ANALYSIS / f"cov-tmp-{slug}.pt")
        print(f"Loaded observables from previous step {min_step} from {ANALYSIS / f'cov-tmp-{slug}.pt'}")
    
    for step in STEPS:
        # if step > min_step:
        run.model.load_state_dict(checkpointer.load_file(step)["model"])
        observables = slt_evals(run.model)
        torch.save((step, observables), ANALYSIS / f"cov-tmp-{slug}.pt")

        cov = observables.pop("covariance")
        evals, evecs = eigsh(cov, k=K, which='LM')

        for token in range(1, 1+K):
            observables[f"cov-eval/{token-1}"] = evals[-token]

        observables_over_time.append(observables)
        del cov
        pp(observables)
        plot_attn_eigencomponents(evecs, evals, slug=slug + f"@t={step}")

    plot_cov_evals_over_time(
        STEPS,
        *[[o[f"cov-eval/{k}"] for o in observables_over_time] for k in K],
        title=run.config.to_latex(),
        save=FIGURES / f"cov-eval-of-t-{slug}.png"
    )

    plot_learning_coeff_over_time(
        STEPS,
        [o["mean"] for o in observables_over_time],
        [o["std"] for o in observables_over_time],
        title=run.config.to_latex(),
        save=FIGURES / f"lc-of-t-{slug}.png"
    )

    observables_df = pd.DataFrame(observables_over_time)
    observables_df.to_csv(ANALYSIS / f"cov/cov-{slug}.csv")
    os.remove(ANALYSIS / f"cov-tmp-{slug}.pt")

In [None]:
print(observables_over_time[0].keys())
plot_cov_evals_over_time(
    STEPS,
    [o["cov-eval/0"] for o in observables_over_time],
    [o["cov-eval/1"] for o in observables_over_time],
    title=run.config.to_latex(),
    save=FIGURES / f"cov-eval-of-t-{slug}.png"
)

plot_learning_coeff_over_time(
    STEPS,
    np.array([o["mean"] for o in observables_over_time]),
    np.array([o["std"] for o in observables_over_time]),
    title=run.config.to_latex(),
    save=FIGURES / f"lc-of-t-{slug}.png"
)

observables_df = pd.DataFrame(observables_over_time)
observables_df.to_csv(ANALYSIS / f"cov/cov-{slug}.csv")

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
del run, trainset, trainloader
del slt_evals
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
covariances = observables.pop("covariances")
evals, evecs = eigsh(covariances, k=3, which='LM')

for token, (eval, evec) in enumerate(zip(evals, evecs.T)):
    slug = f"cov-u{token}" + run.config.to_slug(delimiter="-") + f"@t={step}"
    attn0, attn1 = evec.split(64 * 16 * 3)
    

# TODO: Need to rename the new files otherwise you can't tell easily tell what step they come from.


os.system('say "Your program has finished."')


In [None]:
import yaml
from scipy.sparse.linalg import eigsh


# wandb.init(entity="devinterp", project="icl")

for config, checkpointer in zip(configs, checkpointers):
    run = Run(config)
    observables_over_time = []
    
    for step, observables in generate_slt_observables(
        checkpointer.file_ids,
        iter_models(run.model, checkpointer, verbose=True), 
        run.evaluator.pretrain_xs, 
        run.evaluator.pretrain_ys, 
        cores=4,
        lr=1e-5,
        num_draws=100,
        elasticity=1.,
        num_chains=20,
        device="cuda",
        covariance_paths=[
            f"token_sequence_transformer.blocks.{b}.attention.attention"
            for b in range(run.config.task_config.num_layers)
        ]
    ):
        # wandb.log(observables, step=step)
        observables["step"] = step
        covariances = observables.pop("covariances")

        # I only want the two largest eigenvalues in evals and evecs
        covariances = np.linalg.eigvalsh(covariances)
        evals, evecs = eigsh(covariances, k=3, which='LM')
        
        observables_over_time.append(observables)
        print(yaml.dump({
            **observables,
            "covariances": covariances.shape
        }))

        raise NotImplementedError("TODO: Save covariances")

    subdf = pd.DataFrame(observables_over_time)
    slug = "slt-" + run.config.to_slug(delimiter="-")
    subdf.to_csv(ANALYSIS / (slug + ".csv"))

# wandb.finish()
os.system('say "Your program has finished."')

## Activations

In [None]:
from pathlib import Path
import shutil
from icl.config import ICLConfig

import os
from PIL import Image

def gather_images_side_by_side(folder, save: Optional[str] = None, delete: bool = True):
    """
    Assumes folder contains folders that contain pngs. 
    """
    folder = Path(folder)
    folder_paths = folder.glob("*")

    # Create a dictionary to store images by filename
    images_by_filename = {}

    if save:
        save = Path(save)

        if not os.path.exists(save):
            os.makedirs(save)

    # Load images from each folder and organize them by filename
    for folder_path in folder_paths:
        filenames = [f for f in os.listdir(folder_path) if f.endswith('.png')] 
        for filename in filenames:
            img = Image.open(os.path.join(folder_path, filename))
            if filename in images_by_filename:
                images_by_filename[filename].append(img)
            else:
                images_by_filename[filename] = [img]

    # Create comparison images for each unique filename
    for filename, image_list in images_by_filename.items():
        # Calculate the width and height of the result image
        width = sum(img.width for img in image_list)
        height = max(img.height for img in image_list)

        # Create a new image for the comparison
        result_image = Image.new('RGB', (width, height))

        # Paste images side by side
        x_offset = 0
        for img in image_list:
            result_image.paste(img, (x_offset, 0))
            x_offset += img.width

        # Display or save the result image
        if save: 
            result_image.save(save / filename)  # You can replace this with result_image.save() to save the comparison images

    if delete:
        # Delete the temporary folder
        shutil.rmtree(folder)


def plot_activations(config: ICLConfig, activations: Dict[str, torch.Tensor], save: Optional[str] = None):
    B = 1
    E = config.task_config.embed_size
    T = 2 * config.task_config.max_examples
    H = config.task_config.num_heads

    def optionally_rotate(x, name):
        if len(x.shape) != 2:
            raise ValueError("Tensor should have two dimensions.")

        if x.shape[0] > x.shape[1]:
            return x.T, f"{name}.T"
        
        return x, name 

    def separate_attention(qkv: TensorType["B", "T", "C"], num_heads: int, batch_size: int, head_size: int, num_tokens: int):
        return (qkv   
            .view(batch_size, num_tokens, num_heads, 3*head_size)
            .transpose(-2, -3)     
            .split(head_size, dim=-1)
        )

    if save:
        save = Path(save)

        if not os.path.exists(save):
            os.makedirs(save)

    for location, v in activations.items():
        activation_slice = v[0]

        if location.endswith("attention.attention"):
            q, k, v = separate_attention(v, num_heads=H, batch_size=B, head_size=E//H, num_tokens=T)
            qk = q @ k.transpose(-2, -1)
            q, k, qk, v = q[0], k[0], v[0], qk[0]
            
            fig, axs = plt.subplots(H, 4, figsize=(15, 15))

            for j, (name, x) in enumerate(zip(["Q", "K", "QK", "V"], [q, k, qk, v])):
                for h in range(H):
                    ax = axs[h, j]
                    im = ax.matshow(x[h].detach().to("cpu").numpy())
                    ax.set_title(f"{h}.{name}")
                    # fig.colorbar(im, ax=ax)

            plt.suptitle(location)
            plt.tight_layout()

            if save:
                plt.savefig(save / (location + ".png"))
                plt.close(fig)

                del fig
                del axs

        elif len(activation_slice.shape) == 2:
            fig = plt.figure()

            x, location = optionally_rotate(activation_slice, location)
            plt.matshow(x.detach().to("cpu").numpy())
            plt.title(f"{location}")
            # fig.colorbar(im)

            if save:
                plt.savefig(save / (location + ".png"))
                plt.close(fig)

                del fig


        elif len(activation_slice.shape) == 3:  # [heads, xs, ys]
            heads, xs, ys = activation_slice.shape
            fig, axs = plt.subplots(1, heads, figsize=(15, 15))

            for j in range(heads):
                ax = axs[j]
                x, name = optionally_rotate(activation_slice[j], str(j))
                im = ax.matshow(x.detach().to("cpu").numpy())
                ax.set_title(f"{name}")
                # fig.colorbar(im, ax=ax)
            
            plt.suptitle(f"{location}.#")
            plt.tight_layout()

            if save:
                plt.savefig(save / (location + ".png"))
                plt.close(fig)

            del fig
            del axs

        else:
            raise ValueError("Unsupported number of dimensions.")


def compare_activations(config: ICLConfig, model, x: TensorType["B", "D"], y: TensorType["B", 1], save: Optional[str] = None, names: Optional[List[str]] = None):
    B = len(x)
    hooked_model = hook(model)

    activations = {}
    output, activations_ = hooked_model.run_with_cache(x, y)
    activations["x"] = x
    activations["y"] = y
    activations["output"] = output
    activations.update(activations_)

    def activations_per_sample(activations, index, keep_batch_dim=False):
        if keep_batch_dim:
            print({k: type(v) for k, v in activations.items()})
            return {k: v[index].unsqueeze(0) for k, v in activations.items() if v is not None}
        
        return {k: v[index] for k, v in activations.items() if v is not None}

    tmp_folder = Path("tmp")

    names = names or list(map(str, range(B)))

    for (name, b) in zip(names, range(B)):
        activations_b = activations_per_sample(activations, b, keep_batch_dim=True)
        plot_activations(config, activations_b, save=tmp_folder / str(name))

    gather_images_side_by_side(tmp_folder, save=save, delete=True)    

In [None]:
demo = Run.create_and_restore(configs[2])
compare_activations(demo.config, demo.model, demo.evaluator.pretrain_xs[:3], demo.evaluator.pretrain_ys[:3], save=FIGURES / "demo", names=["$x_0$", "$x_1$", "$x_2$"])
# gather_images_side_by_side("tmp", save=FIGURES/"demo", delete=True)    

In [None]:
# Plot a few samples for each model at the end of training
from icl.train import Run

NUM_SAMPLES = 4

for config, checkpointer in zip(configs, checkpointers):
    run = Run.create_and_restore(config)
    
    sample_names = [f"$x_{i}$" for i in range(NUM_SAMPLES)]
    slug = "activations-" + run.config.to_slug(delimiter="-")

    compare_activations(
        run.config, 
        run.model, 
        run.evaluator.pretrain_xs[:NUM_SAMPLES], 
        run.evaluator.pretrain_ys[:NUM_SAMPLES], 
        save=FIGURES / slug, 
        names=sample_names
    )

In [None]:
print(checkpointer.file_ids)

In [None]:
# Plot a few samples for a subset of models over training

MS = [1, 4, 64, 2**10, 2**20]
STEPS = [0, 1_805, 3_084, 15_381, 26_279, 100_262, 153_061, 193_877, 255_102, 306_122, 408_163]

for log2_M in MS:
    log2_m = int(np.log2(log2_M))
    config, checkpointer = configs[log2_m], checkpointers[log2_m]
    run = Run(config)

    for step in STEPS:
        run.model.load_state_dict(checkpointer.load_file(step)["model"])

        sample_names = [f"$x_{i}$" for i in range(NUM_SAMPLES)]
        slug = "activations-" + run.config.to_slug(delimiter="-") + f"@t={step}"

        # TODO: Need to rename the new files otherwise you can't tell easily tell what step they come from.
        compare_activations(
            run.config, 
            run.model, 
            run.evaluator.pretrain_xs[:NUM_SAMPLES], 
            run.evaluator.pretrain_ys[:NUM_SAMPLES], 
            save=FIGURES / slug, 
            names=sample_names
        )
        
        os.system('say "Your program has finished."')

# LLC hyperparams

In [None]:
import yaml
import wandb
from icl.config import get_config
import pandas as pd
from tqdm import tqdm
from devinfra.utils.iterables import flatten_dict

api = wandb.Api()
# sweep = api.sweep("devinterp/icl-llc/d3ctawc7")  # L2H4
sweep = api.sweep("devinterp/icl-llc/ebu13rjw")  # L4H4

def wandb_run_to_df(run):
    history_df = run.history()
    config_dict = get_config(**run.config).model_dump()
    config_dict["analysis_config"] = run.config["analysis_config"]

    del config_dict["logger_config"]
    del config_dict["checkpointer_config"]

    config_dict_flat = flatten_dict(config_dict, flatten_lists=True)
    
    for k, v in config_dict_flat.items():
        if isinstance(v, tuple):
            # Repeat the tuple for the entire length of the DataFrame
            v = [v] * len(history_df)
            
        history_df[k] = v

    return history_df


def wandb_runs_to_df(runs):
    return pd.concat([wandb_run_to_df(run) for run in tqdm(runs, desc="Converting runs to dfs")])


subdf = wandb_runs_to_df(sweep.runs)

In [None]:
num_layers = 2
num_heads = 4
# df.to_csv("../analysis/L4H4-llc-grid-search.csv") 
subdf = pd.read_csv(f"../analysis/L{num_layers}H{num_heads}-llc-grid-search.csv")
# df = pd.read_csv("../analysis/L2H4-llc-grid-search.csv")

In [None]:
# del df
subdf.columns

In [None]:
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
import numpy as np
import seaborn as sns

# Get unique values for lrs, elasticitys, and num_tasks
num_chains = 25
unique_lrs = subdf['analysis_config/lr'].unique()
unique_elasticities = subdf['analysis_config/elasticity'].unique()
unique_num_tasks = subdf['task_config/num_tasks'].unique()

show_std = False

# Sort for visual consistency
unique_lrs.sort()
unique_lrs = unique_lrs[:-1]
unique_elasticities.sort()
unique_num_tasks.sort()

prefix = "" # "thresholded-" # ""
Prefix = "" # "Thresholded " # ""

# Initialize colormap
cmap = plt.cm.viridis

# Create subplots
fig, axes = plt.subplots(len(unique_lrs), len(unique_elasticities), figsize=(15, 15))
fig.set_facecolor('white')
fig.suptitle(f"{Prefix}$\hat\lambda$ hyperparameter sweep ($n_\mathrm{{chains}}={num_chains}$)\n$L={num_layers}, H={num_heads}, t=500k$")

# Loop through the grid
for token, lr in enumerate(unique_lrs):
    for j, elasticity in enumerate(unique_elasticities):
        ax = axes[token, j]

        # Filter DataFrame for specific lr and elasticity
        filtered_df = subdf[(subdf['analysis_config/lr'] == lr) & (subdf['analysis_config/elasticity'] == elasticity)]

        for log_num_tasks in unique_num_tasks:
            task_specific_df = filtered_df[filtered_df['task_config/num_tasks'] == log_num_tasks]

            # Sort by 'num_draws' for plotting
            task_specific_df = task_specific_df.sort_values('_step')

            # Calculate color based on log2(num_tasks)
            color = cmap(np.log2(log_num_tasks) / np.log2(max(unique_num_tasks)))

            # Plot using Seaborn for better aesthetics
            filtered_data = task_specific_df[(task_specific_df[f'{prefix}llc/mean'] != "NaN") & (task_specific_df[f'{prefix}llc/mean'] <1_000)]
            sns.lineplot(x='_step', y=f'{prefix}llc/mean', data=filtered_data, ax=ax, label=f'_M={log_num_tasks}', color=color)
            
            if show_std:
                mean_val = task_specific_data[f'{prefix}llc/mean']
                std_val = task_specific_data[f'{prefix}llc/std']

                if not np.isnan(mean_val) and not np.isnan(std_val) and np.isfinite(mean_val) and np.isfinite(std_val):
                    lower = mean_val - std_val
                    upper = mean_val + std_val
                else:
                    lower = np.nan
                    upper = np.nan

                ax.fill_between(task_specific_data['_step'], lower, upper, color=color, alpha=0.1)

        ax.set_title(f"$\epsilon={lr}, \gamma={elasticity}$")
        ax.set_xlabel(r"$t_\mathrm{SGLD}$")
        ax.set_ylabel(r"$\hat\lambda$")

# plt.legend()
plt.tight_layout()

# Plot a color bar to the right of the grid
norm = Normalize(vmin=0, vmax=20)
cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=axes)
# cbar.ax.set_clim(0, 20)
cbar.ax.set_ylabel(r"$\log_2(M)$")
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()


if show_std:
    plt.savefig(f"../figures/llc-grid-over-t-L{num_layers}_H{num_heads}.png")
else:
    plt.savefig(f"../figures/llc-grid-over-t-L{num_layers}_H{num_heads}.png")
plt.show()

In [None]:
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
import numpy as np
import seaborn as sns

for M in range(0, 21, 3):
    upper_M = min(21, M + 5)

    data = subdf[(subdf['task_config/num_tasks'] >= 2**M) & (subdf['task_config/num_tasks'] < 2**upper_M)]

    # Get unique values for lrs, elasticitys, and num_tasks
    num_chains = 25
    unique_lrs = data['analysis_config/lr'].unique()
    unique_elasticities = data['analysis_config/elasticity'].unique()
    unique_num_tasks = data['task_config/num_tasks'].unique()

    show_std = True

    # Sort for visual consistency
    unique_lrs.sort()
    unique_lrs = unique_lrs[:-1]
    unique_elasticities.sort()
    unique_num_tasks.sort()

    # Initialize colormap
    cmap = plt.cm.viridis

    # Create subplots
    fig, axes = plt.subplots(len(unique_lrs), len(unique_elasticities), figsize=(15, 15))
    fig.set_facecolor('white')

    prefix = "thresholded-" # ""
    Prefix = "Thresholded " # ""

    fig.suptitle(f"{Prefix}$\hat\lambda$ hyperparameter sweep ($n_\mathrm{{chains}}={num_chains}$)\n$L=2, H=4$\n$M \in [{2**M}, {2**upper_M}, t=500k)$")

    # Loop through the grid
    for token, lr in enumerate(unique_lrs):
        for j, elasticity in enumerate(unique_elasticities):
            ax = axes[token, j]

            # Filter DataFrame for specific lr and elasticity
            filtered_data = data[(data['analysis_config/lr'] == lr) & (data['analysis_config/elasticity'] == elasticity)]

            for log_num_tasks in unique_num_tasks:
                task_specific_data = filtered_data[filtered_data['task_config/num_tasks'] == log_num_tasks]

                # Sort by 'num_draws' for plotting
                task_specific_data = task_specific_data.sort_values('_step')

                # Calculate color based on log2(num_tasks)
                color = cmap((np.log2(log_num_tasks)-M)/5)

                # Plot using Seaborn for better aesthetics
                more_filtered_data = task_specific_data.loc[(task_specific_data[f'{prefix}llc/mean'] != "NaN") & (task_specific_data[f'{prefix}llc/std'] != "NaN")]
                sns.lineplot(x='_step', y=f'{prefix}llc/mean', data=more_filtered_data, ax=ax, label=f'_M={log_num_tasks}', color=color)

                if show_std:
                    # Print types of each cell in more_filtered_data
                    steps = more_filtered_data['_step'].to_numpy()
                    means = more_filtered_data[f"{prefix}llc/mean"].to_numpy()
                    stds = more_filtered_data[f"{prefix}llc/std"].to_numpy()
                    means = pd.to_numeric(means, errors='coerce')
                    stds = pd.to_numeric(stds, errors='coerce')

                    ax.fill_between(steps, means-stds, means+stds, color=color, alpha=0.2)
                    
            ax.set_title(f"$\epsilon={lr}, \gamma={elasticity}$")
            ax.set_xlabel(r"$t_\mathrm{SGLD}$")
            ax.set_ylabel(r"$\hat\lambda$")

    # plt.legend()
    plt.tight_layout()

    # Plot a color bar to the right of the grid
    norm = Normalize(vmin=M, vmax=upper_M)
    cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=axes)
    # cbar.ax.set_clim(0, 20)
    cbar.ax.set_ylabel(r"$\log_2(M)$")
    cbar.locator = MaxNLocator(integer=True)
    cbar.update_ticks()


    plt.savefig(f"../figures/llc-grid-search-M{M}-{upper_M}.png")
    plt.show()

In [None]:
# Let's plot this as a function of M on the x axis

from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
import numpy as np
import seaborn as sns

# Get unique values for lrs, elasticitys, and num_tasks
num_chains = 25
unique_lrs = subdf['analysis_config/lr'].unique()
unique_elasticities = subdf['analysis_config/elasticity'].unique()
unique_num_tasks = subdf['task_config/num_tasks'].unique()

num_layers = 4
num_heads = 4

show_std = True

# Sort for visual consistency
unique_lrs.sort()
# unique_lrs = np.array([lr for lr in unique_lrs if lr <= 0.0001])
unique_lrs = unique_lrs[:-2]
unique_elasticities.sort()
unique_num_tasks.sort()

unique_num_tasks = np.array([2**m for m in range(0, 21)])

prefix = "" # "thresholded-" 
Prefix = "" # "Thresholded " 

# Initialize colormap
cmap = plt.cm.viridis

# Create subplots
fig, axes = plt.subplots(len(unique_lrs), len(unique_elasticities), figsize=(15, 15))
fig.set_facecolor('white')
fig.suptitle(f"{Prefix}$\hat\lambda$ hyperparameter sweep ($n_\mathrm{{chains}}={num_chains}$)\n$L={num_layers}, H={num_heads}, t=500k$")

steps = np.array([9, 29, 99, 299, 999])

# Loop through the grid
for token, lr in enumerate(unique_lrs):
    for j, elasticity in enumerate(unique_elasticities):
        ax = axes[token, j]

        # Filter DataFrame for specific lr and elasticity
        filtered_df = subdf[(subdf['analysis_config/lr'] == lr) & (subdf['analysis_config/elasticity'] == elasticity)]

        for step in steps:
            # Find the closest step to the desired step for each num_tasks 
            # Problem is wandb sometimes drops a log.
            closest_step_df = filtered_df.groupby('task_config/num_tasks').apply(lambda x: x.iloc[(x['_step']-step).abs().argsort()[:1]]).reset_index(drop=True)

            # Sort by 'num_draws' for plotting
            closest_step_df = closest_step_df.sort_values('task_config/num_tasks')

            # Calculate color based on log2(num_tasks)
            color = cmap(step / 999)

            # Plot using Seaborn for better aesthetics
            filtered_data = closest_step_df[closest_step_df[f'{prefix}llc/mean'] != "NaN"]
            log_num_tasks = filtered_data['task_config/num_tasks'].to_numpy()
            log_num_tasks = pd.to_numeric(log_num_tasks, errors='coerce')
            means = filtered_data[f"{prefix}llc/mean"].to_numpy()
            means = pd.to_numeric(means, errors='coerce')

            ax.plot(log_num_tasks, means, label=f'_step={step}', color=color)

            if show_std:
                stds = filtered_data[f"{prefix}llc/std"].to_numpy()
                stds = pd.to_numeric(stds, errors='coerce')
                ax.fill_between(log_num_tasks, means-stds, means+stds, color=color, alpha=0.2)
                
        ax.set_title(f"$\epsilon={lr}, \gamma={elasticity}$")
        ax.set_xlabel(r"$M$")
        ax.set_xscale("log")
        ax.set_xticks([2**m for m in range(0, 21, 4)], [f"$2^{{{m}}}$" for m in range(0, 21, 4)])
        ax.set_ylabel(r"$\hat\lambda$")

# plt.legend()
plt.tight_layout()

# Plot a color bar to the right of the grid
norm = Normalize(vmin=0, vmax=1000)
cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=axes)
# cbar.ax.set_clim(0, 20)
cbar.ax.set_ylabel(r"$t_\mathrm{SGLD}$")
cbar.locator = MaxNLocator(integer=True)
cbar.update_ticks()


if show_std:
    plt.savefig("../figures/llc-grid-search-std.png")
else:
    plt.savefig("../figures/llc-grid-search.png")
plt.show()

In [None]:
# Let's plot this as a function of M on the x axis

from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
import numpy as np
import seaborn as sns

# Get unique values for lrs, elasticitys, and num_tasks
num_chains = 25
unique_lrs = subdf['analysis_config/lr'].unique()
unique_elasticities = subdf['analysis_config/elasticity'].unique()
unique_num_tasks = subdf['task_config/num_tasks'].unique()

num_layers = 2
num_heads = 4

show_std = True

# Sort for visual consistency
unique_lrs.sort()
unique_lrs = unique_lrs[:-1]
# unique_lrs = np.array([lr for lr in unique_lrs if lr <= 0.0001])
unique_elasticities.sort()
unique_num_tasks.sort()

unique_num_tasks = np.array([2**m for m in range(0, 21)])

prefix = "" # "thresholded-" 
Prefix = "" # "Thresholded " 

# Initialize colormap
cmap = plt.cm.viridis

# Create subplots
fig, axes = plt.subplots(len(unique_lrs), len(unique_elasticities), figsize=(15, 15))
fig.set_facecolor('white')
fig.suptitle(f"{Prefix}$\hat\lambda$ hyperparameter sweep ($n_\mathrm{{chains}}={num_chains}$)\n$L={num_layers}, H={num_heads}, t=500k$")

# Loop through the grid
for token, lr in enumerate(unique_lrs):
    for j, elasticity in enumerate(unique_elasticities):
        ax = axes[token, j]

        # Filter DataFrame for specific lr and elasticity
        filtered_df = subdf[(subdf['analysis_config/lr'] == lr) & (subdf['analysis_config/elasticity'] == elasticity)]

        # Get the last step for each num_tasks
        last_step_df = filtered_df.groupby('task_config/num_tasks').last().reset_index()

        # Calculate color based on log2(num_tasks)
        color = sns.color_palette()[0]

        # Plot using Seaborn for better aesthetics
        filtered_data = last_step_df[last_step_df[f'{prefix}llc/mean'] != "NaN"]
        log_num_tasks = filtered_data['task_config/num_tasks'].to_numpy()
        log_num_tasks = pd.to_numeric(log_num_tasks, errors='coerce')
        means = filtered_data[f"{prefix}llc/mean"].to_numpy()
        means = pd.to_numeric(means, errors='coerce')

        ax.plot(log_num_tasks, means, label=f'_step={step}', color=color)

        if show_std:
            stds = filtered_data[f"{prefix}llc/std"].to_numpy()
            stds = pd.to_numeric(stds, errors='coerce')
            ax.fill_between(log_num_tasks, means-stds, means+stds, color=color, alpha=0.2)
                
        ax.set_title(f"$\epsilon={lr}, \gamma={elasticity}, t_\mathrm{{SGLD}}=1000$")
        ax.set_xlabel(r"$M$")
        ax.set_xscale("log")
        ax.set_xticks([2**m for m in range(0, 21, 4)], [f"$2^{{{m}}}$" for m in range(0, 21, 4)])
        ax.set_ylabel(r"$\hat\lambda$")

# plt.legend()
plt.tight_layout()

# Plot a color bar to the right of the grid
# norm = Normalize(vmin=0, vmax=1000)
# cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=axes)
# cbar.ax.set_clim(0, 20)
# cbar.ax.set_ylabel(r"$t_\mathrm{SGLD}$")
# cbar.locator = MaxNLocator(integer=True)
# cbar.update_ticks()


if show_std:
    plt.savefig("../figures/llc-grid-search-std.png")
else:
    plt.savefig("../figures/llc-grid-search.png")
plt.show()

In [None]:
import yaml
import wandb
from icl.config import get_config
import pandas as pd
from tqdm import tqdm
from devinfra.utils.iterables import flatten_dict

api = wandb.Api()
# sweep = api.sweep("devinterp/icl-llc/d3ctawc7")  # L2H4
sweep = api.sweep("devinterp/icl-llc/eli1wlds")  # L4H4

def wandb_run_to_df(run):
    history_df = run.history()
    config_dict = get_config(**run.config).model_dump()
    config_dict["analysis_config"] = run.config["analysis_config"]

    del config_dict["logger_config"]
    del config_dict["checkpointer_config"]

    config_dict_flat = flatten_dict(config_dict, flatten_lists=True)
    
    for k, v in config_dict_flat.items():
        if isinstance(v, tuple):
            # Repeat the tuple for the entire length of the DataFrame
            v = [v] * len(history_df)
            
        history_df[k] = v

    return history_df


def wandb_runs_to_df(runs):
    return pd.concat([wandb_run_to_df(run) for run in tqdm(runs, desc="Converting runs to dfs")])


subdf = wandb_runs_to_df(sweep.runs)

In [None]:
num_layers = 2
num_heads = 4

name = f"../analysis/L{num_layers}H{num_heads}-llc-grid-search-batches.csv"
subdf.to_csv(name) 
# subdf = pd.read_csv(name)
# df = pd.read_csv("../analysis/L2H4-llc-grid-search.csv")

In [None]:
subdf.columns

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_llc_estimation_hyperparams_sweep(observations_df, y="llc/mean", row="analysis_config/lr", col="analysis_config/batch_size"):
    fig = plt.figure(figsize=(15, 6))

    # Get rid of the NaNs
    observations_df = observations_df[observations_df[y] != "NaN"]
    observations_df = observations_df.rename(columns={"task_config/num_tasks": "# Tasks",
                                                      "analysis_config/lr": "Learning Rate",
                                                      "analysis_config/batch_size": "Batch Size"})

    g = sns.FacetGrid(observations_df, col="Batch Size", row="Learning Rate", sharey=False)
    g.map_dataframe(sns.lineplot, x="# Tasks", y=y)
    g.add_legend()
    g.set(xscale="log", yscale="log")
    # g.set(xscale="log", yscale="linear")

    plt.suptitle("Covariance estimation hyperparameter sweep")
    g.fig.tight_layout()
    plt.show()


plot_llc_estimation_hyperparams_sweep(subdf)

## Gradient noise as a function of batch size

The Gaussian noise term of size $\epsilon$ should dominate the gradient noise due to computing gradients over batches of size $m$ in the SGLD step. We want the variance in the gradient norm $g_m$ to be small compared to the variance in the Gaussian noise. 

Knowing $g_m$ as a function of batch size is useful because it establishes a lower bound on batch size for a given choice of $(\beta, n)$:

$$
g_n \ll \frac{2}{\beta n}
$$

Note: more precisely, we should be looking at the maximum eigenvalue of the gradient norm covariance matrix. This is expensive for large models, so we'll make a simplification and just look at the variance in gradient norm instead. We could do better.


### Where to calibrate?

We'll do the gradient noise check at two spots:
- Where the gradient norm is maximal (early in training)
- Where the learning coefficient estimates are maximal (~10k steps)

In [None]:
gradient_norms

In [None]:
# Gradient norms over time
gradient_norms['grad_sq/sum'] = gradient_norms['grad_sq/mean'] * gradient_norms['numel']
avg_gradient_norms = gradient_norms.groupby(["m", "step"]).sum().reset_index()
avg_gradient_norms["grad/norm"] = (avg_gradient_norms["grad_sq/mean"] ** 0.5) / avg_gradient_norms["numel"]

fig, ax = plt.subplots(figsize=(15, 6))

sns.lineplot(avg_gradient_norms, x="step", y="grad/norm", hue="m", ax=ax)
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlabel("Step, $t$")
ax.set_ylabel(r"Gradient norm, $||\nabla L_n||$")
ax.legend(title=r"$\log_2 M$")

In [None]:
import yaml
from devinfra.io.storage import BaseStorageProvider
from icl.config import ICLConfig
from torchtyping import TensorType

def analyze_gradient_norm_variance(
        step: int, 
        configs: List[ICLConfig], 
        checkpointers: List[BaseStorageProvider], 
        batch_sizes: List[int] = [1, 4, 64, 512, 1024, 8196],
        num_batches: int = 100,
        criterion: Callable[[TensorType["B", "D"], TensorType["B", "D"]], TensorType["B"]] = nn.MSELoss(reduction="mean"),
        log_num_tasks: List[int] = list(range(21))
):
    gradient_norms = []

    for log2_M in tqdm(log_num_tasks, desc="Iterating over log_2(M)"):
        config = configs[log2_M]
        checkpointer = checkpointers[log2_M]

        nearest_step = min(checkpointer.file_ids, key=lambda x: abs(x-step))    

        checkpoint = checkpointer.load_file(nearest_step)
        run = Run.create_and_restore(config)
        run.model.load_state_dict(checkpoint["model"])
        run.model.eval()

        for batch_size in batch_sizes:
            first_moment = 0.
            second_moment = 0.

            # first_moments_per_layer = defaultdict(lambda: 0.)
            # second_moments_per_layer = defaultdict(lambda: 0.)

            for i in range(num_batches):
                run.model.zero_grad()
                xs, ys = run.pretrain_dist.get_batch(run.config.task_config.max_examples, batch_size)
                yhats = run.model(xs, ys)
                loss = criterion(yhats, ys)
                loss.backward()

                norm = 0.

                for n, p in run.model.named_parameters():
                    if p.grad is not None:
                        # first_moments_per_layer[n] += p.grad.norm().detach().item()
                        # second_moments_per_layer[n] += (p.grad.norm() ** 2).detach().item()

                        norm += (p.grad ** 2).sum().detach().item()

                    p.grad = None

                first_moment += norm ** 0.5
                second_moment += norm 

                del xs
                del ys
                del yhats
                del loss

            # for n, p in run.model.named_parameters():
            #     first_moments_per_layer[n] /= num_batches
            #     second_moments_per_layer[n] /= num_batches

            #     gradient_norms.append({
            #         "m": m,
            #         "M": 2 ** m,
            #         "step": nearest_step,
            #         "grad/norm/mean": first_moments_per_layer[n],
            #         "grad/norm/std": ((second_moments_per_layer[n] - first_moments_per_layer[n] ** 2) ** 0.5),
            #         "numel": p.numel(),
            #         "batch_size": batch_size,
            #         "layer": n
            #     })

            first_moment /= num_batches
            second_moment /= num_batches

            gradient_norms.append({
                "m": log2_M,
                "M": 2**log2_M,
                "step": nearest_step,
                "grad/norm/mean": first_moment,
                "grad/norm/std": ((second_moment - first_moment ** 2) ** 0.5),
                "numel": sum(p.numel() for p in run.model.parameters()),
                "batch_size": batch_size,
                "layer": "average"
            })

            # print(yaml.dump(gradient_norms[-1]))
            # print(torch.mps.current_allocated_memory() / 1e9)

        del checkpoint
        del run

    return pd.DataFrame(gradient_norms)

IVL = 0.5

log_num_tasks = [int(np.log2(m)) for m in MS]

batch_sizes = [int(2 ** (i * IVL)) for i in range(int(4 / IVL), int(15 / IVL))]
gradient_norm_variance_t100 = analyze_gradient_norm_variance(100, configs, checkpointers, log_num_tasks=log_num_tasks)
gradient_norm_variance_t100.to_csv("../analysis/gradient-norm-variance-t100.csv")

gradient_norm_variance_t10k = analyze_gradient_norm_variance(10_000, configs, checkpointers, log_num_tasks=log_num_tasks)
gradient_norm_variance_t10k.to_csv("../analysis/gradient-norm-variance-t10k.csv")

gradient_norm_variance_t500k = analyze_gradient_norm_variance(500_000, configs, checkpointers, log_num_tasks=log_num_tasks)
gradient_norm_variance_t500k.to_csv("../analysis/gradient-norm-variance-t500k.csv")

In [None]:
# Let's fit a line to the gradient norm variance to get an empirical formula relation between batch size and gradient norm variance

from scipy.optimize import curve_fit

def func(x, a, b):
    return a * (x ** -b)

def fit_gradient_norm_variance(gradient_norm_variance_df):
    batch_sizes, norm_variances = gradient_norm_variance_df["batch_size"].values, gradient_norm_variance_df["grad/norm/var"].values

    # popt, pcov = curve_fit(func, np.log2(batch_sizes), np.log10(norm_variances), p0=(-1, 2))
    popt, pcov = curve_fit(func, batch_sizes, norm_variances, p0=(50, -2))

    return popt, pcov

t100_popt, t100_pcov = fit_gradient_norm_variance(gradient_norm_variance_t100)
t10k_popt, t10k_pcov = fit_gradient_norm_variance(gradient_norm_variance_t10k)
t500k_popt, t500k_pcov = fit_gradient_norm_variance(gradient_norm_variance_t500k)

# plot the fit

batch_sizes = np.logspace(0, 13, 1000, base=2)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase

# gradient_norm_variance_t100['grad/norm/mean'] = [p.item() for p in gradient_norm_variance_t100['grad/norm/mean']]
# gradient_norm_variance_t100['grad/norm/std'] = [p.item() for p in gradient_norm_variance_t100['grad/norm/std']]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ts = [min(checkpointer.file_ids, key=lambda x: abs(x-step)) for step in [100, 10_000, 500_000]]

gradient_norm_dfs = dict(zip(ts, [gradient_norm_variance_t100, gradient_norm_variance_t10k, gradient_norm_variance_t500k]))

for (t, grad_norm_df), ax, t_popt in zip(gradient_norm_dfs.items(), axes.flatten(), [t100_popt, t10k_popt, t500k_popt]):
    grad_norm_df["grad/norm/var"] = grad_norm_df["grad/norm/std"] ** 2

    sns.lineplot(
        data=grad_norm_df[grad_norm_df["layer"] == "average"],
        x="batch_size",
        y="grad/norm/var",
        hue="m",
        palette="viridis",
        ax=ax,
        legend=False
    ) 

    # Plot the fit
    # ax.plot(batch_sizes, 10 ** func(np.log2(batch_sizes), *t_popt), color="red", linestyle="--", label="Fit")
    ax.plot(batch_sizes, func(batch_sizes, *t_popt), color="red", linestyle="--", label="Fit ($=am^{-b}$)")

    ax.set_xscale("log", base=2)
    ax.set_yscale("log")
    # ax.set_xticklabels([f"${{{int(np.log2(x))}}}$" for x in ax.get_xticks()])

    ax.set_xlabel(r"Batch Size, $m$")
    ax.set_ylabel(r"Gradient Norm Variance, $\mathbb{V}[|\nabla L_m|]$")
    ax.legend()
    
    ax.set_title(f"$t={t+1}$")

    param_text = f'a={t_popt[0]:.2f}, b={t_popt[1]:.2f}'
    ax.text(0.05, 0.1, param_text, transform=ax.transAxes, verticalalignment='top')

for ax in axes[1:]:
    ax.set_ylabel("")

cbar_ax = fig.add_axes([0.93, 0.125, 0.02, 0.75])  # Adjust as necessary for position and size
# plt.tight_layout()

# Define the colormap
cmap = plt.cm.viridis

sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=20), )
sm._A = []  # Dummy array for the ScalarMappable. 
cbar = fig.colorbar(sm, cax=cbar_ax)

tick_positions = [0, 5, 10, 15, 20]  # Positions for each color
tick_labels = map(str, tick_positions)  # Labels for each color
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label("$\log_2 M$")


In terms of this power law $g_n=am^b$, we get that

$$
am^b \ll \frac{1}{\beta n}.
$$

Plugging in $\beta^*=1/\log n$ and moving terms around, this becomes:

$$
\log_2 m \ll \frac{1}{b}\log_2 \left(\frac{\log n}{2 a n}\right).
$$

In [None]:
n = np.linspace(100, 1_000_000, 1000)

def upper_bound(n, a, b):
    return -np.log2( 2 * np.log(n) / (a * n)) / b

fig, ax = plt.subplots(figsize=(8, 4))

for t_label, t_popt in zip(["100", "10k", "500k"], [t100_popt, t10k_popt, t500k_popt]):
    ax.plot(n, upper_bound(n, *t_popt), linestyle="--", label=f"Fit for $t={t_label}$") 

ax.set_yticklabels([f"$2^{{{int(x)}}}$" for x in ax.get_yticks()])
ax.set_xscale("log", base=2)

ax.set_ylabel(r"$m$")
ax.set_xlabel(r"$n$")
ax.legend(title="Lower bounds")

# Garbage Collection

In [None]:
# Garbage collection
import gc

print(torch.mps.current_allocated_memory() / 1e9)
gc.collect()
torch.mps.empty_cache()  # If you're using a CUDA-enabled GPU
torch.mps.current_allocated_memory() / 1e9