# Essential dynamics

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
import seaborn as sns
from sklearn.decomposition import PCA
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.patches as mpatches

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
import numpy as np
from sklearn.decomposition import PCA
from torch.nn import functional as F
from sklearn.manifold import TSNE
import gc
import itertools
from scipy.ndimage import gaussian_filter1d
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo
import numpy as np
import tqdm
from infra.utils.iterables import int_linspace
from copy import deepcopy
from pathlib import Path

# import sys
# del sys.modules['icl.figures.colors']
# del sys.modules['icl.figures.notation']

from devinterp.slt.forms import get_osculating_circle
from icl.analysis.utils import get_unique_run
from icl.constants import ANALYSIS, FIGURES, SWEEPS, DATA
from icl.figures.notation import str_d_dlogt, str_d_dt, str_dlog_dlogt
from icl.figures.colors import (
    plot_transitions,
    gen_transition_colors,
    get_transition_type,
    PRIMARY,
    SECONDARY,
    TERTIARY,
    BRED,
    BBLUE,
    BRED,
    BGREEN,
)
from icl.constants import DEVICE

# from devinterp.slt.forms import
sns.set_style("white")
DEVICE

NUM_TASKS = "inf"
NUM_LAYERS = 2
MAX_LR = 0.003
MODEL_SEEDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

steps = int_linspace(0, 500_000, 10_000)[::2]


plt.rcParams["figure.dpi"] = 300

In [None]:
def get_models_and_optimizers(run, steps, model_id):
    if os.path.exists(Path("../checkpoints") / f"{model_id}-models.pt"):
        print("Loading models from disk")
        models = torch.load(Path("../checkpoints") / f"{model_id}-models.pt")
        optimizer_state_dicts = torch.load(
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt"
        )

    else:
        print("Retrieving models from AWS")
        # Let's generate these same plots and also look at their evolution.
        models = []
        optimizer_state_dicts = []

        for step in tqdm.tqdm(steps):
            checkpoint = run.checkpointer.load_file(step)

            m = deepcopy(run.model)
            m.load_state_dict(checkpoint["model"])
            models.append(m)
            optimizer_state_dicts.append(checkpoint["optimizer"])

        print("Saving models to disk")
        torch.save(models, Path("../checkpoints") / f"{model_id}-models.pt")
        torch.save(
            optimizer_state_dicts,
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt",
        )
    
    return models, optimizer_state_dicts

In [None]:
import pickle
from icl.regression.model import to_token_sequence, from_predicted_token_sequence

K = 16
B = 1024
D = 4

def get_data(run, batch_size, max_examples, seed=0):
    torch.manual_seed(seed)

    xs, ys = run.pretrain_dist.get_batch(max_examples, batch_size, return_ws=False)
    return xs, ys

In [None]:
def get_outputs(models, xs, ys, model_id, force_reeval=False):
    B, K, D = xs.shape

    outputs = np.zeros((len(models), K * B), dtype=np.float32)

    if not os.path.exists(DATA / f"{model_id}-outputs-y-only.pkl") or force_reeval:
        print("Computing outputs")
        for i, model in enumerate(tqdm.tqdm(models, desc="Computing outputs")):
            with torch.no_grad():
                output = model(xs, ys).flatten()
                outputs[i, :] = output.cpu().numpy()

        with open(DATA / f"{model_id}-outputs-y-only.pkl", "wb") as f:
            pickle.dump(outputs, f)
    else:
        print("Loading outputs from disk")
        with open(DATA / f"{model_id}-outputs-y-only.pkl", "rb") as f:
            outputs = pickle.load(f)

    return outputs

In [None]:
def get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False):
    if not os.path.exists(DATA / f"{model_id}-pca-y-only.pkl") or force_reeval:
        print("Computing PCA")
        pca = PCA(n_components=n_components)
        pca.fit(outputs)
        reduced = pca.transform(outputs)
        with open(DATA / f"{model_id}-pca-y-only.pkl", "wb") as f:
            pickle.dump((pca, reduced), f)
    else:
        print("Loading PCA from disk")
        with open(DATA / f"{model_id}-pca-y-only.pkl", "rb") as f:
            pca, reduced = pickle.load(f)
    return pca, reduced

In [None]:
from typing import List, TypedDict
import yaml

class FormDict(TypedDict):
    name: str
    components: List[float]

def get_forms(model_id) -> List[FormDict]:
    if os.path.exists(DATA / f"{model_id}-forms-y-only.yaml"):
        print("Loading forms from disk")
        with open(DATA / f"{model_id}-forms-y-only.yaml", "r") as f:
            forms = yaml.safe_load(f)
    else:
        print("Computing forms")
        forms = []
        
    return forms

In [None]:
import plotly.express as px
from sklearn.decomposition import PCA
import seaborn as sns

cmap = sns.color_palette("Spectral", as_cmap=True)
color_indices = np.linspace(0, 1, len(steps))
colors = np.array([cmap(c) for c in color_indices])

def to_color_string(color):
    # return (256 * color[0], 256 * color[1], 256 * color[2], color[3])
    return f"rgb({int(256 * color[0])}, {int(256 * color[1])}, {int(256 * color[2])}, {color[3]})"

In [None]:
def plot_ed(pca, reduced, reduced_smooth, forms, model_id, form_cmap='rainbow', evolute_cmap='Spectral', num_components=3, title="", slug="pca.html"):
    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

    if isinstance(form_cmap, str):
        form_cmap = sns.color_palette(form_cmap, as_cmap=True)
    if isinstance(evolute_cmap, str):
        evolute_cmap = sns.color_palette(evolute_cmap, as_cmap=True)

    form_colors = np.array([to_color_string(form_cmap(c)) for c in np.linspace(0, 1, len(forms))])   
    evolute_colors = np.array([to_color_string(evolute_cmap(c)) for c in np.linspace(0, 1, len(reduced_smooth)-4)])

    for i, j in tqdm.tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            reduced[:, i].min(),
            reduced[:, i].max(),
        )
        xmin, xmax = (
            reduced[:, j].min(),
            reduced[:, j].max(),
        )

        # Forms
        for f, form in enumerate(forms):
            if form[j] is not None:
                # Vertical line
                fig.add_shape(
                    type="line",
                    x0=form[j],
                    y0=ymin * 1.25,
                    x1=form[j],
                    y1=ymax * 1.25,
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )
            if form[i] is not None:
                # Horizontal line
                fig.add_shape(
                    type="line",
                    x0=xmin * 1.25,
                    y0=form[i],
                    x1=xmax * 1.25,
                    y1=form[i],
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )

        ts = np.array(range(2, len(reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Circles
        for ti, t in enumerate(ts):
            center, radius = get_osculating_circle(
                reduced_smooth[:, (j, i)], t
            )
            # if ti % 16 == 0:
            #     # This seems to be cheaper than directly plotting a circle
            #     circle = go.Scatter(
            #         x=center[0] + radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
            #         y=center[1] + radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
            #         mode="lines",
            #         line=dict(color="rgba(0.1, 0.1, 1, 0.05)", width=1),
            #         showlegend=False,
            #     )
            #     fig.add_trace(circle, row=row, col=col)

            centers[ti] = center

        # Centers
        fig.add_trace(
            go.Scatter(
                x=centers[:, 0],
                y=centers[:, 1],
                mode="markers",
                marker=dict(size=2, symbol="x", color=evolute_colors),
                name="Centers",
            ),
            row=row,
            col=col,
        )

        # Original samples
        # fig.add_trace(
        #     go.Scatter(
        #         x=reduced[:, j],
        #         y=reduced[:, i],
        #         mode="markers",
        #         marker=dict(color=colors, size=3),
        #         showlegend=False,
        #     ),
        #     row=row,
        #     col=col,
        # )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=reduced_smooth[:, j],
                y=reduced_smooth[:, i],
                mode="lines",
                line=dict(color="black", width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2500, height=2500)  # Adjust the size as needed
    fig.update_layout(title_text=title, showlegend=False)

    # Save as png
    fig.write_image(str(FIGURES / model_id / slug + ".png"), )

    # Save as html
    pyo.plot(fig, filename=str(FIGURES / model_id / slug + ".html"))

    return fig

In [None]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

REF_SEED = 0
TOKENS_SEED = 0

B = 2 ** 10
K = 16

# Get reference pca
ref_model_id = f"L2H4Minf{REF_SEED}"

print("Retrieving run...")
ref_run = get_unique_run(
    str(SWEEPS / "regression/training-runs/L2H4Minf.yaml"),
    task_config={
        "num_tasks": NUM_TASKS,
        "num_layers": NUM_LAYERS,
        "model_seed": REF_SEED,
    },
    optimizer_config={"lr": MAX_LR},
)
print("Retrieved run.")

# Get reference pca
xs, ys = get_data(ref_run, B, K, TOKENS_SEED)

for model_seed in MODEL_SEEDS[1:3]:
    model_id = f"L2H4Minf{model_seed}"

    os.makedirs(str(FIGURES / model_id), exist_ok=True)
    os.makedirs(str(DATA / model_id), exist_ok=True)

    run = get_unique_run(
        str(SWEEPS / "regression/training-runs/L2H4Minf.yaml"),
        task_config={
            "num_tasks": NUM_TASKS,
            "num_layers": NUM_LAYERS,
            "model_seed": model_seed,
        },
        optimizer_config={"lr": MAX_LR},
    )

    print(f"Tokens generated from seed {TOKENS_SEED} with shape {xs.shape, ys.shape}")

    models, optimizer_state_dicts = get_models_and_optimizers(run, steps, model_id)
    outputs = get_outputs(models, xs, ys, model_id, force_reeval=False)
    pca, reduced = get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False)

    start, end = 0.1, 300
    reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    forms = get_forms(model_id)
    num_forms = len(forms)
    form_cmap = sns.color_palette("rainbow", as_cmap=True)

    fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=f"ED of {model_id} (y-only)", slug="pca-y-only")
    

In [None]:
fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=3, title=model_id)


# Cross-ED

Now let's see what happens when we transform the data from one training run using data from another training run. 

In [None]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

REF_SEED = 0
TOKENS_SEED = 0

ref_model_id = f"L2H4Minf{REF_SEED}"
xs, ys = get_data(ref_run, B, K, TOKENS_SEED) # Larger batch size (bc fewer components per sample)
ref_outputs = get_outputs(steps, xs, ys, ref_model_id, force_reeval=False)
ref_pca, ref_reduced = get_pca_and_reduced(None, ref_model_id, n_components=30, force_reeval=False)

ref_mean = np.mean(ref_outputs, axis=0)

for model_seed in tqdm.tqdm(MODEL_SEEDS, desc="Sweeping over model seeds"):
    model_id = f"L2H4Minf{model_seed}"
    outputs = get_outputs(steps, xs, ys, model_id, force_reeval=False)
    # outputs -= np.mean(outputs, axis=0)
    reduced = ref_pca.transform(outputs)
    
    # Reevaluate explained variance on this new dataset
    print("Evaluating explained variance on new dataset")
    total_variance_new_dataset = np.sum(np.var(outputs, axis=0))
    explained_variances = np.var(reduced, axis=0)  # Variance of each PC in the new dataset
    explained_variance_ratio = explained_variances / total_variance_new_dataset
    total_explained_variance = np.sum(explained_variance_ratio)

    pca.explained_variance_ = explained_variances
    pca.explained_variance_ratio_ = explained_variance_ratio

    print("Applying smoothing")
    start, end = 0.1, 300
    reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    fig = plot_ed(pca, reduced, reduced_smooth, [], model_id, num_components=8, title=f"ED of {model_id} using {ref_model_id}", slug=f"pca-via-{ref_model_id}.html")