# Language

[TODO]

In [None]:
import os
from dotenv import load_dotenv

load_dotenv();

In [None]:
from pathlib import Path
import json
import pickle
from collections import Counter
import warnings

import tqdm 
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

from icl.constants import DEVICE, DATA, ANALYSIS
from icl.language.model import get_model
from icl.language.utils import translate_int_to_str
from icl.figures.colors import plot_transitions, gen_transition_colors, get_transition_type, PRIMARY, SECONDARY, TERTIARY, BRED, BBLUE, BRED, BGREEN
from icl.figures.plotting import WIDTH, HEIGHT, FULL_WIDTH, FULL_HEIGHT

warnings.filterwarnings('ignore')

model = get_model()
AWS_BUCKET_NAME = os.getenv("AWS_BUCKET_NAME")

# Figure 1

In [None]:
import wandb 

api = wandb.Api()

LLC_RUN_IDS = [
    "72lcu5u5",
    "zpaasz5l",
    "vml7sbd2",
    "82ytc5lt",
    "5o92vslc",
]

def get_run(run_id):
    run = api.run(f"devinterp/tetrahedron-3m/{run_id}")
    history = run.history()
    config = run.config

    for k, v in config.items():
        history[k] = [v] * len(history)
 
    history = history.loc[history["_step"] <= 50_000]

    if run_id == "72lcu5u5":
        history["model_seed"] = None

    return history

def get_llc_runs(run_ids):
    runs = [get_run(run_id) for run_id in run_ids]
    return pd.concat(runs)

df = get_llc_runs(LLC_RUN_IDS)
df

In [None]:
list(df.model_seed.unique())

df.loc[df.model_seed.isnull()]

In [None]:
import numpy as np
from sklearn.decomposition import PCA
import pickle

with open(ANALYSIS / 'language/L2W256-per-token-logits.pkl', 'rb') as f:
    per_token_loss = pickle.load(f)

per_token_loss = np.array(per_token_loss)
print(per_token_loss.shape)

pca = PCA(n_components=3)
pca.fit(per_token_loss[:300])
projections = pca.transform(per_token_loss[:300])
pca.fit(per_token_loss)
projections = pca.transform(per_token_loss)
pca.explained_variance_ratio_

In [None]:
from icl.figures.colors import gen_transition_colors

TRANSITIONS = [
    (0, 900, "LM1"),
    (900, 6500, "LM2"),
    (6500, 8500, "LM3"),
    (8500, 17_500, "LM4"),
    (17_500, 49_900, "LM5")
]

colors = gen_transition_colors(['A', 'A', 'B', 'A', 'Other']) # sns.color_palette("coolwarm_r", n_colors=len(TRANSITIONS))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.lines as mlines
from icl.constants import FIGURES

sns.set_palette('deep')

mean_df = df.groupby('_step').mean().reset_index()
std_df = df.groupby('_step').std().reset_index()

# Data setup (use your actual data here)
# For demonstration, replace evals_over_time_df with your DataFrame
# evals_over_time_df = ...

# Create figure
fig, axs = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))

axs[0].set_ylabel(r'Test loss  $\hat{\ell}(w_t)$' '\n')
axs[1].set_ylabel(r'Local learning coeff.  $\hat\lambda(w_t)$')
# First line plot
# sns.lineplot(df, x='_step', y='pile13m/init_loss', ax=axs[0], errorbar=('sd', 2))
sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/init_loss', ax=axs[0])
# axs[0].set_title(r'(a) Loss over Time')
axs[0].set_xscale('log')
#axs[0].set_yscale('log')
# axs[0].set_xlim(100, 500_000)

# Second line plot
# sns.lineplot(df, x='_step', y='pile13m/llc', ax=axs[1], errorbar=('sd', 2))
sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/llc', ax=axs[1],)

# axs[1].set_title(r'(c) Local Learning Coefficient over Time')
axs[1].set_xscale('log')
# axs[1].set_xlim(100, 500_000)

for ax in axs:
    ax.legend().remove()
 
handles = plot_transitions(axs, TRANSITIONS, xlim=True, colors=colors) 
# Set x-label for both plots
# for ax in axs:
axs[0].set_xlabel('')
axs[1].set_xlabel('Training step $t$')

axs[0].set_ylim(3, 5.5)
axs[1].set_ylim(75, 160)

# fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.55, 0.01), ncol=len(TRANSITIONS))
labels = [t[2] for t in TRANSITIONS]
dummy = mlines.Line2D([], [], color='none', label='')
legend = fig.legend(handles=[dummy] + handles, labels=[""] + labels, loc='upper center', bbox_to_anchor=(0.54, 0.01), ncol=len(TRANSITIONS)+1)
fig.text(0.14, -0.058, 'Stages:', horizontalalignment='left', verticalalignment='bottom', zorder=1000, fontsize=9)


# axs[1].legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(TRANSITIONS))
# Layout adjustments
plt.tight_layout()
fig.set_facecolor('white')

fig.savefig(FIGURES / f"language/lm-fig1-top.pdf", bbox_inches='tight')
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.lines as mlines
from icl.constants import FIGURES

sns.set_palette('deep')

mean_df = df.groupby('_step').mean().reset_index()
std_df = df.groupby('_step').std().reset_index()

# Data setup (use your actual data here)
# For demonstration, replace evals_over_time_df with your DataFrame
# evals_over_time_df = ...

# Create figure
fig, axs = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))

axs[0].set_ylabel(r'Test loss  $\hat{\ell}(w_t)$' '\n')
axs[1].set_ylabel(r'Local learning coeff.  $\hat\lambda(w_t)$')
# First line plot
# sns.lineplot(df, x='_step', y='pile13m/init_loss', ax=axs[0], errorbar=('sd', 2))
axs[0].fill_between(mean_df._step, mean_df['pile13m/init_loss'] - 2 * std_df['pile13m/init_loss'], mean_df['pile13m/init_loss'] + 2 * std_df['pile13m/init_loss'], alpha=0.5)
sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/init_loss', ax=axs[0], color='red', alpha=0.75, size=0.5)
# axs[0].set_title(r'(a) Loss over Time')
axs[0].set_xscale('log')
#axs[0].set_yscale('log')
# axs[0].set_xlim(100, 500_000)

# Second line plot
# sns.lineplot(df, x='_step', y='pile13m/llc', ax=axs[1], errorbar=('sd', 2))
axs[1].fill_between(mean_df._step, mean_df['pile13m/llc'] - 2 * std_df['pile13m/llc'], mean_df['pile13m/llc'] + 2 * std_df['pile13m/llc'], alpha=0.5)
sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/llc', ax=axs[1], color='red', alpha=0.75, size=0.5)

# axs[1].set_title(r'(c) Local Learning Coefficient over Time')
axs[1].set_xscale('log')
# axs[1].set_xlim(100, 500_000)

for ax in axs:
    ax.legend().remove()
 
handles = plot_transitions(axs, TRANSITIONS, xlim=True, colors=colors) 
# Set x-label for both plots
# for ax in axs:
axs[0].set_xlabel('')
axs[1].set_xlabel('Training step $t$')

axs[0].set_ylim(3, 5.5)
axs[1].set_ylim(75, 170)

# fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.55, 0.01), ncol=len(TRANSITIONS))
labels = [t[2] for t in TRANSITIONS]
dummy = mlines.Line2D([], [], color='none', label='')
legend = fig.legend(handles=[dummy] + handles, labels=[""] + labels, loc='upper center', bbox_to_anchor=(0.54, 0.01), ncol=len(TRANSITIONS)+1)
fig.text(0.14, -0.058, 'Stages:', horizontalalignment='left', verticalalignment='bottom', zorder=1000, fontsize=9)


# axs[1].legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(TRANSITIONS))
# Layout adjustments
plt.tight_layout()
fig.set_facecolor('white')

fig.savefig(FIGURES / f"language/lm-fig1-top.pdf", bbox_inches='tight')
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.lines as mlines
from icl.constants import FIGURES

sns.set_palette('deep')

mean_df = df.groupby('_step').mean().reset_index()
std_df = df.groupby('_step').std().reset_index()

# Data setup (use your actual data here)
# For demonstration, replace evals_over_time_df with your DataFrame
# evals_over_time_df = ...

# Create figure
fig, axs = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))

axs[0].set_ylabel(r'Test loss  $\hat{\ell}(w_t)$' '\n')
axs[1].set_ylabel(r'Local learning coeff.  $\hat\lambda(w_t)$')
# First line plot
sns.lineplot(df, x='_step', y='pile13m/init_loss', ax=axs[0], errorbar=('sd', 2), hue="model_seed", palette="deep", alpha=0.8, size=0.5)
# axs[0].fill_between(mean_df._step, mean_df['pile13m/init_loss'] - 2 * std_df['pile13m/init_loss'], mean_df['pile13m/init_loss'] + 2 * std_df['pile13m/init_loss'], alpha=0.5)
# sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/init_loss', ax=axs[0], color='red', alpha=0.75, size=0.5)
# axs[0].set_title(r'(a) Loss over Time')
axs[0].set_xscale('log')
#axs[0].set_yscale('log')
# axs[0].set_xlim(100, 500_000)

# Second line plot
sns.lineplot(df, x='_step', y='pile13m/llc', ax=axs[1], errorbar=('sd', 2), hue="model_seed", palette="deep", alpha=0.8, size=0.5)
# axs[1].fill_between(mean_df._step, mean_df['pile13m/llc'] - 2 * std_df['pile13m/llc'], mean_df['pile13m/llc'] + 2 * std_df['pile13m/llc'], alpha=0.5)
# sns.lineplot(df.loc[df.model_seed.isnull()], x='_step', y='pile13m/llc', ax=axs[1], color='red', alpha=0.75, size=0.5)

# axs[1].set_title(r'(c) Local Learning Coefficient over Time')
axs[1].set_xscale('log')
# axs[1].set_xlim(100, 500_000)

for ax in axs:
    ax.legend().remove()
    ax.set_xlim(100, 50_000)
 
# handles = plot_transitions(axs, TRANSITIONS, xlim=True, colors=colors) 
# Set x-label for both plots
# for ax in axs:
axs[0].set_xlabel('')
axs[1].set_xlabel('Training step $t$')

axs[0].set_ylim(3, 5.5)
axs[1].set_ylim(75, 170)

# fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.55, 0.01), ncol=len(TRANSITIONS))
# labels = [t[2] for t in TRANSITIONS]
# dummy = mlines.Line2D([], [], color='none', label='')
# legend = fig.legend(handles=[dummy] + handles, labels=[""] + labels, loc='upper center', bbox_to_anchor=(0.54, 0.01), ncol=len(TRANSITIONS)+1)
# fig.text(0.14, -0.058, 'Stages:', horizontalalignment='left', verticalalignment='bottom', zorder=1000, fontsize=9)


# axs[1].legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(TRANSITIONS))
# Layout adjustments
plt.tight_layout()
fig.set_facecolor('white')

fig.savefig(FIGURES / f"language/lm-fig1-top-multiple.pdf", bbox_inches='tight')
plt.show()


In [None]:
fig, axes = plt.subplots(2, 1, figsize=(WIDTH * 1.5, HEIGHT * 1.5))

#m print(y)
for model_seed in df.model_seed.unique():
    _df = df[df.model_seed == model_seed].sort_values(by="_step")
    sns.lineplot(x=_df["pile13m/init_loss"].values, y=_df["pile13m/llc"].values, size=0.25, ax=axes[0], alpha=0.5)
    axes[1].plot(_df["pile13m/init_loss"].values - 3.125, _df["pile13m/llc"].values, linewidth=0.5, alpha=0.5)

# sns.lineplot(x=df["pile13m/init_loss"].values, y=df["pile13m/llc"].values, size=0.25, color=BRED, alpha=0.5)

for ax in axes:
    ax.legend().remove()
    ax.set_ylabel("LLC $\hat\lambda$")

# ax.set_yscale('log')
ax.set_xscale('log')
# ax.set_xlim(1.7, 5)
ax.set_xlabel("Loss $L_n$")



In [None]:
from matplotlib import lines as mlines
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from icl.figures.derivatives import d_dt

from icl.figures.notation import str_d_dlogt, str_d_dt, str_dlog_dlogt
from icl.figures.colors import plot_transitions, gen_transition_colors, get_transition_type, PRIMARY, SECONDARY, TERTIARY, BRED, BBLUE, BRED, BGREEN

plt.rcParams['figure.dpi'] = 300

metrics_to_plot = [
    (r"\hat\ell(w_t)", "loss", {"logy": True, "derivative": "d_dlogt", "spline": True, "s": 0.1}, ),
    # (r"L_\mathcal{G}(t)", df["true/mse"], {"logy": False}),
    (r"\hat \lambda(w_t)", 'llc_mean', {"derivative": "d_dlogt", "spline": True}),
    (r"|w_t|", "weight/norm", {"derivative": "d_dt", "logy": True, "spline": True, "s": 0.1}),
] 
fig, axes = plt.subplots(2, len(metrics_to_plot), figsize=(FULL_WIDTH * 1.25, FULL_HEIGHT * 1.5))

# axes = np.array(axes)
axes = axes.reshape(2, len(metrics_to_plot))

def str_dlog_dlogt(s):
    return r"$\delta \log " + s + r"/\delta\log t$"

for i, (metric_name, metric_key, kwargs) in enumerate(metrics_to_plot):
    use_spline = kwargs.get("spline", False)

    sns.lineplot(data=df, x="step", y=df[metric_key], ax=axes[0, i],label=metric_name, alpha=1 - use_spline * 0.75)
    # axes[0, i].plot(df['step'], metric_values, label=metric_name, marker='.')
    axes[0, i].set_title(f"")
    axes[0, i].set_xlabel('')
    axes[0, i].set_ylabel(f"${metric_name}$")

    if kwargs.get("logy", False):
        axes[0, i].set_yscale('log')

    axes[0, i].legend().remove()

    slope_type = kwargs.get("derivative", "d_dlogt")

    if slope_type == "d_dlogt":
        slope_name = str_d_dlogt(metric_name)
        ys = d_dt(np.log(df['step'].values), df[metric_key].values)
    elif slope_type == "d_dt":
        slope_name = str_d_dt(metric_name)
        ys = d_dt(df['step'].values, df[metric_key].values)
    elif slope_type == "dlog_dlogt":
        slope_name = str_dlog_dlogt(metric_name)
        ys = d_dt(np.log(df['step'].values), np.log(df[metric_key].values))
    else:
        raise ValueError(f"Unknown slope type {slope_type}")

    sns.lineplot(data=df, x="step", y=ys, ax=axes[1, i], label=metric_name + " Slope", alpha=1 - use_spline * 0.75)
    axes[1, i].axhline(0, linestyle='--', color='gray')
    axes[1, i].set_title("")
    axes[1, i].set_ylabel(slope_name)
    axes[1, i].legend().remove()
    
    if use_spline:     
        _steps = np.log(np.array(steps) + 1 ).reshape((-1, 1))
        _y = df.groupby('step').mean()[metric_key].values

        kernel = C(1.0, (1e-3, 1e3)) * RBF(3, (5e-1, 1e3))

        # Create a Gaussian Process Regressor
        gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)

        # Fit the Gaussian Process
        gp.fit(_steps, _y)
        _ypred = gp.predict(_steps)

        if slope_type == "d_dlogt":
            _derivy = d_dt(_steps, _ypred)
        elif slope_type == "d_dt":
            _derivy = d_dt(np.exp(_steps), _ypred)
        elif slope_type == "dlog_dlogt":            
            _derivy = d_dt(_steps, np.log(_ypred))
        
        axes[0, i].plot(steps, _ypred, label="Spline", linestyle='--', color=BRED)
        axes[1, i].plot(steps, _derivy, label="Spline", linestyle='--', color=BRED)

        print()
        print(metric_name)
        for t1, t2, label in TRANSITIONS:
            i1, i2 = list(steps).index(t1), list(steps).index(t2)
            print(f"{label} {t1}: Original: {_ypred[i1]} Smoothed: {_derivy[i1]} ")

        print("End 50k: Original: ", _ypred[-1], "Smoothed: ", _derivy[-1])

        # Find indices where |derivy| < 2
        if metric_key == "llc_mean":
            print("Indices where |derivy| < .25: ", steps[np.where(np.abs(_derivy) < .25)[0]])



patch_list = plot_transitions(axes, TRANSITIONS, xlim=True, colors=colors)


for ax in axes[0]:
    ax.set_xlabel('')
    
for ax in axes.flatten():
    ax.set_xscale('log')
    ax.set_xlim(100, 49_000)
    # ax.set_ylabel("")

# axes[1, 1].set_yscale('symlog')
# axes[1, 0].set_yscale('symlog')
# axes[0,0].set_ylim(0, 70)

milestone_labels = [label for _, _, label in TRANSITIONS]
gp_fit_patch = mlines.Line2D([], [], color=BRED, linestyle='--', label="GP Fit")
fig.legend(patch_list + [gp_fit_patch], milestone_labels + ["Fit"], loc='upper center', bbox_to_anchor=(0.5, -0.025), ncol=len(TRANSITIONS) + 1)

fig.set_facecolor("white")
fig.tight_layout()


# axes[0, 1].set_ylim(0, 100)
# axes[1, 0].set_ylim(-2.25, 2.25)
# axes[1, 1].set_ylim(-150, 160)
axes[0, 0].set_ylim(3, 6)
axes[0, 1].set_ylim(135, 185)
axes[1, 0].set_ylim(-.75, .25)
axes[1, 1].set_ylim(-15, 25)
# axes[1, 1].set_ylim(-0.003, .015)

fig.savefig(FIGURES / f"language/{MODEL_ID}-loss-llc-with-slopes.pdf", bbox_inches='tight')