In [1]:
cd ..

/home/jovyan/work/SensoriumDecoding


In [2]:
import numpy as np
from sklearn.linear_model import Ridge
import argparse
import torch
import wandb
import os
import sys
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from torch.nn.functional import mse_loss
from skimage.metrics import structural_similarity as ssim
import typing as t
from tqdm import tqdm
from time import time
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr

import src.utils.data as data
from src.utils.data import load_args, get_training_ds, DataLoader
from src.utils.losses import get_criterion
import src.utils.utils as utils
from src.utils.utils import Logger, Scheduler
from src.models.model import get_model, FullModel

class Args:
    def __init__(self):
        self.output_dir = "runs/DNN/first_run"
args = Args()
load_args(args)


In [4]:
def gather(result: t.Dict[str, t.List[torch.Tensor]]):
    return {k: torch.sum(torch.stack(v)).cpu() for k, v in result.items()}

def vstack(tensors: t.List[torch.Tensor]):
    return torch.vstack(tensors).cpu()

@torch.no_grad()
def compute_metrics(y_true: torch.Tensor, y_pred: torch.Tensor):
    """Metrics to compute as part of training and validation step"""
    batch_size, h, w = y_true.size()
    # Reshape tensors to match the required format for calculations
    y_true = y_true.view(batch_size, -1)
    y_pred = y_pred.view(batch_size, -1)
    # Pixel wise correlation
    correlations = []
    for i in range(y_true.shape[1]):
        correlation, _ = pearsonr(y_true[:, i].cpu(), y_pred[:, i].cpu())
        correlations.append(correlation)
    correlation = torch.tensor(correlations).mean().item()
    # RMSE
    rmse = torch.sqrt(mse_loss(y_true, y_pred))
    # SSIM
    ssim_score = torch.tensor([ssim(y_true[i].cpu().numpy(), y_pred[i].cpu().numpy(), data_range=4) for i in range(batch_size)]).mean()
    return {
        "correlation": correlation,
        "rmse": rmse.item(),
        "ssim": ssim_score.item(),
    }

def train_step(
    mouse_id: str,
    batch: t.Dict[str, torch.Tensor],
    model: FullModel,
    optimizer: torch.optim,
    criterion: torch.nn.Module,
    update: bool,
    micro_batch_size: int,
    device: torch.device = "cpu"
):
    # Define single train step with microbatching
    model.to(device)
    batch_size = batch["image"].size(0)
    result = {"loss/loss": []}
    for micro_batch in data.micro_batching(batch, micro_batch_size):
        y_true = micro_batch["image"].to(device)
        y_pred = model(
            x=micro_batch["response"].to(device),
            mouse_id=mouse_id,
            behaviours=micro_batch["behavior"].to(device),
            pupil_centers=micro_batch["pupil_center"].to(device)
        )
        y_pred = y_pred.view(y_true.size(0), 36, 64)
        loss = criterion(
            y_true=y_true,
            y_pred=y_pred,
            mouse_id=mouse_id,
            batch_size=batch_size,
        )
        loss.backward()
        result["loss/loss"].append(loss.detach())
    if update:
        optimizer.step()
        optimizer.zero_grad()
    return gather(result)

def train(
    args,
    ds: t.Dict[str, DataLoader],
    model: FullModel,
    optimizer: torch.optim,
    criterion: torch.nn.Module,
    epoch: int
):
    mouse_ids = list(ds.keys())
    results = {mouse_id: {} for mouse_id in mouse_ids}
    ds = data.CycleDataloaders(ds)
    update_frequency = len(mouse_ids)
    model.train(True)
    optimizer.zero_grad()
    for i, (mouse_id, mouse_batch) in tqdm(
        enumerate(ds), desc="Train", total=len(ds), disable=args.verbose < 2
    ):
        result = train_step(
            mouse_id=mouse_id,
            batch=mouse_batch,
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            update=(i + 1) % update_frequency == 0,
            micro_batch_size=args.micro_batch_size,
            device=args.device
        )
        utils.update_dict(results[mouse_id], result)
    return utils.log_metrics(results, epoch=epoch, mode=0)

@torch.no_grad()
def validation_step(
    mouse_id: str,
    batch: t.Dict[str, torch.Tensor],
    model: FullModel,
    criterion: torch.nn.Module,
    micro_batch_size: int,
    device: torch.device = "cpu"
):
    model.to(device)
    batch_size = batch["image"].size(0)
    result = {"loss/loss": []}
    targets, predictions = [], []
    for micro_batch in data.micro_batching(batch, micro_batch_size):
        y_true = micro_batch["image"].to(device)
        y_pred = model(
            x=micro_batch["response"].to(device),
            mouse_id=mouse_id,
            behaviours=micro_batch["behavior"].to(device),
            pupil_centers=micro_batch["pupil_center"].to(device),
        )
        y_pred = y_pred.view(y_true.size(0), 36, 64)
        loss = criterion(
            y_true=y_true,
            y_pred=y_pred,
            mouse_id=mouse_id,
            batch_size=batch_size,
        )
        result["loss/loss"].append(loss)
        targets.append(y_true)
        predictions.append(y_pred)
    return gather(result), vstack(targets), vstack(predictions)


def validate(
    args,
    ds: t.Dict[str, DataLoader],
    model: FullModel,
    criterion: torch.nn.Module,
    epoch: int
):
    model.train(False)
    results = {}
    with tqdm(desc="Val", total=len(ds), disable=args.verbose < 2) as pbar:
        for mouse_id, mouse_ds in ds.items():
            mouse_result, y_true, y_pred = {}, [], []
            for batch in mouse_ds:
                result, targets, predictions = validation_step(
                    mouse_id=mouse_id,
                    batch=batch,
                    model=model,
                    criterion=criterion,
                    micro_batch_size=args.micro_batch_size,
                    device=args.device
                )
                utils.update_dict(mouse_result, result)
                y_true.append(targets)
                y_pred.append(predictions)
                pbar.update(1)
            y_true, y_pred = vstack(y_true), vstack(y_pred)
            mouse_result.update(compute_metrics(y_true=y_true, y_pred=y_pred))
            results[mouse_id] = mouse_result
            del y_true, y_pred
    return utils.log_metrics(results, epoch=epoch, mode=1)

In [18]:
def save_checkpoint(model, optimizer, scheduler, args, epoch, history):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'history': history
    }
    torch.save(checkpoint, args.output_dir + "/test.pt")

def load_checkpoint(args, optimizer, scheduler):
    checkpoint = torch.load(args.output_dir + "/test.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    epoch = checkpoint['epoch']
    history = checkpoint['history']
    return model, optimizer, scheduler, epoch, history

In [19]:
model = get_model(args).to(args.device)  # Define the model architecture
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
history = {"train_loss": [], "val_loss": [], "val_correlation": []}
save_checkpoint(model, optimizer, scheduler, args, 1, history)
model2, optimizer2, scheduler2, start_epoch, history = load_checkpoint(args, optimizer, scheduler)

In [20]:
all(torch.allclose(model.state_dict()[key], model2.state_dict()[key]) for key in model.state_dict())

True

In [21]:
def save_checkpoint(model, optimizer, scheduler, args, epoch, history):
    checkpoint_dir = os.path.join(args.output_dir, "checkpoints", f"epoch_{epoch}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Save model
    model_path = os.path.join(checkpoint_dir, "model.pt")
    torch.save(model, model_path)
    
    # Save optimizer
    optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt")
    torch.save(optimizer.state_dict(), optimizer_path)
    
    # Save scheduler
    scheduler_path = os.path.join(checkpoint_dir, "scheduler.pt")
    torch.save(scheduler.state_dict(), scheduler_path)
    
    # Save other necessary components
    history_path = os.path.join(checkpoint_dir, "history.pt")
    torch.save(history, history_path)

def load_checkpoint(args, optimizer, scheduler):
    checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
    
    # Get list of checkpoint folders
    checkpoint_folders = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
    
    if not checkpoint_folders:
        raise FileNotFoundError("No checkpoint folders found.")
    
    # Get the latest epoch folder
    latest_epoch_folder = max(checkpoint_folders, key=lambda x: int(x.split("_")[-1]))
    checkpoint_dir = os.path.join(checkpoint_dir, latest_epoch_folder)
    
    # Load model
    model_path = os.path.join(checkpoint_dir, "model.pt")
    model = torch.load(model_path)
    model.train(True)
    
    # Load optimizer
    optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt")
    optimizer_state_dict = torch.load(optimizer_path)
    optimizer.load_state_dict(optimizer_state_dict)
    
    # Load scheduler
    scheduler_path = os.path.join(checkpoint_dir, "scheduler.pt")
    scheduler_state_dict = torch.load(scheduler_path)
    scheduler.load_state_dict(scheduler_state_dict)
    
    # Extract epoch from folder name
    epoch = int(latest_epoch_folder.split("_")[-1])
    
    # Load history
    history_path = os.path.join(checkpoint_dir, "history.pt")
    history = torch.load(history_path)
    
    return model, optimizer, scheduler, epoch, history

In [22]:
args.resume = False
model = get_model(args).to(args.device)  # Define the model architecture
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
history = {"train_loss": [], "val_loss": [], "val_correlation": []}
save_checkpoint(model, optimizer, scheduler, args, 1, history)
model2, optimizer2, scheduler2, start_epoch, history = load_checkpoint(args, optimizer, scheduler)

In [23]:
all(torch.allclose(model.state_dict()[key], model2.state_dict()[key]) for key in model.state_dict())

True