# OLD
See Neuro_visualizer Notebook

In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader

from NeuroVisualizer.neuro_aux.AEmodel import UniformAutoencoder
from NeuroVisualizer.neuro_aux.utils import get_files
from NeuroVisualizer.neuro_aux.trajectories_data import get_trajectory_dataloader

In [None]:
# === Config ===
checkpoint_dir = "trainings/models_DenseNet_cifar10"
ae_save_path = "ae_models/ae_model_densenet.pt"
latent_dim = 2
num_layers = 3
batch_size = 32
num_epochs = 100
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.utils.data import Dataset
import torch

class FlatTensorDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        tensor = torch.load(self.file_paths[idx], map_location='cpu')
        if self.transform:
            tensor = self.transform(tensor)
        return tensor

def calculate_mean_std_flat(file_paths):
    weights = [torch.load(fp, map_location='cpu') for fp in file_paths]
    stacked = torch.stack(weights)
    mean = torch.mean(stacked, dim=0)
    std = torch.std(stacked, dim=0)
    return mean, std

from NeuroVisualizer.neuro_aux.trajectories_data import NormalizeModelParameters, ModelParamsDataset

def get_trajectory_dataloader_flat(pt_files, batch_size, normalize=True, shuffle=True):
    mean, std = calculate_mean_std_flat(pt_files)
    normalizer = NormalizeModelParameters(mean, std)
    dataset = FlatTensorDataset(pt_files, transform=normalizer if normalize else None)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle), normalizer

In [None]:
# === Load flattened checkpoints ===
pt_files = get_files(checkpoint_dir, prefix="model-")
loader, transform = get_trajectory_dataloader_flat(pt_files, batch_size)

In [None]:
# === Get input dimension from one sample ===
input_dim = loader.dataset[0].shape[0]

In [None]:
x = torch.load(pt_files[21])
print(x.shape, x.numel(), x.dtype)

In [None]:
# === Init AE ===
ae = UniformAutoencoder(input_dim, num_of_layers=num_layers, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(ae.parameters(), lr=lr)
loss_fn = nn.MSELoss()

In [None]:
# === AE Training Loop ===
for epoch in range(num_epochs):
    ae.train()
    total_loss = 0
    for x in loader:
        x = x.to(device)
        x_recon, _ = ae(x)
        loss = loss_fn(x_recon, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1:03d} | Loss: {total_loss / len(loader):.6f}")

In [None]:
# === Save AE Model ===
os.makedirs(os.path.dirname(ae_save_path), exist_ok=True)
torch.save(ae.state_dict(), ae_save_path)
print(f"AE model saved to {ae_save_path}")