# RLCT estimation

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

if os.environ.get("PJRT_DEVICE") == "TPU":
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp
    import torch_xla.distributed.parallel_loader as pl

    print(xm.get_xla_supported_devices())

In [None]:
from typing import Dict, List, Optional
from pprint import pp
import yaml

import wandb
from tqdm import tqdm
import torch
# from devinterp.evals import SamplerEvaluator
from devinfra.io.storage import CheckpointerConfig, BaseStorageProvider
from devinfra.evals import ModelEvaluator
from devinfra.optim.optimizers import OptimizerConfig
from devinfra.utils.seed import set_seed

from icl.config import get_config
from icl.evals import ICLEvaluator
import logging
import os

from matplotlib import pyplot as plt
import seaborn as sns
# Warning block_idx in the dataframe starts at 1... sorry. 
cmap = sns.color_palette("viridis_r", 16) # Light to dark


logging.basicConfig(level=logging.WARNING)

In [None]:
def get_sweep_configs(sweep_config_dicts: List[Dict], **kwargs):
    for sweep_config_dict in sweep_config_dicts:
        yield get_config(**sweep_config_dict, **kwargs)

# Need to include the following because I was a dumbass and created the checkpoint names with a hash that didn't include all of the defualts.
shared = {"model_seed": 0, "pretrain_seed": 1, "true_seed": 2, "sampling_seed": 3}
layers4 = {"task_size": 4, "max_examples": 8, "num_layers": 4, "num_heads": 4, "embed_size": 64, "mlp_size": 64, "noise_variance": 0.125}
configs = list(get_sweep_configs([{"task_config": {"num_tasks": 2**i, **layers4, **shared}, "optimizer_config": {"lr": 0.01}} for i in range(21)]))
configs

In [None]:
checkpointers = []

for config in configs:
    checkpointer = config.checkpointer_config.factory()
    checkpointers.append(checkpointer)
    print(repr(checkpointer) + ":", str(tuple(checkpointer.file_ids)))

In [None]:
def load_model_at_step(config, step: int, checkpointer: Optional[BaseStorageProvider] = None):
    if checkpointer is None:
        checkpointer = config.checkpointer_config.factory()

    model = config.task_config.model_factory()
    model_state_dict = checkpointer.load_file(step)["model"]
    model.load_state_dict(model_state_dict)

    return model


def load_model_at_last_checkpoint(config, checkpointer: Optional[BaseStorageProvider] = None):
    if checkpointer is None:
        checkpointer = config.checkpointer_config.factory()

    model = config.task_config.model_factory()
    model_state_dict = checkpointer[-1]["model"]
    model.load_state_dict(model_state_dict)

    return model
    

def eval_model_over_checkpoints(model, checkpointer: BaseStorageProvider, evaluator: ModelEvaluator, verbose=False):
    steps = checkpointer.file_ids
    evals = []

    for step in steps:
        model_state_dict = checkpointer.load_file(step)["model"]
        model.load_state_dict(model_state_dict)
        evals.append({**evaluator(model), "step": step})

        if verbose:
            print("\n")
            print(f"Step {step}")
            print(yaml.dump(evals[-1]))

    return pd.DataFrame(evals)


In [None]:
device = "cpu"
config = configs[0]

# initialise model
model = load_model_at_last_checkpoint(config, checkpointers[0])

# initialise 'pretraining' data source (for training on fixed task set)
pretrain_dist = config.task_config.pretrain_dist_factory().to(device)

# initialise 'true' data source (for evaluation, including unseen tasks)
true_dist = config.task_config.true_dist_factory().to(device)

# initialise evaluations
evaluator = ICLEvaluator(
    pretrain_dist=pretrain_dist,
    true_dist=true_dist,
    max_examples=config.task_config.max_examples,
    eval_batch_size=config.eval_batch_size,
)

# Load model    
# model.load_state_dict(checkpointers[0][-1]["model"])
evaluator(model)

In [None]:
from torchtyping import TensorType
from devinterp.mechinterp.activations import ActivationProbe
from devinfra.utils.iterables import flatten_dict
import pandas as pd

def make_attention_entropy_evals(xs, ys, **paths):
    """
    Each path should be a
    """
    def get_attention_entropy(attn: TensorType["B", "H", "T", "T"]):
        log_attention = torch.where(attn > 0, torch.log(attn), torch.tensor(0.0).to(attn.device))
    
        # Compute entropy: -sum(p * log(p))
        entropy = -torch.sum(attn * log_attention, dim=-1).mean(dim=0).squeeze(-1) # H T
        
        num_heads, num_tokens = entropy.shape

        def get_head_attns(entropy):
            results =  {
                f"token_{j}": entropy[j].item()
                for j in range(num_tokens)
            }

            results["mean"] = entropy.mean().item()
            return results

        results = {
            f"head_{i}": get_head_attns(entropy[i]) for i in range(num_heads)
        }

        results["mean"] = entropy.mean().item()
        return results

    def eval_attention_entropy(model):
        # Hook into the attention
        probes = [ActivationProbe(model, path) for path in paths.values()]

        # TODO: avoid this awful registering thing
        for probe in probes:
            probe.register_hook()

        model(xs, ys)
        
        for probe in probes:
            probe.unregister_hook()         

        return flatten_dict({
            k: get_attention_entropy(probe.activation)
            for k, probe in zip(paths.keys(), probes)
        })

    return eval_attention_entropy

attention_entropy_evals = make_attention_entropy_evals(
    evaluator.pretrain_xs,
    evaluator.pretrain_ys,
    block_1="token_sequence_transformer.blocks.0.attention.attention_softmax",
    block_2="token_sequence_transformer.blocks.1.attention.attention_softmax"
)

attention_entropy_evals(model)

In [None]:
attn_entropies_over_checkpoints = eval_model_over_checkpoints(model, checkpointer, attention_entropy_evals, verbose=True)
attn_entropies_over_checkpoints

In [None]:
print(list(attn_entropies_over_checkpoints.columns))

In [None]:
def plot_attention_patterns(df, title="", save: Optional[str] = None):
    fig = plt.figure(figsize=(20, 25))
    plt.suptitle(title)
 

    ax0 = plt.subplot2grid((6, 4), (0, 0), colspan=4)
 
    # First plot: Block 1 vs 2 (mean), 4-column wide
    ax0.plot(df.step, df["block_1/mean"], label="block_1")
    ax0.plot(df.step, df["block_2/mean"], label="block_2")
    ax0.set_title("Block 1 vs. 2")
    ax0.set_xlabel("Step")
    ax0.set_ylabel("Entropy")
    ax0.legend()

    ax1 = [plt.subplot2grid((6, 4), (1, i*2), colspan=2) for i in range(2)]
    for block_idx in range(2):
        ax1[block_idx].set_title(f"Block {block_idx + 1}")
        ax1[block_idx].set_xlabel("Step")
        ax1[block_idx].set_ylabel("Entropy")
        for head_idx in range(4):
            series = df[f"block_{block_idx+1}/head_{head_idx}/mean"]
            ax1[block_idx].plot(df.step, series, label=f"Head {head_idx + 1}")
        ax1[block_idx].legend()

    ax2 = [plt.subplot2grid((6, 4), (i//4 + 2, i%4)) for i in range(16)]
    ax_idx = 0
    for head_idx in range(4):
        for block_idx in range(2):
            for xs in (1, 0):
                ax2[ax_idx].set_title(f"Block {block_idx + 1} Head {head_idx + 1} {'X' if xs else 'Y'}")
                ax2[ax_idx].set_xlabel("Step")
                ax2[ax_idx].set_ylabel("Entropy")
                for token_idx in range(1-int(xs), 16, 2):
                    series = df[f"block_{block_idx+1}/head_{head_idx}/token_{token_idx}"]
                    ax2[ax_idx].plot(df.step, series, label=f"Token {token_idx + 1}", color=cmap[token_idx])
                # ax2[ax_idx].set_xscale("log")
                # ax2[ax_idx].set_yscale("log")
                ax_idx += 1

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save:
        parent_dir = os.path.dirname(save)
        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)

        plt.savefig(save)

    plt.show()    

plot_attention_patterns(attn_entropies_over_checkpoints, "M=1", save="../figures/M=1/attention_patterns.png")

In [None]:
def plot_activations_over_models_over_checkpoints(configs, checkpointers):
    dfs = []

    for config, checkpointer in zip(configs, checkpointers):
        num_tasks = config.task_config.num_tasks
        print("\n\n")
        print("=" * 20 + f" M={num_tasks} " + "=" * 20)
        model = load_model_at_last_checkpoint(config, checkpointer)

        # initialise 'pretraining' data source (for training on fixed task set)
        pretrain_dist = config.task_config.pretrain_dist_factory().to(device)

        # initialise 'true' data source (for evaluation, including unseen tasks)
        true_dist = config.task_config.true_dist_factory().to(device)

        # initialise evaluations
        evaluator = ICLEvaluator(
            pretrain_dist=pretrain_dist,
            true_dist=true_dist,
            max_examples=config.task_config.max_examples,
            eval_batch_size=config.eval_batch_size,
        )
             
        attention_entropy_evals = make_attention_entropy_evals(
            evaluator.pretrain_xs,
            evaluator.pretrain_ys,
            block_1="token_sequence_transformer.blocks.0.attention.attention_softmax",
            block_2="token_sequence_transformer.blocks.1.attention.attention_softmax"
        )
        
        attn_entropies_over_checkpoints = eval_model_over_checkpoints(model, checkpointer, attention_entropy_evals, verbose=False)
        plot_attention_patterns(attn_entropies_over_checkpoints, f"M={num_tasks}", save=f"../figures/M={num_tasks}/attention_entropies.png")

        dfs.append(attn_entropies_over_checkpoints)

    return dfs

dfs = plot_activations_over_models_over_checkpoints(configs[1:], checkpointers[1:])

In [None]:
for df, config in zip(dfs, configs[1:]):
    num_tasks = config.task_config.num_tasks
    plot_attention_patterns(df, f"M={num_tasks}", save=f"../figures/M={num_tasks}/attention_entropies.png")

# Dim reduction of weights and activations

In [None]:
from contextlib import contextmanager
from typing import Callable, Optional, Tuple, Union

import torch
from devinfra.utils.iterables import prepend_dict
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm import tqdm

Transform = Callable[[torch.Tensor], torch.Tensor]

class ActivationProbe:
    """
    A utility class to extract the activation value of a specific layer or neuron within a neural network.
    
    The location of the target is defined using a string that can specify the layer, channel, and spatial coordinates (y, x). The format allows flexibility in defining the location:
    
    - 'layer1.0.conv1': Targets the entire layer.
    - 'layer1.0.conv1.3': Targets channel 3 in the specified layer.
    - 'layer1.0.conv1.3.2.2': Targets channel 3, y-coordinate 2, and x-coordinate 2 in the specified layer.
    - 'layer1.0.conv1.*': Targets all neurons in the specified layer.
    
    The class provides methods to register a forward hook into a PyTorch model to capture the activation of the specified target during model inference.
    
    Attributes:
        model: The PyTorch model from which to extract the activation.
        layer_location (List[str]): List of strings specifying the layer hierarchy.
        neuron_location (List[int]): List of integers specifying the channel, y, and x coordinates.
        activation: The value of the activation at the specified location.

    Example:
        model = ResNet18()
        extractor = ActivationProbe(model, 'layer1.0.conv1.3')
        handle = extractor.register_hook()
        output = model(input_tensor)
        print(extractor.activation)  # Prints the activation value

    The wildcard '*' in neuron_location means that all neurons in the specified layer will be targeted.
    For example, 'layer1.0.conv1.*' will capture activations for all neurons in the 'layer1.0.conv1' layer.
    """
    
    def __init__(self, model, location):
        self.activation = None
        self.model = model
        location = location.split('.')

        self.layer_location = []
        self.neuron_location = []

        # Get the target layer
        self.layer = model
        for part in location:
            if part == "":
                continue
            if hasattr(self.layer, part):
                self.layer_location.append(part)
                self.layer = getattr(self.layer, part)
            else:
                if part == "*":
                    self.neuron_location.append(...)
                else:
                    self.neuron_location.append(int(part))

    def hook_fn(self, module, input, output):
        if self.neuron_location:
            # Assumes first index is over batch 
            self.activation = output[(..., *self.neuron_location)]
        else:
            self.activation = output

    def register_hook(self):
        self.handle = self.layer.register_forward_hook(self.hook_fn)
        return self.handle
    
    def unregister_hook(self):
        self.handle.remove()
    

    @contextmanager
    def watch(self):
        handle = self.register_hook()
        yield
        handle.remove()


In [None]:
from typing import Optional, Callable, Union, Dict
from sklearn.decomposition import PCA
import numpy as np
from collections import defaultdict
from copy import deepcopy
from devinterp.mechinterp.activations import ActivationProbe

# m = deepcopy(model)
# model = deepcopy(m)

class ActivationsReducer:
    activation_samples: Dict[str, Union[list, np.ndarray]]

    def __init__(
        self,
        xs: torch.Tensor,
        ys: torch.Tensor,
        activation_paths: Optional[Dict[str, str]] = None, 
        dr_method: Callable = None, 
    ):
        self.xs = xs
        self.ys = ys
        self.activation_samples = defaultdict(list)
        self.targets = activation_paths or model.named_modules()
        self.dr_method = dr_method
        
    def eval(self, model):
        assert not len(self.activation_samples.values()) or isinstance(next(iter(self.activation_samples.values())), list), "Cannot run eval() after running freeze()"

        probes = [ActivationProbe(model, path) for path in self.targets.values()]
        
        for probe in probes:
            probe.register_hook()
        
        model(self.xs, self.ys)
        
        for probe in probes:
            probe.unregister_hook()
            
        for name, probe in zip(self.targets.keys(), probes):
            activation = probe.activation.detach().cpu().numpy()
            activation = activation.reshape(-1)  
            self.activation_samples[name].append(activation)

    def run(self, model, checkpointer):
        for step in tqdm(steps, desc="Iterating over checkpoints"):
            model_state_dict = checkpointer.load_file(step)["model"]
            model.load_state_dict(model_state_dict)
            self.eval(model)

    def freeze(self):
        self.activation_samples = {k: np.array(v) for k, v in self.activation_samples.items()}
            
    def transform(self):
        assert isinstance(next(iter(self.activation_samples.values())), np.ndarray), "Must run freeze() before fit()"
        return {k: self.dr_method(v) for k, v in self.activation_samples.items()} if self.dr_method else self.activation_samples

    def run_transform(self, model, checkpointer):
        self.run(model, checkpointer)
        self.freeze()
        return self.transform()


def apply_pca(samples: np.ndarray, num_components=2):
    pca = PCA(n_components=num_components)
    transformed_samples = pca.fit_transform(samples)
    return pca, transformed_samples

activation_reducer = ActivationsReducer(
    evaluator.pretrain_xs,
    evaluator.pretrain_ys,
    # {"ys": ""},
    {"y": "7.0"},
    dr_method=apply_pca
)

steps = checkpointer.file_ids
activation_reducer.run(model, checkpointer)
activation_reducer.freeze()
dr_samples_pca, dr_samples = activation_reducer.transform()["y"] 
#activation_reducer.dr_method(activation_reducer.activation_samples["ys"])# activation_reducer.reduce()
samples = activation_reducer.activation_samples["y"]
samples_last_half = samples[len(samples)//2:, :]
pca_last_half = PCA(n_components=2)
transformed_samples_last_half = pca_last_half.fit_transform(samples_last_half)
steps_last_half = checkpointer.file_ids[len(samples)//2:]

In [None]:
print(samples.shape[1] // evaluator.pretrain_xs.shape[0])

In [None]:

# Function to plot sample evolution with color linear in steps and rescale samples
def plot_sample_evolution_with_inset(steps, samples, explained_variance, title="Sample Evolution in 2D Plane", num_points_to_label=10, save: Optional[str] = None, ax: Optional = None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(15, 8))
    
    # Main plot
    sc = ax.scatter(samples[:, 0], samples[:, 1], c=steps, cmap='viridis', s=50, alpha=0.6)
    plt.colorbar(sc, ax=ax, label='Steps')
    
    # Label some points
    total_samples = len(samples)
    step = total_samples // num_points_to_label
    for i in range(0, total_samples, step):
        sample_step = steps[i]
        ax.text(samples[i, 0], samples[i, 1], str(sample_step), fontsize=12, ha='right', va='bottom')
        
    ax.set_xlabel('Feature 1')
    ax.set_ylabel('Feature 2')
    ax.set_title(title)
    
    # Inset for explained variance at the bottom right corner with slight transparency
    axins = ax.inset_axes([0.7, 0.05, 0.25, 0.25])  # x, y, width, height
    axins.bar(range(len(explained_variance)), explained_variance, alpha=0.5)
    axins.set_title('Explained Variance')
    axins.set_xlabel('Component')
    axins.set_ylabel('Variance')
    axins.patch.set_alpha(0.5)
    
    if save:
        parent_dir = os.path.dirname(save)

        if not os.path.exists(parent_dir):
            os.makedirs(parent_dir)

        plt.savefig(save)

    # plt.show()

fig, axes = plt.subplots(1, 2, figsize=(20, 8))
plt.suptitle("M=1")
plot_sample_evolution_with_inset(checkpointer.file_ids, dr_samples, dr_samples_pca.explained_variance_ratio_, title="All checkpoints", ax=axes[0])
plot_sample_evolution_with_inset(steps_last_half, samples_last_half, pca_last_half.explained_variance_ratio_, title="Last half of checkpoints", save=f"../figures/M=1/M1_y_pca.png", ax=axes[1])
plt.show()

# Example usage could be similar to before, with the option to specify immediate_dr
# immediate_dr = SomeRandomProjectionFunction
# weight_eval, apply_weight_dr = make_weight_dr_evals({"layer1": "layer1.weight"}, dr_method, immediate_dr)
# activation_eval, apply_activation_dr = make_activation_dr_evals({"act1": "layer1"}, dr_method, immediate_dr)

In [None]:
def prepend_dict(d: dict, prefix: str, delimiter="."):
    return {f"{prefix}{delimiter}{k}": v for k, v in d.items()}

def plot_pcas_over_models_over_checkpoints(configs, checkpointers, paths):
    pca_results = []

    def apply_pca(samples: np.ndarray, num_components=2):
        pca = PCA(n_components=num_components)
        transformed_samples = pca.fit_transform(samples)
        return pca, transformed_samples

    for config, checkpointer in zip(configs, checkpointers):
        num_layers = config.task_config.num_layers
        num_tasks = config.task_config.num_tasks
        print("\n\n")
        print("=" * 20 + f" M={num_tasks} " + "=" * 20)
        model = load_model_at_last_checkpoint(config, checkpointer)

        # initialise 'pretraining' data source (for training on fixed task set)
        pretrain_dist = config.task_config.pretrain_dist_factory().to(device)

        # initialise 'true' data source (for evaluation, including unseen tasks)
        true_dist = config.task_config.true_dist_factory().to(device)

        # initialise evaluations
        evaluator = ICLEvaluator(
            pretrain_dist=pretrain_dist,
            true_dist=true_dist,
            max_examples=config.task_config.max_examples,
            eval_batch_size=config.eval_batch_size,
        )
        
        activation_reducer = ActivationsReducer(
            evaluator.pretrain_xs,
            evaluator.pretrain_ys,
            paths,
            dr_method=apply_pca
        )

        steps = checkpointer.file_ids

        local_pca_results = {
            "num_tasks": num_tasks,
        }

        transformed_samples = activation_reducer.run_transform(model, checkpointer)

        for path in paths.keys():
            pca, samples = transformed_samples[path]
            samples_full = activation_reducer.activation_samples[path]
            samples_last_half = samples_full[len(samples_full) //2:, :]
            steps_last_half = checkpointer.file_ids[len(samples)//2:]
            pca_last_half, transformed_samples_last_half = apply_pca(samples_last_half)

            fig, axes = plt.subplots(1, 2, figsize=(20, 8))
            plt.suptitle(f"M={num_tasks}")
            plot_sample_evolution_with_inset(checkpointer.file_ids, samples, pca.explained_variance_ratio_, title="All checkpoints", ax=axes[0])
            plot_sample_evolution_with_inset(steps_last_half, samples_last_half, pca_last_half.explained_variance_ratio_, title="Last half of checkpoints", save=f"../figures/L={num_layers}_M={num_tasks}_{path}_pca.png", ax=axes[1])
            plt.show()

            local_pca_results.update(prepend_dict({
                "explained_variance": pca.explained_variance_ratio_,
                "samples": samples,
                "explained_variance_last_half": pca_last_half.explained_variance_ratio_,
                "samples_last_half": samples_last_half
            }, prefix=path, delimiter="/"))


        pca_results.append(local_pca_results)


    return pca_results

# pca_results = plot_pcas_over_models_over_checkpoints(configs[1:], checkpointers[1:])
# pca_results = plot_pcas_over_models_over_checkpoints(configs[12:], checkpointers[12:], paths={path: "token_sequence_transformer"})
pca_results = plot_pcas_over_models_over_checkpoints(configs, checkpointers, paths={"ys": "", "y": "7.0"})

# Covariance analysis

# RLCTs

In [None]:
# from boto3.s3 import NoSuchKey
from devinfra.evals import RepeatEvaluator
from tqdm import trange


def eval_rlct(model: nn.Module):
    xs, ys = evaluator.pretrain_xs, evaluator.pretrain_ys
    trainset = torch.utils.data.TensorDataset(xs, ys)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(xs))

    optimizer_kwargs = dict(
        lr=1e-5,
        noise_level=1.0,
        weight_decay=3e-7,
        elasticity=10.0,
        temperature="adaptive",
        num_samples=len(xs),
    )
    return {
        "rlct": estimate_rlct(
            model,
            trainloader,
            F.mse_loss,
            "sgld",
            optimizer_kwargs,
            num_draws=20,
            num_chains=8,
            num_burnin_steps=0,
            num_steps_bw_draws=1,
            cores=8,
            pbar=False,
            device="xla"
        )
    }


eval_rlcts = RepeatEvaluator(eval_rlct, 5)
eval_rlcts(model)

In [None]:


def eval_models_at_step(step=-1):
    evals = {}

    for i, _checkpointer in enumerate(checkpointers):
        checkpointer = _checkpointer.providers[0]
        print("-" * 20 + f" M {2 ** i} " + "-" * 20)

        if step == -1:
            step = checkpointer.file_ids[step]        

        try:
            model.load_state_dict(checkpointer.load_file(step)["model"])  # Load last checkpoint
        except Exception as e:
            # TODO: Figure out where to find NoSuchKey
            print(f"Step {step} not found for checkpoint {i}. Skipping...", e, step)
            continue
        
        sampler = ICLSampler(
            model, 
            evaluator.pretrain_xs, 
            evaluator.pretrain_ys,    
            sampler_config
        )

        rlct_evaluator = SamplerEvaluator.create_rlct_evaluator(sampler)
        _evals = []

        def wipe():
            nonlocal sampler
            nonlocal rlct_evaluator

            del sampler
            del rlct_evaluator
            
            torch.cuda.empty_cache()

        try:

            for _ in trange(10):
                _evals.append(rlct_evaluator(model, None, None))
                print(_evals[-1])

            evals[2 ** i] = _evals
            print(evals)
        except Exception as e:
            wipe()
            raise e      

    return evals

evals = eval_models_at_step(81632)

In [None]:
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
import matplotlib.pyplot as plt

keys = sorted(list(evals.keys()))
means = [np.mean(evals[k]) for k in keys]
std_devs = [np.std(evals[k]) for k in keys]

plt.figure(figsize=(10, 6))
plt.errorbar(keys, means, yerr=std_devs, fmt='o-', capsize=5, label="Evaluation Metrics")
plt.xlabel("Number of Tasks")
plt.ylabel(r"$\hat\lambda$")
plt.title("Steps = 81632")
plt.legend()
plt.grid(True)
plt.xscale("log")
plt.show()

In [None]:
from functools import reduce

all_steps = sorted(list(reduce(lambda x, y: x | y, [set(c.file_ids) for c in checkpointers], set())))
all_steps = all_steps[::2]
print(all_steps, len(all_steps))

In [None]:
def eval_models_at_multiple_steps(steps):
    evals = {}

    for step in steps:
        print("=" * 20 + f" Step {step} " + "=" * 20)
        evals[step] = eval_models_at_step(step)

    return evals

evals_over_time = eval_models_at_multiple_steps(all_steps)

In [None]:
from matplotlib.animation import FuncAnimation

# Initialize figure
fig, ax = plt.subplots(figsize=(10, 6))
line, = ax.plot([], [], 'o-', label="Evaluation Metrics")
ax.set_xlabel("Num tasks")
ax.set_ylabel(r"$\hat\lambda$")
ax.set_title("RLCT estimates over time")
ax.grid(True)
ax.legend()

# Initialize data
x_data = []
y_data = []
y_err = []

# Update function for animation
def update(step):
    global x_data, y_data, y_err

    evals = evals_over_time.get(step, {})
    keys = sorted(list(evals.keys()))
    means = [np.mean(evals[k]) for k in keys]
    std_devs = [np.std(evals[k]) for k in keys]

    x_data = keys
    y_data = means
    y_err = std_devs

    line.set_data(x_data, y_data)
    ax.relim()
    ax.autoscale_view()
    ax.errorbar(x_data, y_data, yerr=y_err, fmt='o-', capsize=5, label="Evaluation Metrics (Step {})".format(step))

# Create animation
ani = FuncAnimation(fig, update, frames=sorted(evals_over_time.keys()), repeat=False)

plt.show()

# Analysis

In [None]:
import pandas as pd
from devinterp.utils import flatten_dict

def wandb_run_to_df(run):
    # Assuming 'run' is a wandb.Api().run() object
    config = flatten_dict(run.config, flatten_lists=True)
    history = run.history()

    for k, v in config.items():
        history[k] = v

    return pd.DataFrame.from_dict(history)

def wandb_runs_to_df(*runs):
    # Multiple wandb run objects
    df_list = [wandb_run_to_df(run) for run in runs]
    return pd.concat(df_list)

def wandb_sweep_to_df(sweep):
    # Assuming 'sweep' is a wandb.Api().sweep() object
   return wandb_runs_to_df(*sweep.runs)

def wandb_sweeps_to_df(*sweeps):
    # Multiple wandb sweep objects
    df_list = [wandb_sweep_to_df(sweep) for sweep in sweeps]
    return pd.concat(df_list)

def wandb_run_id_to_df(run_id: str):
    api = wandb.Api()
    run = api.run(f"project_path/{run_id}")  # Replace 'project_path' with your actual project path
    return wandb_run_to_df(run)

def wandb_sweep_id_to_df(sweep_id: str, entity="devinterp", project="icl"):
    api = wandb.Api()
    sweep = api.sweep(f"{entity}/{project}/{sweep_id}")  # Replace 'project_path' with your actual project path
    return wandb_sweep_to_df(sweep)

def wandb_sweep_ids_to_df(sweep_ids: str, entity="devinterp", project="icl"):
    api = wandb.Api()
    sweeps = [api.sweep(f"{entity}/{project}/{sweep_id}") for sweep_id in sweep_ids.split(",")]  # Replace 'project_path' with your actual project path
    return wandb_sweeps_to_df(*sweeps)

df = wandb_sweep_ids_to_df("xksoyrhb")
df

In [None]:
df.columns

In [None]:
lrs = df["analysis_config/lr"].unique()
num_draws = df["analysis_config/num_draws"].unique()
elasticities = df["analysis_config/elasticity"].unique()

lrs, num_draws, elasticities

In [None]:
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns

# Set Seaborn style
sns.set(style="whitegrid")

for (lr, _num_draws, elasticity) in product(lrs, num_draws, elasticities):
    _df = df[
        (df["analysis_config/lr"] == lr) &
        (df["analysis_config/num_draws"] == _num_draws) &
        (df["analysis_config/elasticity"] == elasticity)
    ]
    _df = _df.sort_values(by="task_config/num_tasks")  # Ensure data is sorted by x-values

    # Values for plotting
    x_vals = _df["task_config/num_tasks"]
    y_means = _df["rlct/mean"]
    y_stds = _df["rlct/std"]

    # Create a figure and a plot
    plt.figure(figsize=(10, 6))
    
    # Plot mean values as a line
    plt.plot(x_vals, y_means, 'o-', label="RLCTs", linewidth=2)
    
    # Add shaded area for error
    plt.fill_between(x_vals, y_means - y_stds, y_means + y_stds, color='gray', alpha=0.4)
    
    # Labels and scales
    plt.xlabel("Number of Tasks")
    plt.xscale("log")
    plt.ylabel(r"$\hat\lambda$")
    
    plt.title(f"$\\eta={lr}, n_\\mathrm{{draws}}={_num_draws}, \\gamma={elasticity}$")

    # Show legend
    plt.legend()

    plt.show()


In [None]:
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns

# Set Seaborn style
sns.set(style="whitegrid")

# Initialize the figure
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
fig.tight_layout(pad=6.0)

# Create a mapping for lrs and num_draws to the grid
lr_to_row = {lr: i for i, lr in enumerate(sorted(set(lrs)))}
num_draws_to_col = {nd: i for i, nd in enumerate(sorted(set(num_draws)))}

for (lr, _num_draws, elasticity) in product(lrs, num_draws, elasticities):
    _df = df[
        (df["analysis_config/lr"] == lr) &
        (df["analysis_config/num_draws"] == _num_draws) &
        (df["analysis_config/elasticity"] == elasticity)
    ]
    _df = _df.sort_values(by="task_config/num_tasks")  # Ensure data is sorted by x-values
    
    # Select subplot
    ax = axes[lr_to_row[lr], num_draws_to_col[_num_draws]]

    # Values for plotting
    x_vals = _df["task_config/num_tasks"]
    y_means = _df["rlct/mean"]
    y_stds = _df["rlct/std"]

    # Plot mean values as a line
    ax.plot(x_vals, y_means, 'o-', label=f"RLCTs, $\gamma={elasticity}$", linewidth=2)
    
    # Add shaded area for error
    ax.fill_between(x_vals, y_means - y_stds, y_means + y_stds, color='gray', alpha=0.4)

    # Labels and scales
    ax.set_xlabel("Number of Tasks")
    ax.set_xscale("log")
    ax.set_ylabel(r"$\hat\lambda$")
    
    # Title for each subplot
    ax.set_title(f"$\eta={lr}, n_\mathrm{{draws}}={_num_draws}, \gamma={elasticity}$")
    
    # Show legend
    ax.legend()

# Show the complete figure
plt.show()
