# Essential dynamics

In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
import seaborn as sns
from sklearn.decomposition import PCA
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.patches as mpatches

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
import numpy as np
from sklearn.decomposition import PCA
from torch.nn import functional as F
from sklearn.manifold import TSNE
import gc
import itertools
from scipy.ndimage import gaussian_filter1d
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo
import numpy as np
import tqdm
from infra.utils.iterables import int_linspace
from copy import deepcopy
from pathlib import Path

# import sys
# del sys.modules['icl.figures.colors']
# del sys.modules['icl.figures.notation']

from devinterp.slt.forms import get_osculating_circle
from icl.analysis.utils import get_unique_run
from icl.constants import ANALYSIS, FIGURES, SWEEPS, DATA
from icl.figures.notation import str_d_dlogt, str_d_dt, str_dlog_dlogt
from icl.figures.colors import (
    plot_transitions,
    gen_transition_colors,
    get_transition_type,
    PRIMARY,
    SECONDARY,
    TERTIARY,
    BRED,
    BBLUE,
    BRED,
    BGREEN,
)
from icl.constants import DEVICE

# from devinterp.slt.forms import
sns.set_style("white")
DEVICE

NUM_TASKS = "inf"
NUM_LAYERS = 2
MAX_LR = 0.003
MODEL_SEEDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

steps = int_linspace(0, 500_000, 10_000)[::2]


plt.rcParams["figure.dpi"] = 300

In [3]:
def get_models_and_optimizers(run, steps, model_id):
    if os.path.exists(Path("../checkpoints") / f"{model_id}-models.pt"):
        print("Loading models from disk")
        models = torch.load(Path("../checkpoints") / f"{model_id}-models.pt")
        optimizer_state_dicts = torch.load(
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt"
        )

    else:
        print("Retrieving models from AWS")
        # Let's generate these same plots and also look at their evolution.
        models = []
        optimizer_state_dicts = []

        for step in tqdm.tqdm(steps):
            checkpoint = run.checkpointer.load_file(step)

            m = deepcopy(run.model)
            m.load_state_dict(checkpoint["model"])
            models.append(m)
            optimizer_state_dicts.append(checkpoint["optimizer"])

        print("Saving models to disk")
        torch.save(models, Path("../checkpoints") / f"{model_id}-models.pt")
        torch.save(
            optimizer_state_dicts,
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt",
        )
    
    return models, optimizer_state_dicts

In [52]:
import pickle
from icl.regression.model import to_token_sequence, from_predicted_token_sequence

K = 16
B = 1024
D = 4

def get_tokens(run, batch_size, max_examples, seed=0, include_x_and_y=False):
    torch.manual_seed(seed)

    xs, ys = run.pretrain_dist.get_batch(max_examples, batch_size, return_ws=False)
    tokens = to_token_sequence(xs, ys)

    if include_x_and_y:
        return tokens, xs, ys

    return tokens



In [49]:
def get_y_outputs(models, xs, ys, model_id, force_reeval=False):
    B, K, D = xs.shape

    outputs = np.zeros((len(models), K * B), dtype=np.float32)

    if not os.path.exists(DATA / f"{model_id}-outputs-y-only.pkl") or force_reeval:
        print("Computing outputs")
        for i, model in enumerate(tqdm.tqdm(models, desc="Computing outputs")):
            with torch.no_grad():
                output = model(xs, ys).flatten()
                outputs[i, :] = output.cpu().numpy()

        with open(DATA / f"{model_id}-outputs-y-only.pkl", "wb") as f:
            pickle.dump(outputs, f)
    else:
        print("Loading y outputs from disk")
        with open(DATA / f"{model_id}-outputs-y-only.pkl", "rb") as f:
            outputs = pickle.load(f)

    return outputs

def get_outputs(models, tokens, model_id, force_reeval=False):
    B, K, D = tokens.shape
    K = K // 2  
    D = D - 1 

    outputs = np.zeros((len(models), K * B * (D + 1) * 2), dtype=np.float32)

    if not os.path.exists(DATA / f"{model_id}-outputs.pkl") or force_reeval:
        print("Computing outputs")
        for i, model in enumerate(tqdm.tqdm(models, desc="Computing outputs")):
            with torch.no_grad():
                output = model.token_sequence_transformer(tokens).flatten()
                outputs[i, :] = output.cpu().numpy()

        with open(DATA / f"{model_id}-outputs.pkl", "wb") as f:
            pickle.dump(outputs, f)
    else:
        print("Loading outputs from disk")
        with open(DATA / f"{model_id}-outputs.pkl", "rb") as f:
            outputs = pickle.load(f)

    return outputs

In [7]:
def get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False):
    if not os.path.exists(DATA / f"{model_id}-pca.pkl") or force_reeval:
        print("Computing PCA")
        pca = PCA(n_components=n_components)
        pca.fit(outputs)
        reduced = pca.transform(outputs)
        with open(DATA / f"{model_id}-pca.pkl", "wb") as f:
            pickle.dump((pca, reduced), f)
    else:
        print("Loading PCA from disk")
        with open(DATA / f"{model_id}-pca.pkl", "rb") as f:
            pca, reduced = pickle.load(f)
    return pca, reduced

In [8]:
from typing import List, TypedDict
import yaml

class FormDict(TypedDict):
    name: str
    components: List[float]

def get_forms(model_id) -> List[FormDict]:
    if os.path.exists(DATA / f"{model_id}-forms.yaml"):
        print("Loading forms from disk")
        with open(DATA / f"{model_id}-forms.yaml", "r") as f:
            forms = yaml.safe_load(f)
    else:
        print("Computing forms")
        forms = []
        
    return forms

In [9]:
import plotly.express as px
from sklearn.decomposition import PCA
import seaborn as sns

cmap = sns.color_palette("Spectral", as_cmap=True)
color_indices = np.linspace(0, 1, len(steps))
colors = np.array([cmap(c) for c in color_indices])

def to_color_string(color):
    # return (256 * color[0], 256 * color[1], 256 * color[2], color[3])
    return f"rgb({int(256 * color[0])}, {int(256 * color[1])}, {int(256 * color[2])}, {color[3]})"


In [10]:
def plot_ed(pca, reduced, reduced_smooth, forms, model_id, form_cmap='rainbow', evolute_cmap='Spectral', num_components=3, title="", slug="pca.html"):
    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

    if isinstance(form_cmap, str):
        form_cmap = sns.color_palette(form_cmap, as_cmap=True)
    if isinstance(evolute_cmap, str):
        evolute_cmap = sns.color_palette(evolute_cmap, as_cmap=True)

    form_colors = np.array([to_color_string(form_cmap(c)) for c in np.linspace(0, 1, len(forms))])   
    evolute_colors = np.array([to_color_string(evolute_cmap(c)) for c in np.linspace(0, 1, len(reduced_smooth)-4)])

    for i, j in tqdm.tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            reduced[:, i].min(),
            reduced[:, i].max(),
        )
        xmin, xmax = (
            reduced[:, j].min(),
            reduced[:, j].max(),
        )

        # Forms
        for f, form in enumerate(forms):
            if form[j] is not None:
                # Vertical line
                fig.add_shape(
                    type="line",
                    x0=form[j],
                    y0=ymin * 1.25,
                    x1=form[j],
                    y1=ymax * 1.25,
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )
            if form[i] is not None:
                # Horizontal line
                fig.add_shape(
                    type="line",
                    x0=xmin * 1.25,
                    y0=form[i],
                    x1=xmax * 1.25,
                    y1=form[i],
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )

        ts = np.array(range(2, len(reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Circles
        for ti, t in enumerate(ts):
            center, radius = get_osculating_circle(
                reduced_smooth[:, (j, i)], t
            )
            # if ti % 16 == 0:
            #     # This seems to be cheaper than directly plotting a circle
            #     circle = go.Scatter(
            #         x=center[0] + radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
            #         y=center[1] + radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
            #         mode="lines",
            #         line=dict(color="rgba(0.1, 0.1, 1, 0.05)", width=1),
            #         showlegend=False,
            #     )
            #     fig.add_trace(circle, row=row, col=col)

            centers[ti] = center

        # Centers
        fig.add_trace(
            go.Scatter(
                x=centers[:, 0],
                y=centers[:, 1],
                mode="markers",
                marker=dict(size=2, symbol="x", color=evolute_colors),
                name="Centers",
            ),
            row=row,
            col=col,
        )

        # Original samples
        # fig.add_trace(
        #     go.Scatter(
        #         x=reduced[:, j],
        #         y=reduced[:, i],
        #         mode="markers",
        #         marker=dict(color=colors, size=3),
        #         showlegend=False,
        #     ),
        #     row=row,
        #     col=col,
        # )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=reduced_smooth[:, j],
                y=reduced_smooth[:, i],
                mode="lines",
                line=dict(color="black", width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2500, height=2500)  # Adjust the size as needed
    fig.update_layout(title_text=title, showlegend=False)

    # Save as html
    pyo.plot(fig, filename=str(FIGURES / model_id / slug))
    # fig.write_image(str(FIGURES / model_id / "pca.png"))

    return fig

In [35]:
num_checkpoints_total = 5000
num_downsample = 50
num_checkpoints = num_checkpoints_total // num_downsample
num_seeds = 10

In [56]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

TOKENS_SEED = 0

num_checkpoints_total = 5000
num_downsample = 10
num_checkpoints = num_checkpoints_total // num_downsample
num_seeds = 10

combined_outputs = np.zeros((num_seeds * num_checkpoints, 81920))
combined_y_outputs = np.zeros((num_seeds * num_checkpoints, 8192 * 2))

for model_seed in tqdm.tqdm(MODEL_SEEDS):
    model_id = f"L2H4Minf{model_seed}"

    os.makedirs(str(FIGURES / model_id), exist_ok=True)
    os.makedirs(str(DATA / model_id), exist_ok=True)

    print("Retrieving run...")
    run = get_unique_run(
        str(SWEEPS / "regression/training-runs/L2H4Minf.yaml"),
        task_config={
            "num_tasks": NUM_TASKS,
            "num_layers": NUM_LAYERS,
            "model_seed": model_seed,
        },
        optimizer_config={"lr": MAX_LR},
    )
    print("Retrieved run.")

    models, optimizer_state_dicts = get_models_and_optimizers(run, steps, model_id)
    tokens, xs, ys = get_tokens(run, B, K, seed=TOKENS_SEED, include_x_and_y=True)
    print(f"Tokens generated from seed {TOKENS_SEED} with shape {tokens.shape}")

    outputs, y_outputs = get_outputs(models, tokens, model_id, force_reeval=False), get_y_outputs(models, xs, ys, model_id, force_reeval=False)

    print(f"Outputs shape: {outputs.shape}")
    combined_outputs[model_seed * num_checkpoints: (model_seed + 1) * num_checkpoints, :] = outputs[::num_downsample, :]
    combined_y_outputs[model_seed * num_checkpoints: (model_seed + 1) * num_checkpoints, :] = y_outputs[::num_downsample, :]
    # pca, reduced = get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False)

    # start, end = 0.1, 300
    # reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    # forms = get_forms(model_id)
    # num_forms = len(forms)
    # form_cmap = sns.color_palette("rainbow", as_cmap=True)

    # fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id)

  0%|          | 0/10 [00:00<?, ?it/s]

Retrieving run...
Retrieved run.
Loading models from disk


  0%|          | 0/10 [00:33<?, ?it/s]


KeyboardInterrupt: 

In [57]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

TOKENS_SEED = 0

num_checkpoints_total = 5000
num_downsample = 10
num_checkpoints = num_checkpoints_total // num_downsample
num_seeds = 10

combined_outputs = np.zeros((num_seeds * num_checkpoints, 81920))
combined_y_outputs = np.zeros((num_seeds * num_checkpoints, 8192 * 2))

for model_seed in tqdm.tqdm(MODEL_SEEDS):
    model_id = f"L2H4Minf{model_seed}"

    os.makedirs(str(FIGURES / model_id), exist_ok=True)
    os.makedirs(str(DATA / model_id), exist_ok=True)

    tokens, xs, ys = get_tokens(run, B, K, seed=TOKENS_SEED, include_x_and_y=True)
    print(f"Tokens generated from seed {TOKENS_SEED} with shape {tokens.shape}")

    outputs, y_outputs = get_outputs(models, tokens, model_id, force_reeval=False), get_y_outputs(models, xs, ys, model_id, force_reeval=False)

    print(f"Outputs shape: {outputs.shape}")
    combined_outputs[model_seed * num_checkpoints: (model_seed + 1) * num_checkpoints, :] = outputs[::num_downsample, :]
    combined_y_outputs[model_seed * num_checkpoints: (model_seed + 1) * num_checkpoints, :] = y_outputs[::num_downsample, :]
    # pca, reduced = get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False)

    # start, end = 0.1, 300
    # reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    # forms = get_forms(model_id)
    # num_forms = len(forms)
    # form_cmap = sns.color_palette("rainbow", as_cmap=True)

    # fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id)

  0%|          | 0/10 [00:00<?, ?it/s]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 10%|█         | 1/10 [00:00<00:06,  1.46it/s]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 20%|██        | 2/10 [00:01<00:06,  1.21it/s]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Computing outputs


Computing outputs: 100%|██████████| 5000/5000 [00:34<00:00, 143.46it/s]
 30%|███       | 3/10 [00:37<01:57, 16.79s/it]

Outputs shape: (5000, 81920)
Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 40%|████      | 4/10 [00:38<01:02, 10.46s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 50%|█████     | 5/10 [00:38<00:34,  6.95s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 60%|██████    | 6/10 [00:41<00:21,  5.31s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 70%|███████   | 7/10 [00:42<00:11,  3.97s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 80%|████████  | 8/10 [00:43<00:06,  3.22s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


 90%|█████████ | 9/10 [00:44<00:02,  2.50s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 81920)


100%|██████████| 10/10 [00:45<00:00,  4.57s/it]


In [62]:
if not os.path.exists(DATA / f"combined-pca.pkl") or True:
    print("Computing Combined PCA")
    pca = PCA(n_components=8)
    pca.fit(combined_outputs)
    reduced = pca.transform(combined_outputs)
    with open(DATA / f"{model_id}-pca.pkl", "wb") as f:
        pickle.dump((pca, reduced), f)
else:
    print("Loading PCA from disk")
    with open(DATA / f"combined-pca.pkl", "rb") as f:
        pca, reduced = pickle.load(f)

Computing Combined PCA


In [33]:
reduced.shap

((1000, 5), (5000, 5))

In [67]:
start, end = 0.1 / num_downsample, 300 / num_downsample
reduced_smooth = np.zeros_like(reduced)

for i in range(10):
    reduced_smooth[i * num_checkpoints: (i + 1) * num_checkpoints] = gaussian_filter1d_variable_sigma(reduced[i * num_checkpoints: (i + 1) * num_checkpoints], np.linspace(start, end, num_checkpoints), axis=0)

num_components = 8

evolute_cmap = ''
fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

for s in range(num_seeds):
    _reduced = reduced[s * num_checkpoints: (s + 1) * num_checkpoints]
    _reduced_smooth = reduced_smooth[s * num_checkpoints: (s + 1) * num_checkpoints]
    color = sns.color_palette("Spectral", as_cmap=True)(s / num_seeds)
    color = to_color_string(color)

    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    
    for i, j in tqdm.tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            reduced[:, i].min(),
            reduced[:, i].max(),
        )
        xmin, xmax = (
            reduced[:, j].min(),
            reduced[:, j].max(),
        )

        ts = np.array(range(2, len(_reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Original samples
        # fig.add_trace(
        #     go.Scatter(
        #         x=reduced[:, j],
        #         y=reduced[:, i],
        #         mode="markers",
        #         marker=dict(color=colors, size=3),
        #         showlegend=False,
        #     ),
        #     row=row,
        #     col=col,
        # )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=_reduced_smooth[:, j],
                y=_reduced_smooth[:, i],
                mode="lines",
                line=dict(color=color, width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        fig.add_trace(
            go.Scatter(
                x=_reduced_smooth[:1, j],
                y=_reduced_smooth[:1, i],
                mode="markers",
                marker=dict(color=color),
                showlegend=False,
            ),
            row=row,
            col=col,
        )


        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2500, height=2500)  # Adjust the size as needed
    fig.update_layout(title_text=f"Combined LR PCA ({num_seeds} seeds, {num_checkpoints} checkpoints per seed, all outputs)", showlegend=False)

# Save as html
pyo.plot(fig, filename=str(FIGURES / "combined-pca.html"))
fig.write_image(str(FIGURES / "combined-pca.png"))


100%|██████████| 64/64 [00:00<00:00, 204.63it/s]
100%|██████████| 64/64 [00:00<00:00, 251.47it/s]
100%|██████████| 64/64 [00:00<00:00, 251.16it/s]
100%|██████████| 64/64 [00:00<00:00, 248.12it/s]
100%|██████████| 64/64 [00:00<00:00, 251.81it/s]
100%|██████████| 64/64 [00:00<00:00, 257.36it/s]
100%|██████████| 64/64 [00:00<00:00, 259.82it/s]
100%|██████████| 64/64 [00:00<00:00, 260.65it/s]
100%|██████████| 64/64 [00:00<00:00, 256.20it/s]
100%|██████████| 64/64 [00:00<00:00, 256.25it/s]


In [58]:
if not os.path.exists(DATA / f"combined-pca-y-only.pkl") or True:
    print("Computing Combined PCA")
    y_pca = PCA(n_components=8)
    y_pca.fit(combined_y_outputs)
    y_reduced = y_pca.transform(combined_y_outputs)
    with open(DATA / f"{model_id}-pca-y-only.pkl", "wb") as f:
        pickle.dump((y_pca, y_reduced), f)
else:
    print("Loading PCA from disk")
    with open(DATA / f"combined-pca-y-only.pkl", "rb") as f:
        y_pca, y_reduced = pickle.load(f)

Computing Combined PCA


In [69]:
start, end = 0.1 / num_downsample, 300 / num_downsample
y_reduced_smooth = np.zeros_like(y_reduced)

for i in range(10):
    y_reduced_smooth[i * num_checkpoints: (i + 1) * num_checkpoints] = gaussian_filter1d_variable_sigma(y_reduced[i * num_checkpoints: (i + 1) * num_checkpoints], np.linspace(start, end, num_checkpoints), axis=0)

num_components = 8

evolute_cmap = ''
fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

for s in range(num_seeds):
    _y_reduced = y_reduced[s * num_checkpoints: (s + 1) * num_checkpoints]
    _y_reduced_smooth = y_reduced_smooth[s * num_checkpoints: (s + 1) * num_checkpoints]
    color = sns.color_palette("Spectral", as_cmap=True)(s / num_seeds)
    color = to_color_string(color)

    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    
    for i, j in tqdm.tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            y_reduced[:, i].min(),
            y_reduced[:, i].max(),
        )
        xmin, xmax = (
            y_reduced[:, j].min(),
            y_reduced[:, j].max(),
        )

        ts = np.array(range(2, len(_y_reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Original samples
        # fig.add_trace(
        #     go.Scatter(
        #         x=y_reduced[:, j],
        #         y=y_reduced[:, i],
        #         mode="markers",
        #         marker=dict(color=colors, size=3),
        #         showlegend=False,
        #     ),
        #     row=row,
        #     col=col,
        # )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=_y_reduced_smooth[:, j],
                y=_y_reduced_smooth[:, i],
                mode="lines",
                line=dict(color=color, width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        fig.add_trace(
            go.Scatter(
                x=_y_reduced_smooth[:1, j],
                y=_y_reduced_smooth[:1, i],
                mode="markers",
                marker=dict(color=color),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2500, height=2500)  # Adjust the size as needed
    fig.update_layout(title_text=f"Combined LR PCA ({num_seeds} seeds, {num_checkpoints} checkpoints per seed, all outputs)", showlegend=False)

# Save as html
pyo.plot(fig, filename=str(FIGURES / "combined-pca-y-only.html"))
fig.write_image(str(FIGURES / "combined-pca-y-only.png"))


100%|██████████| 64/64 [00:00<00:00, 232.62it/s]
100%|██████████| 64/64 [00:00<00:00, 253.09it/s]
100%|██████████| 64/64 [00:00<00:00, 255.75it/s]
100%|██████████| 64/64 [00:00<00:00, 254.92it/s]
100%|██████████| 64/64 [00:00<00:00, 259.14it/s]
100%|██████████| 64/64 [00:00<00:00, 258.83it/s]
100%|██████████| 64/64 [00:00<00:00, 259.73it/s]
100%|██████████| 64/64 [00:00<00:00, 262.66it/s]
100%|██████████| 64/64 [00:00<00:00, 255.99it/s]
100%|██████████| 64/64 [00:00<00:00, 244.48it/s]


# Finite Ms

In [79]:
from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

MODEL_SEED = 0
NUM_TASKS = list(2 ** np.arange(0, 21))
TOKENS_SEED = 0

num_checkpoints_total = 5000
num_downsample = 10
num_checkpoints = num_checkpoints_total // num_downsample
num_diff_tasks = len(NUM_TASKS)

combined_outputs = np.zeros((num_diff_tasks * num_checkpoints, 81920 * 2))
combined_y_outputs = np.zeros((num_diff_tasks * num_checkpoints, 8192 * 2))

for i, num_tasks in enumerate(tqdm.tqdm(NUM_TASKS[:5] + ['inf'])):
    model_id = f"L2H4M{num_tasks}"

    if num_tasks == 'inf':
        model_id += '0'


    os.makedirs(str(FIGURES / model_id), exist_ok=True)
    os.makedirs(str(DATA / model_id), exist_ok=True)
    
    # if num_tasks == 'inf':
    #     run = get_unique_run(
    #         str(SWEEPS / "regression/training-runs/L2H4Minf.yaml"),
    #         task_config={
    #             "num_layers": NUM_LAYERS,
    #             "model_seed": MODEL_SEED,
    #         },
    #         optimizer_config={"lr": MAX_LR},
    #     )
    # else:
    #     run = get_unique_run(
    #         str(SWEEPS / "regression/training-runs/L2H4Mfin.yaml"),
    #         task_config={
    #             "num_tasks": num_tasks,
    #             "num_layers": NUM_LAYERS,
    #             "model_seed": MODEL_SEED,
    #         },
    #         optimizer_config={"lr": MAX_LR},
    #     )
    # print("Retrieved run.")

    # models, optimizer_state_dicts = get_models_and_optimizers(run, steps, model_id)

    tokens, xs, ys = get_tokens(run, B, K, seed=TOKENS_SEED, include_x_and_y=True)
    print(f"Tokens generated from seed {TOKENS_SEED} with shape {tokens.shape}")

    outputs, y_outputs = get_outputs(models, tokens, model_id, force_reeval=False), get_y_outputs(models, xs, ys, model_id, force_reeval=False)

    print(f"Outputs shape: {outputs.shape}, {y_outputs.shape}")
    combined_outputs[i * num_checkpoints: (i + 1) * num_checkpoints, :] = outputs[::num_downsample, :]
    combined_y_outputs[i * num_checkpoints: (i + 1) * num_checkpoints, :] = y_outputs[::num_downsample, :]
    # pca, reduced = get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False)

    # start, end = 0.1, 300
    # reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    # forms = get_forms(model_id)
    # num_forms = len(forms)
    # form_cmap = sns.color_palette("rainbow", as_cmap=True)

    # fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id)

  0%|          | 0/6 [00:00<?, ?it/s]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 163840), (5000, 16384)


 17%|█▋        | 1/6 [00:01<00:05,  1.20s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 163840), (5000, 16384)


 33%|███▎      | 2/6 [00:02<00:04,  1.23s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 163840), (5000, 16384)


 50%|█████     | 3/6 [00:04<00:05,  1.69s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 163840), (5000, 16384)


 67%|██████▋   | 4/6 [00:06<00:03,  1.62s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk
Loading outputs from disk
Outputs shape: (5000, 163840), (5000, 16384)


 83%|████████▎ | 5/6 [00:07<00:01,  1.53s/it]

Tokens generated from seed 0 with shape torch.Size([1024, 32, 5])
Loading outputs from disk


 83%|████████▎ | 5/6 [00:08<00:01,  1.65s/it]

Loading outputs from disk
Outputs shape: (5000, 81920), (5000, 16384)





ValueError: could not broadcast input array from shape (500,81920) into shape (500,163840)

: 