In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call

import matplotlib.pyplot as plt
from tqdm import tqdm

from .new_mem_module import metaRNN, MemoryModule
from .synthetic_datasets import InContextRecallDataset
from .losses import windowed_p_loss

#torch.manual_seed(0) # You can uncomment this for reproducible runs

class LinearModel(MemoryModule):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        return self.linear(x)
    
    def functional_forward(self, x, params):
        return F.linear(x, params[0])

# --- Configuration ---
n_features = 10
n_samples = 300
num_runs = 50  # The number of times 'k' to run the simulation
loss_fn = windowed_p_loss

# --- Main execution logic ---

# Store the lookback accuracy from each run
all_lookback_accuracies = []

print(f"Running simulation {num_runs} times...")

for i in tqdm(range(num_runs), desc="Running Simulations"):
    # 1. Re-initialize the model for each run to start from a fresh state.
    #    This is crucial for a fair average.
    base_model = LinearModel(n_features, n_features)
    params = base_model.parameters()