# 
RLCT Estimation of Multitask Sparse Parity

In [None]:
%pip install devinterp seaborn torchvision pickleshare wandb plotly einops scikit-learn
!git clone https://github.com/ucla-vision/entropy-sgd.git
%cd entropy-sgd
from python.optim import EntropySGD
%cd ..

In [None]:
import numpy as np
import torch as t
import torch
import torch.nn as nn
import torch.optim as optim
import time
import torch.nn.functional as F
import einops
import random
from dataclasses import dataclass
import os
import copy
import wandb
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from python.optim import EntropySGD
from torch.utils.data import DataLoader
from collections import defaultdict
from itertools import islice, product
import random
import time
from pathlib import Path

from devinterp.optim.sgld import SGLD
from devinterp.optim.sgnht import SGNHT

PRIMARY, SECONDARY, TERTIARY, QUATERNARY, QUINARY, SENARY = sns.color_palette("muted")[:6]
PRIMARY_LIGHT, SECONDARY_LIGHT, TERTIARY_LIGHT, QUATERNARY_LIGHT, QUINARY_LIGHT, SENARY_LIGHT = sns.color_palette(
    "pastel"
)[:6]

print(len(sns.color_palette("pastel")))

In [None]:
class FastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    """
    def __init__(self, *tensors, batch_size=32, shuffle=False):
        """
        Initialize a FastTensorDataLoader.

        :param *tensors: tensors to store. Must have the same length @ dim 0.
        :param batch_size: batch size to load.
        :param shuffle: if True, shuffle the data *in-place* whenever an
            iterator is created out of this object.

        :returns: A FastTensorDataLoader.
        """
        assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
        self.tensors = tensors

        self.dataset_len = self.tensors[0].shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            self.indices = torch.randperm(self.dataset_len, device=self.tensors[0].device)
        else:
            self.indices = None
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        if self.indices is not None:
            indices = self.indices[self.i:self.i+self.batch_size]
            batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors)
        else:
            batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches


def get_batch(n_tasks, n, Ss, codes, sizes, device='cpu', dtype=torch.float32):
    """Creates batch. 

    Parameters
    ----------
    n_tasks : int
        Number of tasks.
    n : int
        Bit string length for sparse parity problem.
    Ss : list of lists of ints
        Subsets of [1, ... n] to compute sparse parities on.
    codes : list of int
        The subtask indices which the batch will consist of
    sizes : list of int
        Number of samples for each subtask
    device : str
        Device to put batch on.
    dtype : torch.dtype
        Data type to use for input x. Output y is torch.int64.

    Returns
    -------
    x : torch.Tensor
        inputs
    y : torch.Tensor
        labels
    """
    batch_x = torch.zeros((sum(sizes), n_tasks+n), dtype=dtype, device=device)
    batch_y = torch.zeros((sum(sizes),), dtype=torch.int64, device=device)
    start_i = 0
    for (S, size, code) in zip(Ss, sizes, codes):
        if size > 0:
            x = torch.randint(low=0, high=2, size=(size, n), dtype=dtype, device=device)
            y = torch.sum(x[:, S], dim=1) % 2
            x_task_code = torch.zeros((size, n_tasks), dtype=dtype, device=device)
            x_task_code[:, code] = 1
            x = torch.cat([x_task_code, x], dim=1)
            batch_x[start_i:start_i+size, :] = x
            batch_y[start_i:start_i+size] = y
            start_i += size
    return batch_x, batch_y
    
def cycle(iterable):
    while True:
        for x in iterable:
            yield x


In [None]:
class MLP(nn.Module):
    
    def __init__(self, activation, depth, width):
        super(MLP, self).__init__()
        
        if activation == 'ReLU':
            activation_fn = nn.ReLU
        elif activation == 'Tanh':
            activation_fn = nn.Tanh
        elif activation == 'Sigmoid':
            activation_fn = nn.Sigmoid
        else:
            assert False, f"Unrecognized activation function identifier: {activation}"

        # create model
        layers = []
        for i in range(depth):
            if i == 0:
                layers.append(nn.Linear(n_tasks + n, width))
                layers.append(activation_fn())
            elif i == depth - 1:
                layers.append(nn.Linear(width, 2))
            else:
                layers.append(nn.Linear(width, width))
                layers.append(activation_fn())
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [None]:
def l1_loss_zero_one(parameters):
    l1_loss = 0
    for param in parameters:
        l1_loss += torch.min(param.abs(), (param - 1).abs()).sum()
    return l1_loss

def run(n_tasks,
        n,
        k,
        D,
        width,
        depth,
        activation,
        test_points,
        test_points_per_task,
        steps,
        batch_size,
        lr,
        weight_decay,
        device,
        dtype,
        log_freq,
        verbose,
        seed):

    torch.set_default_dtype(dtype)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    info = {}
    models_saved = []

    mlp = MLP(activation, depth, width).to(device)
    #_log.debug("Created model.")
    #_log.debug(f"Model has {sum(t.numel() for t in mlp.parameters())} parameters") 
    info['P'] = sum(t.numel() for t in mlp.parameters())

    Ss = []
    for _ in range(n_tasks * 10):
        S = tuple(sorted(list(random.sample(range(n), k))))
        if S not in Ss:
            Ss.append(S)
        if len(Ss) == n_tasks:
            break
    assert len(Ss) == n_tasks, "Couldn't find enough subsets for tasks for the given n, k"
    info['Ss'] = Ss

    probs = np.array([1 / n_tasks for _ in range(n_tasks)])
    cdf = np.cumsum(probs)

    test_batch_sizes = [int(prob * test_points) for prob in probs]
    # _log.debug(f"Total batch size = {sum(batch_sizes)}")

    if D != -1:
        samples = np.searchsorted(cdf, np.random.rand(D,))
        hist, _ = np.histogram(samples, bins=n_tasks, range=(0, n_tasks-1))
        train_x, train_y = get_batch(n_tasks=n_tasks, n=n, Ss=Ss, codes=list(range(n_tasks)), sizes=hist, device='cpu', dtype=dtype)
        train_x = train_x.to(device)
        train_y = train_y.to(device)
        train_loader = FastTensorDataLoader(train_x, train_y, batch_size=min(D, batch_size), shuffle=True)
        train_iter = cycle(train_loader)
        info['D'] = D
    else:
        info['D'] = steps * batch_size

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(mlp.parameters(), lr=lr, weight_decay=weight_decay)
    info['log_steps'] = list()
    info['accuracies'] = list()
    info['losses'] = list()
    info['losses_subtasks'] = dict()
    info['accuracies_subtasks'] = dict()
    for i in range(n_tasks):
        info['losses_subtasks'][str(i)] = list()
        info['accuracies_subtasks'][str(i)] = list()
    for step in tqdm(range(steps), disable=not verbose):
        if step % log_freq == 0:
            with torch.no_grad():
                x_i, y_i = get_batch(n_tasks=n_tasks, n=n, Ss=Ss, codes=list(range(n_tasks)), sizes=test_batch_sizes, device=device, dtype=dtype)
                y_i_pred = mlp(x_i)
                labels_i_pred = torch.argmax(y_i_pred, dim=1)
                info['accuracies'].append(torch.sum(labels_i_pred == y_i).item() / test_points) 
                info['losses'].append(loss_fn(y_i_pred, y_i).item())
                for i in range(n_tasks):
                    x_i, y_i = get_batch(n_tasks=n_tasks, n=n, Ss=[Ss[i]], codes=[i], sizes=[test_points_per_task], device=device, dtype=dtype)
                    y_i_pred = mlp(x_i)
                    info['losses_subtasks'][str(i)].append(loss_fn(y_i_pred, y_i).item())
                    labels_i_pred = torch.argmax(y_i_pred, dim=1)
                    info['accuracies_subtasks'][str(i)].append(torch.sum(labels_i_pred == y_i).item() / test_points_per_task)
                info['log_steps'].append(step)
                models_saved += [copy.deepcopy(mlp)]
        optimizer.zero_grad()
        if D == -1:
            samples = np.searchsorted(cdf, np.random.rand(batch_size,))
            hist, _ = np.histogram(samples, bins=n_tasks, range=(0, n_tasks-1))
            x, y_target = get_batch(n_tasks=n_tasks, n=n, Ss=Ss, codes=list(range(n_tasks)), sizes=hist, device=device, dtype=dtype)
        else:
            x, y_target = next(train_iter)
        y_pred = mlp(x)
        #loss = loss_fn(y_pred, y_target) + l1_loss_zero_one(mlp.parameters())
        loss = loss_fn(y_pred, y_target)
        loss.backward()
        optimizer.step()
    return info, models_saved, loss_fn

n_tasks = 10
n = 50
k = 3

D = -1 # -1 for infinite data

width = 100
depth = 2
activation = 'ReLU'
    
steps = 5000
batch_size = 10000
lr = 1e-3
weight_decay = 0.0
test_points = 30000
test_points_per_task = 1000
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

#log_freq = max(1, steps // 1000)
log_freq = 100
verbose=True
seed = 0
runs = 1
        
info, models_saved, criterion = run(n_tasks,
    n, 
    k, 
    D, 
    width, 
    depth, 
    activation, 
    test_points, 
    test_points_per_task, 
    steps, 
    batch_size, 
    lr, 
    weight_decay, 
    device, 
    dtype, 
    log_freq, 
    verbose, 
    seed)

In [None]:
from devinterp.slt import estimate_learning_coeff_with_summary

N_EPOCHS = steps
SAVE_EVERY_N_EPOCHS = log_freq

Ss = []
for _ in range(n_tasks * 10):
    S = tuple(sorted(list(random.sample(range(n), k))))
    if S not in Ss:
        Ss.append(S)
    if len(Ss) == n_tasks:
        break
assert len(Ss) == n_tasks, "Couldn't find enough subsets for tasks for the given n, k"

probs = np.array([1 / n_tasks for _ in range(n_tasks)])
cdf = np.cumsum(probs)
test_batch_sizes = [int(prob * test_points) for prob in probs]

train_x, train_y = get_batch(n_tasks=n_tasks, n=n, Ss=Ss, codes=list(range(n_tasks)), sizes=test_batch_sizes, device=device, dtype=dtype)
train_data = list(zip(train_x, train_y))
        
#train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

def estimate_rlcts(models, train_loader, criterion, data_length, device, num_draws, num_models):
    estimates = {"sgnht": [], "sgld": []}
    for idx, model in enumerate(tqdm(models)):
        for method, optimizer_kwargs in [
            #("sgnht", {"lr": 1e-7, "diffusion_factor": 0.01}),
            ("sgld", {"lr": 1e-3, "localization": 400.0, "noise_level": 1.0}),
        ]:
            results = estimate_learning_coeff_with_summary(
                model,
                train_loader,
                criterion=criterion,
                optimizer_kwargs=optimizer_kwargs,
                sampling_method=SGNHT if method == "sgnht" else SGLD,
                num_chains=2,
                num_draws=num_draws,
                num_burnin_steps=200,
                num_steps_bw_draws=1,
                device=device,
                seed=0
            )
            estimate = results["llc/mean"]
            estimates[method].append(estimate)
    return estimates

def obtain_rlct_estimates(train_loader, models_saved, criterion, runs):
    num_models = N_EPOCHS // SAVE_EVERY_N_EPOCHS
    data_length = len(train_loader)
    rlct_estimates = {"sgnht": torch.zeros(runs, num_models), "sgld": torch.zeros(runs, num_models)}
    num_draws = 800

    for run in tqdm(range(runs)):
        rlct_estimate = estimate_rlcts(
            models_saved[num_models * run : num_models * (run + 1)], train_loader, criterion, data_length, device, num_draws, num_models
        )
        #rlct_estimates["sgnht"][run] = torch.tensor(rlct_estimate["sgnht"])
        rlct_estimates["sgld"][run] = torch.tensor(rlct_estimate["sgld"])

    rlct_estimates_final = {"sgnht": rlct_estimates["sgnht"].mean(dim=0), "sgld": rlct_estimates["sgld"].mean(dim=0)}
    return rlct_estimates_final

rlct_estimates_final = obtain_rlct_estimates(train_loader, models_saved, criterion, runs)

In [None]:
dataset = '0'

def plot_losses(train_losses_final, name = ''):

    sns.set_style("whitegrid")
    x_axis = np.arange(1, N_EPOCHS, SAVE_EVERY_N_EPOCHS)

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color=PRIMARY)
    plt.yscale('log')
    ax1.plot(x_axis, train_losses_final, label="Train Loss, sgd", color=PRIMARY)
    #ax1.plot(x_axis, test_losses_final, label="Test Loss, sgd", color=PRIMARY_LIGHT)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("losses_" + name + "_" + str(N_EPOCHS) + "_epochs.png")
    
def plot_subtask_losses(train_losses_subtasks, n_tasks):

    sns.set_style("whitegrid")
    x_axis = np.arange(1, N_EPOCHS, SAVE_EVERY_N_EPOCHS)

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color=PRIMARY)
    plt.yscale('log')
    
    for task in range(n_tasks):
        ax1.plot(x_axis, train_losses_subtasks[str(task)], label="Task " + str(task), color=sns.color_palette("muted")[task])
        
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("losses_subtasks_" + str(N_EPOCHS) + "_epochs.png")
    
def plot_accuracies(train_accuracies_final, name = ''):

    sns.set_style("whitegrid")
    x_axis = np.arange(1, N_EPOCHS, SAVE_EVERY_N_EPOCHS)

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy", color=PRIMARY)
    plt.yscale('log')
    ax1.plot(x_axis, train_accuracies_final, label="Train Accuracy, sgd", color=PRIMARY)
    #ax1.plot(x_axis, test_accuracies_final, label="Test Accuracy, sgd", color=PRIMARY_LIGHT)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("accuracies_" + name + "_" + str(N_EPOCHS) + "_epochs.png")
    
def plot_subtask_accuracies(train_accuracies_subtasks, n_tasks):

    sns.set_style("whitegrid")
    x_axis = np.arange(1, N_EPOCHS, SAVE_EVERY_N_EPOCHS)

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy", color=PRIMARY)
    plt.yscale('log')
    
    for task in range(n_tasks):
        ax1.plot(x_axis, train_accuracies_subtasks[str(task)], label="Task " + str(task), color=sns.color_palette("muted")[task])
        
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("accuracies_subtasks_" + str(N_EPOCHS) + "_epochs.png")
    
def plot_rlcts(rlct_estimates_final, dataset, rlct_estimates_final_other = {}):

    sns.set_style("whitegrid")
    
    #first_part = np.arange(1, 1001, 10)
    
    # Create array from 1000 to 50000 with step 100
    # Start from 1100 to avoid duplicating 1000
    #second_part = np.arange(1001, N_EPOCHS, 100)
    
    # Combine the two arrays
    #x_axis = np.concatenate([first_part, second_part])
    x_axis = np.arange(1, N_EPOCHS, SAVE_EVERY_N_EPOCHS)

    fig, ax2 = plt.subplots(figsize=(10, 6))
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel(r"Local Learning Coefficient, $\hat \lambda$", color=SECONDARY)
    if rlct_estimates_final_other:
        ax2.plot(x_axis, rlct_estimates_final_other["sgld"], label="summed curve", color=TERTIARY)
    ax2.plot(x_axis, rlct_estimates_final["sgld"], label="SGLD, sgd", color=TERTIARY_LIGHT)
    ax2.tick_params(axis="y", labelcolor=SECONDARY)
    ax2.legend(loc="center right")

    fig.tight_layout()
    plt.show()
    fig.savefig("rclt_" + dataset + "_" + str(N_EPOCHS) + "_epochs.png")

train_losses_final = info['losses']
train_accuracies_final = info['accuracies']
    
plot_losses(train_losses_final, dataset)
plot_subtask_losses(info['losses_subtasks'], n_tasks)
plot_accuracies(train_accuracies_final, dataset)
plot_subtask_accuracies(info['accuracies_subtasks'], n_tasks)
plot_rlcts(rlct_estimates_final, dataset='full')


In [None]:
def return_topk_percent_mask(tensor, proportion):
    # Step 1: Flatten the tensor
    flattened_tensor = tensor.flatten()

    # Step 2: Determine K, where K is 20% of the total number of elements
    total_elements = flattened_tensor.numel()
    K = int(proportion * total_elements)

    # Step 3: Find the value of the K-th largest element
    topk_values, _ = torch.topk(flattened_tensor, K)
    threshold_value = topk_values[-1]

    # Step 4: Create a boolean mask of the top K values
    return tensor >= threshold_value


def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))

def ablation_study(model, loss_fn):
    
    loss_diffs_per_task = []
    for index in tqdm(range(n_tasks)):
        loss_diffs = {}
        for name, param in model.named_parameters():
            loss_diffs[name] = torch.zeros(param.shape)
            x_i, y_i = get_batch(n_tasks=n_tasks, n=n, Ss=[Ss[index]], codes=[index], sizes=[test_points_per_task], device=device, dtype=dtype)
            y_i_pred = model(x_i)
            loss_baseline = loss_fn(y_i_pred, y_i).item()
            
            for idx in range(param.numel()):
                with torch.no_grad():
                    # Convert flat index i to multi-dimensional index for the original shape
                    multi_idx = unravel_index(idx, param.shape)
                
                    # Save the original weight value
                    original_value = param[multi_idx].item()
                
                    # Set the weight to zero
                    param[multi_idx] = 0.0
                    
                    y_i_ablated = model(x_i)
                    loss_ablated = loss_fn(y_i_ablated, y_i).item()
                    
                    loss_diffs[name][multi_idx] = abs(loss_ablated - loss_baseline)
                    
                    # Restore the original weight
                    param[multi_idx] = original_value
                    torch.cuda.empty_cache()
        loss_diffs_per_task.append(loss_diffs)
                    
    return loss_diffs_per_task   

# Use the function
loss_diffs_per_task = ablation_study(models_saved[-1], criterion)
epsilon = .1
num_params = sum(p.numel() for p in models_saved[-1].parameters())
indices_per_task = []

# Analyze results
for index in tqdm(range(n_tasks)):
    num_weights = 0
    print(index)
    indices_per_task.append({})
    for name in loss_diffs_per_task[index].keys():
        print(f"Layer: {name}")
        indices = loss_diffs_per_task[index][name] > epsilon
        indices_per_task[index][name] = indices
        print(loss_diffs_per_task[index][name][indices].shape)
        num_weights += sum(loss_diffs_per_task[index][name][indices].shape)
    print(index, num_weights, num_params)

indices_per_task = []
# Analyze results
for task in tqdm(range(n_tasks)):
    
    indices_dict = {}
    
    proportion = 0.3
    indices_dict['model.0.weight'] = return_topk_percent_mask(loss_diffs_per_task[task]['model.0.weight'], proportion)
    indices_dict['model.0.bias'] = torch.ones(b_0.shape, dtype=torch.bool)
    indices_dict['model.2.weight'] = return_topk_percent_mask(loss_diffs_per_task[task]['model.2.weight'], proportion)
    indices_dict['model.2.bias'] = torch.ones(b_1.shape, dtype=torch.bool)
        
    indices_per_task.append(indices_dict)

In [None]:
def get_indices_per_task(path_losses_per_task, path_indices_per_task, indices_per_task):
    eps = .01
    
    for task in tqdm(range(n_tasks)):
        path_losses = path_losses_per_task[task]
        path_indices = path_indices_per_task[task]
        
        above_eps_indices = path_losses > eps
        
        for (idx, jdx), (kdx, ldx) in tqdm(path_indices[above_eps_indices]):
            indices_per_task[task]['model.0.weight'][idx, jdx] = True 
            indices_per_task[task]['model.2.weight'][kdx, ldx] = True
            
    return indices_per_task
        

def get_path_losses_per_task(loss_diffs_per_task, n_tasks):
    path_losses_per_task = []
    path_indices_per_task = []
    indices_per_task = []

    for task in tqdm(range(n_tasks)):
        W_0 = loss_diffs_per_task[task]['model.0.weight']
        b_0 = loss_diffs_per_task[task]['model.0.bias']
        W_1 = loss_diffs_per_task[task]['model.2.weight']
        b_1 = loss_diffs_per_task[task]['model.2.bias']
        
        indices_dict = {}
        indices_dict['model.0.weight'] = torch.zeros(W_0.shape, dtype=torch.bool)
        indices_dict['model.0.bias'] = torch.ones(b_0.shape, dtype=torch.bool)
        indices_dict['model.2.weight'] = torch.zeros(W_1.shape, dtype=torch.bool)
        indices_dict['model.2.bias'] = torch.ones(b_1.shape, dtype=torch.bool)
        
        indices_per_task.append(indices_dict)
    
        path_losses = []
        path_indices = []

        for idx in tqdm(range(W_0.shape[0])):
            for jdx in range(W_0.shape[1]):
                for kdx in range(W_1.shape[0]):
                    for ldx in range(W_1.shape[1]):
                        path_loss = W_0[idx, jdx].item() + W_1[kdx, ldx].item()
                        path_losses.append(path_loss)
                        path_indices.append([(idx, jdx), (kdx, ldx)])
        path_losses_per_task.append(torch.tensor(path_losses))
        path_indices_per_task.append(torch.tensor(path_indices))
        
        print(torch.tensor(path_losses) == (W_0[ : , None] * W_1.T).flatten())
        if task == 0:
            return
        
    return path_losses_per_task, path_indices_per_task, indices_per_task

for name, param in models_saved[-1].named_parameters():
    print(name)
    print(param.shape)

path_losses_per_task, path_indices_per_task, indices_per_task = get_path_losses_per_task(loss_diffs_per_task, n_tasks)
indices_per_task = get_indices_per_task(path_losses_per_task, path_indices_per_task, indices_per_task)

In [None]:
import copy

def prune_to_obtain_circuit(model, model_indices):
    
    model_state_dict = model.state_dict()
    
    for name, param in model.named_parameters():
        indices = model_indices[name]
        model_state_dict[name][~indices] = 0.0
        
    model.load_state_dict(model_state_dict)
            
    return model

models_per_task = []

for task in tqdm(range(n_tasks)):
    models = []
    for model in tqdm(models_saved):
        task_model = copy.deepcopy(model)
        task_model = prune_to_obtain_circuit(task_model, indices_per_task[task])
        models.append(task_model)
    models_per_task.append(models)

        

In [None]:
def compute_loss_curve_for_model(models, steps, n, Ss, n_tasks, test_points, test_batch_sizes, device, dtype):
    losses = []
    accuracies = []
    losses_subtasks = {}
    accuracies_subtasks = {}
    loss_fn = nn.CrossEntropyLoss()
    
    for i in range(n_tasks):
        losses_subtasks[str(i)] = list()
        accuracies_subtasks[str(i)] = list()
    
    for model in tqdm(models):
        with torch.no_grad():
            x_i, y_i = get_batch(n_tasks=n_tasks, n=n, Ss=Ss, codes=list(range(n_tasks)), sizes=test_batch_sizes, device=device, dtype=dtype)
            y_i_pred = model(x_i)
            labels_i_pred = torch.argmax(y_i_pred, dim=1)
            accuracies.append(torch.sum(labels_i_pred == y_i).item() / test_points) 
            losses.append(loss_fn(y_i_pred, y_i).item())
            for i in range(n_tasks):
                x_i, y_i = get_batch(n_tasks=n_tasks, n=n, Ss=[Ss[i]], codes=[i], sizes=[test_points_per_task], device=device, dtype=dtype)
                y_i_pred = model(x_i)
                losses_subtasks[str(i)].append(loss_fn(y_i_pred, y_i).item())
                labels_i_pred = torch.argmax(y_i_pred, dim=1)
                accuracies_subtasks[str(i)].append(torch.sum(labels_i_pred == y_i).item() / test_points_per_task)
    return losses, accuracies, losses_subtasks, accuracies_subtasks

losses, accuracies, losses_subtasks, accuracies_subtasks = compute_loss_curve_for_model(models_per_task[0], steps, n, Ss, n_tasks, test_points, test_batch_sizes, device, dtype)

plot_losses(losses, 'first_task_with_ablation')
plot_accuracies(accuracies, 'first_task_with_ablation')
plot_losses(losses_subtasks['0'], 'first_task_with_ablation_subtask_' + str(0))
plot_accuracies(accuracies_subtasks['0'], 'first_task_with_ablation_subtask_' + str(0))
