In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import pandas as pd
import os
import cv2
from tqdm import tqdm
from scipy.interpolate import CubicSpline
import time

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import font_manager, rcParams

import warnings
from collections import defaultdict
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

PAST = 100
FUTURE = 10
TOTAL_WINDOW = PAST + FUTURE
REAL_SEQ_DIM = 28

def load_pretrained_from_lab(model: Visual_Module_Real,
                             pretrained_path: str):
    if not os.path.exists(pretrained_path):
        print(f"Warning: Pretrained model not found: {pretrained_path}")
        return

    print(f"\nLoading pretrained weights from lab model: {pretrained_path}")
    state_dict = torch.load(pretrained_path, map_location="cpu")
    model_dict = model.state_dict()

    compatible_dict = {}
    for k, v in state_dict.items():
        if k in model_dict and model_dict[k].shape == v.shape:
            compatible_dict[k] = v
        else:
            pass

    print(f"  - Compatible parameters: {len(compatible_dict)}/{len(model_dict)}")
    model_dict.update(compatible_dict)
    model.load_state_dict(model_dict)
    print("  ✅ Pretrained weights loaded (incompatible layers skipped)")

def freeze_backbone_for_transfer(model: Visual_Module_Real):
    for name, param in model.named_parameters():
        param.requires_grad = False

    train_keys = ["temporal_module", "branch_att", "decoder_4", "output_layer"]
    for name, param in model.named_parameters():
        if any(k in name for k in train_keys):
            param.requires_grad = True

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTransfer learning setup:")
    print(f"  - Total parameters: {total_params / 1e6:.2f} M")
    print(f"  - Trainable parameters: {trainable_params / 1e6:.2f} M")
    print(f"  - Training modules: {train_keys}")


class CustomDataset(Dataset):
    def __init__(self, image, sequence, targets):
        self.image = torch.tensor(image, dtype=torch.float32)
        self.sequence = torch.tensor(sequence, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return (self.image[idx],
                self.sequence[idx],
                self.targets[idx])

def train_model(train_loader, val_loader, model, epochs,
                criterion, optimizer, scheduler, save_path):
    train_loss = []
    val_loss = []
    model.to(device)

    print("\nTraining configuration (real vehicle transfer):")
    print(f"  - Device: {device}")
    print(f"  - Training set size: {len(train_loader.dataset)}")
    print(f"  - Validation set size: {len(val_loader.dataset)}")
    print(f"  - Batch size: {train_loader.batch_size}")
    print(f"  - Total epochs: {epochs}")
    print(f"  - Model save path: {save_path}\n")

    best_val_loss = float("inf")
    patience = 15
    patience_counter = 0
    
    total_train_time = 0.0
    epoch_times = []

    for epoch in range(epochs):
        epoch_start_time = time.time()
        
        model.train()
        batch_loss = []
        train_loader_tqdm = tqdm(train_loader,
                                 desc=f"[Real vehicle transfer] Training Epoch {epoch + 1}/{epochs}")
        for img_in, seq_in, targets in train_loader_tqdm:
            img_in = img_in.to(device)
            seq_in = seq_in.to(device)
            targets = targets.to(device)

            outputs = model(img_in, seq_in)
            optimizer.zero_grad()
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            batch_loss.append(loss.item())
            train_loader_tqdm.set_postfix({'loss': f'{loss.item():.6f}'})

        train_loss.append(np.mean(batch_loss))

        model.eval()
        val_batch_loss = []
        with torch.no_grad():
            for img_in, seq_in, targets in val_loader:
                img_in = img_in.to(device)
                seq_in = seq_in.to(device)
                targets = targets.to(device)

                outputs = model(img_in, seq_in)
                loss = criterion(outputs, targets)
                val_batch_loss.append(loss.item())

        val_loss.append(np.mean(val_batch_loss))

        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)
        total_train_time += epoch_time

        if scheduler is not None:
            scheduler.step(val_loss[-1])
            current_lr = optimizer.param_groups[0]['lr']
        else:
            current_lr = None

        if val_loss[-1] < best_val_loss:
            best_val_loss = val_loss[-1]
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            lr_str = f", LR: {current_lr:.2e}" if current_lr is not None else ""
            print(
                f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss[-1]:.6f}, "
                f"Val Loss: {val_loss[-1]:.6f}{lr_str}, Time: {epoch_time:.2f}s ⭐ (saved best model)"
            )
        else:
            patience_counter += 1
            lr_str = f", LR: {current_lr:.2e}" if current_lr is not None else ""
            print(
                f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss[-1]:.6f}, "
                f"Val Loss: {val_loss[-1]:.6f}{lr_str}, Time: {epoch_time:.2f}s"
            )

        if patience_counter >= patience:
            print(f"\nEarly stopping triggered: validation loss not improved for {patience} epochs")
            break

    avg_epoch_time = np.mean(epoch_times)
    total_train_time_min = total_train_time / 60
    
    print("\nReal vehicle transfer training completed!")
    print(f"  - Best validation loss: {best_val_loss:.6f}")
    print(f"  - Total training time: {total_train_time_min:.2f} minutes")
    print(f"  - Average epoch time: {avg_epoch_time:.2f} seconds")

    return train_loss, val_loss, total_train_time_min, avg_epoch_time
