# Essential dynamics

In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
import seaborn as sns
from sklearn.decomposition import PCA
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.patches as mpatches

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
import numpy as np
from sklearn.decomposition import PCA
from torch.nn import functional as F
from sklearn.manifold import TSNE
import gc
import itertools
from scipy.ndimage import gaussian_filter1d
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo
import numpy as np
import tqdm
from infra.utils.iterables import int_linspace
from copy import deepcopy
from pathlib import Path

# import sys
# del sys.modules['icl.figures.colors']
# del sys.modules['icl.figures.notation']

from devinterp.slt.forms import get_osculating_circle
from icl.analysis.utils import get_unique_run
from icl.constants import ANALYSIS, FIGURES, SWEEPS, DATA
from icl.figures.notation import str_d_dlogt, str_d_dt, str_dlog_dlogt
from icl.figures.colors import (
    plot_transitions,
    gen_transition_colors,
    get_transition_type,
    PRIMARY,
    SECONDARY,
    TERTIARY,
    BRED,
    BBLUE,
    BRED,
    BGREEN,
)
from icl.constants import DEVICE

# from devinterp.slt.forms import
sns.set_style("white")
DEVICE

NUM_LAYERS = 2
MAX_LR = 0.003
MODEL_SEED = 0
NUM_TASKS = list(2 ** np.arange(0, 21))

steps = int_linspace(0, 500_000, 10_000)[::2]


plt.rcParams["figure.dpi"] = 300

In [None]:
def get_models_and_optimizers(run, steps, model_id):
    if os.path.exists(Path("../checkpoints") / f"{model_id}-models.pt"):
        print("Loading models from disk")
        models = torch.load(Path("../checkpoints") / f"{model_id}-models.pt")
        optimizer_state_dicts = torch.load(
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt"
        )

    else:
        print("Retrieving models from AWS")
        # Let's generate these same plots and also look at their evolution.
        models = []
        optimizer_state_dicts = []

        for step in tqdm.tqdm(steps):
            checkpoint = run.checkpointer.load_file(step)

            m = deepcopy(run.model)
            m.load_state_dict(checkpoint["model"])
            models.append(m)
            optimizer_state_dicts.append(checkpoint["optimizer"])

        print("Saving models to disk")
        torch.save(models, Path("../checkpoints") / f"{model_id}-models.pt")
        torch.save(
            optimizer_state_dicts,
            Path("../checkpoints") / f"{model_id}-optimizer_state_dicts.pt",
        )
    
    return models, optimizer_state_dicts

In [None]:
import pickle
from icl.regression.model import to_token_sequence, from_predicted_token_sequence

K = 16
B = 1024
D = 4

def get_tokens(run, batch_size, max_examples, seed=0):
    torch.manual_seed(seed)

    xs, ys = run.pretrain_dist.get_batch(max_examples, batch_size, return_ws=False)
    tokens = to_token_sequence(xs, ys)
    return tokens



In [None]:
def get_outputs(models, tokens, model_id, force_reeval=False):
    B, K, D = tokens.shape
    K = K // 2  
    D = D - 1 

    outputs = np.zeros((len(models), K * B * (D + 1) * 2), dtype=np.float32)

    if not os.path.exists(DATA / f"{model_id}-outputs.pkl") or force_reeval:
        print("Computing outputs")
        for i, model in enumerate(tqdm.tqdm(models, desc="Computing outputs")):
            with torch.no_grad():
                output = model.token_sequence_transformer(tokens).flatten()
                outputs[i, :] = output.cpu().numpy()

        with open(DATA / f"{model_id}-outputs.pkl", "wb") as f:
            pickle.dump(outputs, f)
    else:
        print("Loading outputs from disk")
        with open(DATA / f"{model_id}-outputs.pkl", "rb") as f:
            outputs = pickle.load(f)

    return outputs

In [None]:
def get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False):
    if not os.path.exists(DATA / f"{model_id}-pca.pkl") or force_reeval:
        print("Computing PCA")
        pca = PCA(n_components=n_components)
        pca.fit(outputs)
        reduced = pca.transform(outputs)
        with open(DATA / f"{model_id}-pca.pkl", "wb") as f:
            pickle.dump((pca, reduced), f)
    else:
        print("Loading PCA from disk")
        with open(DATA / f"{model_id}-pca.pkl", "rb") as f:
            pca, reduced = pickle.load(f)
    return pca, reduced

In [None]:
from typing import List, TypedDict
import yaml

class FormDict(TypedDict):
    name: str
    components: List[float]

def get_forms(model_id) -> List[FormDict]:
    if os.path.exists(DATA / f"{model_id}/forms.yaml"):
        print("Loading forms from disk")
        with open(DATA / f"{model_id}/forms.yaml", "r") as f:
            forms = yaml.safe_load(f)
    else:
        print("Computing forms")
        forms = []
        
    return forms

In [None]:
import plotly.express as px
from sklearn.decomposition import PCA
import seaborn as sns

cmap = sns.color_palette("Spectral", as_cmap=True)
color_indices = np.linspace(0, 1, len(steps))
colors = np.array([cmap(c) for c in color_indices])

def to_color_string(color):
    # return (256 * color[0], 256 * color[1], 256 * color[2], color[3])
    return f"rgb({int(256 * color[0])}, {int(256 * color[1])}, {int(256 * color[2])}, {color[3]})"


In [None]:
def plot_ed(pca, reduced, reduced_smooth, forms, model_id, form_cmap='rainbow', evolute_cmap='Spectral', num_components=3, title="", slug="pca", auto_open=True, tmin=0, tmax=-1,
            save_png=False, kmeans=None):
    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

    if isinstance(form_cmap, str):
        form_cmap = sns.color_palette(form_cmap, as_cmap=True)
    if isinstance(evolute_cmap, str):
        evolute_cmap = sns.color_palette(evolute_cmap, as_cmap=True)

    form_colors = np.array([to_color_string(form_cmap(c)) for c in np.linspace(0, 1, len(forms))])   
    evolute_colors = np.array([to_color_string(evolute_cmap(c)) for c in np.linspace(0, 1, len(reduced_smooth)-4)])

    for i, j in tqdm.tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            reduced[tmin:tmax, i].min(),
            reduced[tmin:tmax, i].max(),
        )
        xmin, xmax = (
            reduced[tmin:tmax, j].min(),
            reduced[tmin:tmax, j].max(),
        )

        # Forms
        for f, form in enumerate(forms):
            if j < len(form['components']) and form['components'][j] is not None:
                # Vertical line
                fig.add_shape(
                    type="line",
                    x0=form['components'][j],
                    y0=ymin * 1.25,
                    x1=form['components'][j],
                    y1=ymax * 1.25,
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )
            if i < len(form['components']) and form['components'][i] is not None:
                # Horizontal line
                fig.add_shape(
                    type="line",
                    x0=xmin * 1.25,
                    y0=form['components'][i],
                    x1=xmax * 1.25,
                    y1=form['components'][i],
                    line=dict(color=form_colors[f], width=1),
                    row=row,
                    col=col,
                )

        ts = np.array(range(2, len(reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Circles
        for ti, t in enumerate(ts):
            center, radius = get_osculating_circle(
                reduced_smooth[:, (j, i)], t
            )
            # if ti % 16 == 0:
            #     # This seems to be cheaper than directly plotting a circle
            #     circle = go.Scatter(
            #         x=center[0] + radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
            #         y=center[1] + radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
            #         mode="lines",
            #         line=dict(color="rgba(0.1, 0.1, 1, 0.05)", width=1),
            #         showlegend=False,
            #     )
            #     fig.add_trace(circle, row=row, col=col)

            centers[ti] = center

        # Centers
        fig.add_trace(
            go.Scatter(
                x=centers[:, 0],
                y=centers[:, 1],
                mode="markers",
                marker=dict(size=2, symbol="x", color=evolute_colors),
                name="Centers",
            ),
            row=row,
            col=col,
        )

        # Original samples
        # fig.add_trace(
        #     go.Scatter(
        #         x=reduced[:, j],
        #         y=reduced[:, i],
        #         mode="markers",
        #         marker=dict(color=colors, size=3),
        #         showlegend=False,
        #     ),
        #     row=row,
        #     col=col,
        # )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=reduced_smooth[:, j],
                y=reduced_smooth[:, i],
                mode="lines",
                line=dict(color="black", width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        if kmeans is not None:
            fig.add_trace(
                go.Scatter(
                    x=kmeans.cluster_centers_[:, j],
                    y=kmeans.cluster_centers_[:, i],
                    mode="markers",
                    marker=dict(size=8, symbol="x", color="black"),
                    showlegend=False,
                ),
                row=row,
                col=col,
            )

        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2500, height=2500)  # Adjust the size as needed
    fig.update_layout(title_text=title, showlegend=False)

    # Save as html
    pyo.plot(fig, filename=str(FIGURES / model_id / f"{slug}.html"), auto_open=auto_open)

    if save_png:
        fig.write_image(str(FIGURES / model_id /  f"{slug}.png"), scale=0.25)

    return fig

In [None]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

TOKENS_SEED = 0

for num_tasks in NUM_TASKS[4:]:
    model_id = f"L2H4M{num_tasks}"

    os.makedirs(str(FIGURES / model_id), exist_ok=True)
    os.makedirs(str(DATA / model_id), exist_ok=True)

    print("Retrieving run...")
    run = get_unique_run(
        str(SWEEPS / "regression/training-runs/L2H4Mfin.yaml"),
        task_config={
            "num_tasks": num_tasks,
            "num_layers": NUM_LAYERS,
            "model_seed": MODEL_SEED,
        },
        optimizer_config={"lr": MAX_LR},
    )
    print("Retrieved run.")

    models, optimizer_state_dicts = get_models_and_optimizers(run, steps, model_id)
    tokens = get_tokens(run, B, K, seed=TOKENS_SEED)
    print(f"Tokens generated from seed {TOKENS_SEED} with shape {tokens.shape}")

    outputs = get_outputs(models, tokens, model_id, force_reeval=False)
    pca, reduced = get_pca_and_reduced(outputs, model_id, n_components=30, force_reeval=False)

    start, end = 0.1, 300
    reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    forms = get_forms(model_id)
    num_forms = len(forms)
    form_cmap = sns.color_palette("rainbow", as_cmap=True)

    fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id)
    

In [None]:
forms

In [None]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

TOKENS_SEED = 0

for num_tasks in NUM_TASKS[:1]:
    model_id = f"L2H4M{num_tasks}"

    print("Retrieving run...")
    # run = get_unique_run(
    #     str(SWEEPS / "regression/training-runs/L2H4Mfin.yaml"),
    #     task_config={
    #         "num_tasks": num_tasks,
    #         "num_layers": NUM_LAYERS,
    #         "model_seed": MODEL_SEED,
    #     },
    #     optimizer_config={"lr": MAX_LR},
    # )
    # print("Retrieved run.")

    # models, optimizer_state_dicts = get_models_and_optimizers(run, steps, model_id)
    # tokens = get_tokens(run, B, K, seed=TOKENS_SEED)
    # print(f"Tokens generated from seed {TOKENS_SEED} with shape {tokens.shape}")

    # outputs = get_outputs(models, tokens, model_id, force_reeval=False)
    pca, reduced = get_pca_and_reduced(None, model_id, n_components=30, force_reeval=False)

    start, end = 0.1, 300
    reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    forms = get_forms(model_id)
    num_forms = len(forms)
    form_cmap = sns.color_palette("rainbow", as_cmap=True)

    fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id, save_png=False)
    

# Cross-ED

Now let's see what happens when we transform the data from one training run using data from another training run. 

In [None]:
# steps

from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

REF_SEED = 0
TOKENS_SEED = 0

# Get reference pca
ref_model_id = f"L2H4Minf{REF_SEED}"
ref_outputs = get_outputs(steps, tokens, ref_model_id, force_reeval=False)
ref_pca, ref_reduced = get_pca_and_reduced(None, ref_model_id, n_components=30, force_reeval=False)

ref_mean = np.mean(ref_outputs, axis=0)

for model_seed in tqdm.tqdm(MODEL_SEEDS[3:], desc="Sweeping over model seeds"):
    model_id = f"L2H4Minf{model_seed}"

    outputs = get_outputs(steps, tokens, model_id, force_reeval=False)
    # outputs -= np.mean(outputs, axis=0)
    reduced = ref_pca.transform(outputs)
    
    # Reevaluate explained variance on this new dataset
    print("Evaluating explained variance on new dataset")
    total_variance_new_dataset = np.sum(np.var(outputs, axis=0))
    explained_variances = np.var(reduced, axis=0)  # Variance of each PC in the new dataset
    explained_variance_ratio = explained_variances / total_variance_new_dataset
    total_explained_variance = np.sum(explained_variance_ratio)

    pca.explained_variance_ = explained_variances
    pca.explained_variance_ratio_ = explained_variance_ratio

    print("Applying smoothing")
    start, end = 0.1, 300
    reduced_smooth = gaussian_filter1d_variable_sigma(reduced, np.linspace(start, end, len(reduced)), axis=0)

    fig = plot_ed(pca, reduced, reduced_smooth, [], model_id, num_components=8, title=f"ED of {model_id} using {ref_model_id}", slug=f"pca-via-{ref_model_id}.html")

In [None]:
from icl.figures.derivatives import d_dt, d_dlogt


def plot_pc_over_time(pca, reduced, reduced_smooth, forms, model_id, form_cmap='rainbow', evolute_cmap='Spectral', num_components=3, title="", slug="pca-over-time", auto_open=True, tmin=0, tmax=-1,
            save_png=False):
    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    pc_row = 1
    slope_row = 2
    curvature_row = 3
    center_row = 4
    radius_row = 5

    subplot_titles = []
    fig = make_subplots(rows=5, cols=num_components, subplot_titles=subplot_titles)

    if isinstance(form_cmap, str):
        form_cmap = sns.color_palette(form_cmap, as_cmap=True)

    form_colors = np.array([to_color_string(form_cmap(c)) for c in np.linspace(0, 1, len(forms))])   

    ts = np.arange(len(reduced_smooth))

    # Principal components
    for j in tqdm.trange(num_components):
        col = j + 1
        # Centers
        fig.add_trace(
            go.Scatter(
                x=ts,
                y=reduced_smooth[:, j],
                mode="lines",
                line=dict(width=1),
                name="Centers",
            ),
            row=1,
            col=col,
        )
        
    max_center, max_radius = -np.inf, -np.inf

    # Fit osculating circles
    for j in tqdm.trange(num_components): 
        col= j + 1

        _ts = ts[2:-2]
        centers = np.zeros((len(ts), 2))
        radiuses = np.zeros(len(ts))

        slopes = d_dlogt(ts, reduced_smooth[:, j])
        curvatures = d_dlogt(ts[1:-1], slopes[1:-1])

        fig.add_trace(
            go.Scatter(
                x=ts[1:-1],
                y=slopes[1:-1],
                mode="lines",
                line=dict(width=1),
                name="Slopes",
            ),
            row=slope_row,
            col=col,
        )

        fig.add_trace(
            go.Scatter(
                x=ts[2:-2],
                y=curvatures[2:-2],
                mode="lines",
                line=dict(width=1),
                name="Curvatures",
            ),
            row=curvature_row,
            col=col,
        )

        # Circles
        for i in range(num_components):
                
            for ti, t in enumerate(ts):
                center, radius = get_osculating_circle(
                    reduced_smooth[:, (i, j)], t
                )
                centers[ti] = center
                radiuses[ti] = 1/radius

            # Centers
            fig.add_trace(
                go.Scatter(
                    x=_ts,
                    y=centers[:, 1],
                    mode="lines",
                    line=dict(width=0.5),
                    name=f"Evolute in PC{i+1}-{j+1} plane",
                ),
                row=center_row,
                col=col,
            )

            # Radius
            fig.add_trace(
                go.Scatter(
                    x=_ts,
                    y=radiuses,
                    mode="lines",
                    line=dict(width=0.5),
                    name=f"Radius of osculating circle in PC{i+1}-{j+1} plane",
                ),
                row=radius_row,
                col=col,
            )

        # Remove outliers from centers & radiuses, then compute maxima
        normal_centers = np.where(np.abs(centers[:, 1]) < 1e3, centers[:, 1], 0)
        normal_radiuses = np.where(np.abs(radiuses) < 1e3, radiuses, 0)
        max_center = max(max_center, np.max(normal_centers))
        max_radius = max(max_radius, np.max(normal_radiuses))

    for j in tqdm.trange(num_components):
        col = j + 1

        max_center = reduced_smooth[:, j].max()
        max_radius = reduced_smooth[:, j].max()

        for row in range(1, 3):
            fig.update_xaxes(
                row=row,
                col=col,
                type='log'
            )

        # Update scale
        fig.update_yaxes(
            range=(-max_center * 1.25, max_center * 1.25),
            row=center_row,
            col=col,

        )
        fig.update_yaxes(
            range=(0, max_radius * 1.25),
            row=radius_row,
            col=col,
            type='log'
        )


        # if j == 0:
        #     fig.update_yaxes(title_text=labels[str(i)], row=1, col=col)

        # fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        # fig.update_xaxes(
        #     range=(xmin * 1.25, xmax * 1.25),
        #     row=row,
        #     col=col,
        # )
        # fig.update_yaxes(
        #     range=(ymin * 1.25, ymax * 1.25),
        #     row=row,
        #     col=col,
        # )

    fig.update_layout(width=2000, height=1500)  # Adjust the size as needed
    fig.update_layout(title_text=title, showlegend=False)

    # Save as html
    pyo.plot(fig, filename=str(FIGURES / model_id / f"{slug}.html"), auto_open=auto_open)

    if save_png:
        fig.write_image(str(FIGURES / model_id /  f"{slug}.png"), scale=0.25)

    return fig
    

plot_pc_over_time(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id, tmin=0, tmax=1000, save_png=False)


In [None]:
from sklearn.cluster import KMeans

print(reduced_smooth.shape)

kmeans = KMeans(n_clusters=8, random_state=0).fit(reduced_smooth)


In [None]:
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans

class KVars(KMeans):
    def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, verbose=0, random_state=None, copy_x=True):
        super().__init__(
            n_clusters=n_clusters,
            init='random',
            max_iter=max_iter,
            tol=tol,
            verbose=verbose,
            random_state=random_state,
            copy_x=copy_x,
            n_init=1
        )
    
    def fit(self, X, sample_weight=None):
        random_state = np.random.RandomState(self.random_state)

        # Initialize centroids randomly from the dataset
        initial_centroids = random_state.permutation(X.shape[0])[:self.n_clusters]
        self.cluster_centers_ = X[initial_centroids]

        for i in range(self.max_iter):
            # Assign clusters based on closest centroid
            labels = pairwise_distances_argmin_min(X, self.cluster_centers_)[0]

            # Compute new centroids and the within-cluster variance
            new_centroids = np.zeros_like(self.cluster_centers_)
            cluster_variances = np.zeros(self.n_clusters)
            for k in range(self.n_clusters):
                cluster_points = X[labels == k]
                if len(cluster_points) > 0:
                    # Calculate the variance within the cluster
                    cluster_variance = np.var(pairwise_distances(cluster_points, [self.cluster_centers_[k]]))
                    cluster_variances[k] = cluster_variance

                    # Update the centroid to minimize the variance (use the current centroid)
                    # This is where you can modify the centroid update rule to minimize the variance
                    new_centroids[k] = self.cluster_centers_[k] - (self.tol / np.sqrt(cluster_variance)) * (cluster_points - self.cluster_centers_[k]).sum(axis=0)

            # Check for convergence
            shift = np.sqrt(np.sum((new_centroids - self.cluster_centers_) ** 2, axis=1)).max()
            if shift <= self.tol:
                if self.verbose:
                    print(f"Converged at iteration {i}: center shift {shift} within tolerance {self.tol}.")
                break

            self.cluster_centers_ = new_centroids

        self.labels_ = labels
        self.inertia_ = np.sum([cluster_variances[k] * len(X[labels == k]) for k in range(self.n_clusters)])
        self.n_iter_ = i

        return self
    

kvars = KVars(n_clusters=8, random_state=0, verbose=False).fit(reduced_smooth)
fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id, save_png=False, slug='pca-kmeans', kmeans=kvars)

In [None]:
fig = plot_ed(pca, reduced, reduced_smooth, forms, model_id, num_components=8, title=model_id, save_png=False, slug='pca-kmeans', kmeans=kmeans)