### Packages and Libraries

In [None]:
# Choose available CUDAs for parallell computing
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,5"
print("This notebook's PID:", os.getpid())

In [3]:
import os
import numpy as np
from tqdm import tqdm
import random
import pandas as pd
from collections import defaultdict

import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16_bn, VGG16_BN_Weights

# Import models
import import_ipynb
from models import HyperspectralTransferCNN, ImprovedHybrid3D2DCNN_v2, VGG16WithAttention, VGG16WithCBAM

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: {device}.")

Current device: cuda.


### Paths

In [5]:
# HYPSO paths
cube_path = '/home/salyken/PRISMA/HYPSO_data/cube'
labels_path = '/home/salyken/PRISMA/HYPSO_data/labels'
list_path = '/home/salyken/PRISMA/HYPSO_data/list/hypso_labels.xlsx'
split_save_path = '/home/salyken/PRISMA/HYPSO_data/list/hypso_train_val_test_split.csv'

cube_files = sorted([f for f in os.listdir(cube_path) if f.endswith('.npy')])

### Investigate HYPSO-TreeLabels Files

In [None]:
def file_identifier(file_path): 

        # file_path = os.path.join(folder_path, file)
        
        if file_path.endswith(".xlsx"):
            df = pd.read_excel(file_path, engine="openpyxl")
            print(f"Opened Excel file: {file_path}")
            return df
    
        elif file_path.endswith(".tif"):
            with rasterio.open(file_path) as src:
                print(f"Opened TIFF file: {file_path}, Shape:", src.read(3).shape)
                img = src.read(1)  # Read the first band (1-based index)
                
            # Display the image
            plt.imshow(img, cmap="gray")
            plt.colorbar()
            plt.title("GeoTIFF - Single Band")
            plt.xlabel("Width (X)")
            plt.ylabel("Height (Y)")
            plt.show()
    
        else:
            print(f"Skipping unknown file: {file_path}")

df_excel = file_identifier(list_path)

Opened Excel file: /home/salyken/PRISMA/HYPSO_data/list/hypso_labels.xlsx


In [None]:
# Focus only on the relevant tree species
species = ['spruce', 'pine', 'deciduous']

# Sum counts across the full dataset
total_counts = df_excel[species].sum()

# Plotting
plt.figure(figsize=(6,5))
plt.bar(total_counts.index, total_counts.values, color=[(180/255, 100/255, 20/255, 0.8), (60/255, 180/255, 220/255, 0.8), (200/255, 100/255, 220/255, 0.8)])
plt.title('HYPSO-TreeLabels: Overall Tree species Class Distribution')
plt.xlabel('Tree Species')
plt.ylabel('Number of Samples')
plt.grid(axis='y')
plt.tight_layout()
plt.show()

### Preprocessing

In [None]:
def compute_bandwise_norm_from_cubes(cube_dir, train_cube_list, drop_first_n_bands=3):
    total_sum = None
    total_sum_sq = None
    total_pixels = 0

    for fname in tqdm(train_cube_list, desc="Computing HYPSO band stats"):
        cube = np.load(os.path.join(cube_dir, fname))  # (H, W, Bands)
        cube = cube[:, :, drop_first_n_bands:]  # Drop bad bands if needed → (H, W, usable_Bands)

        # Reshape to (Bands, H*W)
        cube = np.transpose(cube, (2, 0, 1)).reshape(cube.shape[2], -1)

        sum_ = np.sum(cube, axis=1)      # (Bands,)
        sum_sq = np.sum(cube**2, axis=1) # (Bands,)
        pixels = cube.shape[1]

        if total_sum is None:
            total_sum = sum_
            total_sum_sq = sum_sq
        else:
            total_sum += sum_
            total_sum_sq += sum_sq

        total_pixels += pixels

    band_means = total_sum / total_pixels
    band_stds = np.sqrt((total_sum_sq / total_pixels) - (band_means ** 2))
    
    print(" Band statistics computed.")
    return band_means, band_stds


In [None]:
def compute_bandwise_norm_from_projected_cubes(cube_dir, train_cube_list, projection_matrix, drop_first_n_bands=3):
    total_sum = None
    total_sum_sq = None
    total_pixels = 0

    for fname in tqdm(train_cube_list, desc="Computing band stats after projection"):
        cube = np.load(os.path.join(cube_dir, fname))  # (H, W, Bands)
        cube = cube[:, :, drop_first_n_bands:]  # Drop bad bands → (H, W, 117)

        # Apply projection: (63 x 117) @ (H x W x 117) → (63 x H x W)
        projected = np.tensordot(projection_matrix, cube, axes=([1], [2]))  # (63, H, W)
        projected = np.moveaxis(projected, 0, -1)  # (H, W, 63)

        # Reshape to (Bands, H*W)
        flat = np.transpose(projected, (2, 0, 1)).reshape(projected.shape[2], -1)

        sum_ = np.sum(flat, axis=1)      # (63,)
        sum_sq = np.sum(flat**2, axis=1) # (63,)
        pixels = flat.shape[1]

        if total_sum is None:
            total_sum = sum_
            total_sum_sq = sum_sq
        else:
            total_sum += sum_
            total_sum_sq += sum_sq

        total_pixels += pixels

    band_means = total_sum / total_pixels
    band_stds = np.sqrt((total_sum_sq / total_pixels) - (band_means ** 2))

    print(" Projected band statistics computed.")
    return band_means, band_stds


### Create Patch Dataset

In [None]:
from scipy.stats import mode
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import random
from collections import defaultdict

def create_patch_datasets(
    cube_dir, label_dir, cube_list,
    band_means=None, band_stds=None,
    patch_size=33, stride=4,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    projection_matrix=None,
    majority_label=False  # toggleable for center- vs mode-based labeling scheme
):
    class HYPSOPatchDataset(Dataset):
        def __init__(self, cube_data, index_map, band_means, band_stds, patch_size=5, augment=False, majority_label=False):
            self.cube_data = cube_data
            self.index_map = index_map
            self.band_means = band_means
            self.band_stds = band_stds
            self.patch_size = patch_size
            self.half = patch_size // 2
            self.augment = augment
            self.majority_label = majority_label  # store flag

        def __len__(self):
            return len(self.index_map)

        def __getitem__(self, idx):
            data = self.index_map[idx]
            if self.majority_label:
                cube_idx, i, j, majority = data
                label_val = majority - 1  # already validated
            else:
                cube_idx, i, j = data
                cube, label = self.cube_data[cube_idx]
                raw_label = int(label[i, j])
                if raw_label not in (1, 2, 3):
                    raise ValueError(f"Invalid label {raw_label} at index {idx}")
                label_val = raw_label - 1

            cube, _ = self.cube_data[cube_idx]
            patch = cube[
                i - self.half:i + self.half + 1,
                j - self.half:j + self.half + 1,
                :
            ]
            patch = np.transpose(patch, (2, 0, 1))
            patch = torch.tensor(patch, dtype=torch.float32)

            if self.band_means is not None and self.band_stds is not None:
                mean = torch.tensor(self.band_means[:, None, None], dtype=torch.float32)
                std = torch.tensor(self.band_stds[:, None, None], dtype=torch.float32)
                patch = (patch - mean) / (std + 1e-6)

            if self.augment:
                patch = self.apply_augmentations(patch)

            return patch, torch.tensor(label_val).long()

        def apply_augmentations(self, x):
            if torch.rand(1) < 0.5:
                x = torch.flip(x, dims=[1])
            if torch.rand(1) < 0.5:
                x = torch.flip(x, dims=[2])
            if torch.rand(1) < 0.5:
                x = torch.rot90(x, k=1, dims=[1, 2])
            if torch.rand(1) < 0.5:
                x += torch.randn_like(x) * 0.01
            return x

    # Step 1: preload cubes
    cube_data = []
    for fname in cube_list:
        cube = np.load(os.path.join(cube_dir, fname))[:, :, 3:]
        if projection_matrix is not None:
            cube = np.tensordot(projection_matrix, cube, axes=([1], [2]))
            cube = np.moveaxis(cube, 0, -1)

        label = np.loadtxt(
            os.path.join(label_dir, fname.replace('_l1d_cube.npy', '_labels.csv')),
            dtype=np.uint8
        )
        cube_data.append((cube, label))

    # Step 2: collect forest patches
    class_map = defaultdict(list)
    half = patch_size // 2

    for cube_idx, (cube, label) in enumerate(tqdm(cube_data, desc="Indexing patches")):
        h, w = label.shape
        for i in range(half, h - half, stride):
            for j in range(half, w - half, stride):
                if majority_label:
                    patch_labels = label[i - half:i + half + 1, j - half:j + half + 1]
                    valid = patch_labels[np.isin(patch_labels, [1, 2, 3])]
                    if valid.size > 0:
                        majority = mode(valid, axis=None).mode.item()
                        if majority in (1, 2, 3):
                            class_map[majority - 1].append((cube_idx, i, j, majority))  # include majority
                else:
                    class_label = label[i, j]
                    if class_label in (1, 2, 3):
                        class_map[class_label - 1].append((cube_idx, i, j))

    # Step 3: print class counts
    for cls in sorted(class_map.keys()):
        print(f" Class {cls}: selected {len(class_map[cls])} patches")

    # Step 4–6: shuffle, split, build datasets
    indices = [item for lst in class_map.values() for item in lst]
    random.seed(seed)
    random.shuffle(indices)

    n_total = len(indices)
    n_train = int(train_ratio * n_total)
    n_val = int(val_ratio * n_total)

    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]

    train_ds = HYPSOPatchDataset(cube_data, train_idx, band_means, band_stds, patch_size, augment=True, majority_label=majority_label)
    val_ds = HYPSOPatchDataset(cube_data, val_idx, band_means, band_stds, patch_size, augment=False, majority_label=majority_label)
    test_ds = HYPSOPatchDataset(cube_data, test_idx, band_means, band_stds, patch_size, augment=False, majority_label=majority_label)

    return train_ds, val_ds, test_ds


### With Mapping Layer

In [7]:
stats = torch.load('/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/mean_std/mean_std.pt', weights_only=False)

# Access tensors
band_means = stats['band_means']
band_stds = stats['band_stds']

In [None]:
# Check Shape
print(band_means.shape)

In [None]:
train_dataset, val_dataset, test_dataset = create_patch_datasets(
    cube_dir = cube_path, label_dir=labels_path, cube_list=cube_files,
    band_means=band_means, band_stds=band_stds,
    patch_size=71, stride=15,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    majority_label=True
)

### With Projection Matrix

In [9]:
projection_stats = torch.load('/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/mean_std/projection_mean_std_47.pt', weights_only=False)

# Access tensors
projection_band_means = projection_stats['band_means']
projection_band_stds = projection_stats['band_stds']

In [None]:
# Check shape
print(projection_band_means.shape)

In [None]:
W = np.load("/home/salyken/PRISMA/hypso_to_prisma_projection/hypso_to_prisma_projection.npy") # shape: (47,117)
print(W.shape)

train_dataset, val_dataset, test_dataset = create_patch_datasets(
    cube_dir = cube_path, label_dir=labels_path, cube_list=cube_files,
    band_means=projection_band_means, band_stds=projection_band_stds,
    patch_size=71, stride=15,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    projection_matrix = W,
    majority_label = True
)

### Training Loop for Pretrained PRISMA models

In [9]:
# Remember to change backbone freezing for each model

# EarlyStopping Helper
class EarlyStopping:
    def __init__(self, patience=70, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# Unified transfer learning trainer
def train_model(
    model,
    train_loader,
    val_loader,
    train_dataset,
    device,
    checkpoint_path,
    save_dir,
    resume=True,
    start_epoch=0,
    num_epochs=100,
    head_lr=1e-3,
    backbone_lr=2e-4,
    unfreeze_epoch=8,
    scheduler_patience=3,
    scheduler_factor=0.5
):
    os.makedirs(save_dir, exist_ok=True)
    chkpt_file = os.path.join(save_dir, "checkpoint.pth")
    print('Hypso save directory:', chkpt_file)
    print(" File already exists:", os.path.exists(chkpt_file))

    if resume and os.path.exists(chkpt_file):
        print(f" Resuming training from: {chkpt_file}")
        resume_ckpt = torch.load(chkpt_file, map_location=device)
        model.load_state_dict(resume_ckpt['model_state_dict'])
        optimizer_state = resume_ckpt.get('optimizer_state_dict')

        start_epoch = resume_ckpt.get('epoch', start_epoch) + 1
    else:
        print(f" Loading pretrained checkpoint from: {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location=device)
        state_dict = ckpt.get('model_state_dict', ckpt)

        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module."):
                k = k[len("module."):]
            new_state_dict[k] = v

        excluded_prefixes = ['mapping.', 'classifier.']
        filtered = {
            k: v for k, v in new_state_dict.items()
            if not any(k.startswith(prefix) for prefix in excluded_prefixes)
        }

        model_dict = model.state_dict()
        compatible_weights = {
            k: v for k, v in filtered.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }

        model_dict.update(compatible_weights)
        model.load_state_dict(model_dict)

        print(f" Loaded {len(compatible_weights)} compatible pretrained layers.")
        skipped = [k for k in filtered if k not in compatible_weights]
        if skipped:
            print(f" Skipped {len(skipped)} incompatible layers (shape mismatch):")
            for k in skipped:
                print(f" - {k} (saved shape: {filtered[k].shape}, model shape: {model_dict.get(k, 'N/A')})")

        optimizer_state = None

    model.to(device)

    # VGG-16 CBAM:
    backbone_prefixes = ['block1', 'block2', 'block3', 'block4', 'block5']
    # VGG-16 sSE:
    # backbone_prefixes = [f'blocks.{i}' for i in range(5)]
    # VGG-16:
    # backbone_prefixes = [f"features.{i}" for i in range(0, 43)]
    # Hybrid: 
    # backbone_prefixes = [
    #     *[f"encoder3d.{i}" for i in range(5)],  # freeze early 3D conv layers
    #     *[f"features2d.{i}" for i in range(11)]  # 2D CNN head
    # ]
    # 2D: 
    # backbone_prefixes = [f"features.{i}" for i in range(0, 3)]



    if resume and start_epoch > unfreeze_epoch:
        print(f"↪ Resuming after unfreeze epoch; keeping backbone unfrozen")
        for name, param in model.named_parameters():
            if any(name.startswith(p) for p in backbone_prefixes):
                param.requires_grad = True
        print("\n🔓 Re-checking trainable parameters after unfreeze:")
        for name, param in model.named_parameters():
            status = "✅ trainable" if param.requires_grad else "❌ frozen"
            print(f"{name:50s} | {status}")
    else:
        print(f"↪ Freezing backbone until epoch {unfreeze_epoch}")
        for name, param in model.named_parameters():
            if any(name.startswith(p) for p in backbone_prefixes):
                param.requires_grad = False

    print("\n🔍 Trainable Parameters Summary:")
    for name, param in model.named_parameters():
        status = "✅ trainable" if param.requires_grad else "❌ frozen"
        print(f"{name:50s} | {status}")

    if hasattr(train_dataset, 'class_counts'):
        counts = np.array([train_dataset.class_counts[c] for c in sorted(train_dataset.class_counts)])
    else:
        print("⚖️ Estimating class weights from the full training dataset...")
        all_labels = []
        for idx in tqdm(range(len(train_dataset)), desc="Scanning dataset"):
            try:
                label = int(train_dataset[idx][1])
                all_labels.append(label)
            except Exception as e:
                print(f" Error on sample {idx}: {e}")

        if len(all_labels) == 0:
            raise RuntimeError("No valid labels found — check your dataset logic!")

        num_classes = model.num_classes if hasattr(model, 'num_classes') else len(set(all_labels))
        counts = np.bincount(all_labels, minlength=num_classes)

    weights = counts.max() / counts
    class_weights = torch.tensor(weights, dtype=torch.float32, device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    head_params = [p for n, p in model.named_parameters() if not any(n.startswith(b) for b in backbone_prefixes)]
    backbone_params = [p for n, p in model.named_parameters() if any(n.startswith(b) for b in backbone_prefixes)]

    if optimizer_state:
        saved_param_groups = optimizer_state.get("param_groups", [])
        if len(saved_param_groups) == 1:
            print("🔁 Restoring optimizer with 1 param group (head only)")
            optimizer = optim.Adam(
                [{'params': head_params, 'lr': head_lr}],
                weight_decay=1e-5
            )
        else:
            print("🔁 Restoring optimizer with 2 param groups (head + backbone)")
            optimizer = optim.Adam(
                [
                    {'params': head_params, 'lr': head_lr},
                    {'params': backbone_params, 'lr': backbone_lr}
                ],
                weight_decay=1e-5
            )
    else:
        print("⚙️ Initializing fresh optimizer with 2 param groups")
        optimizer = optim.Adam(
            [
                {'params': head_params, 'lr': head_lr},
                {'params': backbone_params, 'lr': backbone_lr}
            ],
            weight_decay=1e-5
        )

    if optimizer_state:
        try:
            optimizer.load_state_dict(optimizer_state)
            print(" Optimizer state successfully restored.")
        except ValueError as e:
            print(f" Optimizer state mismatch: {e}")
            print(" Proceeding with freshly initialized optimizer.")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience)
    best_val_acc = float('-inf')

    for epoch in range(start_epoch, num_epochs):
        if epoch == unfreeze_epoch:
            print(f" Unfreezing backbone at epoch {epoch}")
            for name, param in model.named_parameters():
                if any(name.startswith(p) for p in backbone_prefixes):
                    param.requires_grad = True
            backbone_params = [p for n, p in model.named_parameters() if any(n.startswith(b) for b in backbone_prefixes)]

        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=True):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * targets.size(0)
            train_correct += (outputs.argmax(1) == targets).sum().item()
            train_total += targets.size(0)
        train_loss /= train_total
        train_acc = train_correct / train_total

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validating", leave=True):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * targets.size(0)
                val_correct += (outputs.argmax(1) == targets).sum().item()
                val_total += targets.size(0)
        val_loss /= val_total
        val_acc = val_correct / val_total

        print(f"Epoch [{epoch+1}/{num_epochs}] | Train: {train_loss:.4f}, {train_acc:.4f} | Val: {val_loss:.4f}, {val_acc:.4f}")

        scheduler.step(val_loss)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_path = os.path.join(save_dir, "best_val_acc.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, best_path)
            print(f" Saved best model at epoch {epoch+1} with val acc {val_acc:.4f} → {best_path}")

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, chkpt_file)

    print("🏁 Training complete; best val acc:", best_val_acc)


### Training Loop From Scratch

In [10]:
def train_without_pretrained_model(
    model,
    train_loader,
    val_loader,
    train_dataset,
    device,
    checkpoint_path,
    save_dir,
    resume=False,  # default to False for scratch training
    start_epoch=0,
    num_epochs=100,
    head_lr=1e-3,
    backbone_lr=2e-4,  # kept in case you later want to use two groups again
    unfreeze_epoch=8,
    scheduler_patience=3,
    scheduler_factor=0.5
):
    os.makedirs(save_dir, exist_ok=True)
    chkpt_file = os.path.join(save_dir, "checkpoint.pth")
    print('Hypso save directory:', chkpt_file)
    print(" File already exists:", os.path.exists(chkpt_file))

    if resume and os.path.exists(chkpt_file):
        print(f"🔄 Resuming training from: {chkpt_file}")
        resume_ckpt = torch.load(chkpt_file, map_location=device)
        model.load_state_dict(resume_ckpt['model_state_dict'])
        optimizer_state = resume_ckpt.get('optimizer_state_dict')
        start_epoch = resume_ckpt.get('epoch', start_epoch) + 1
    else:
        print(" Training from scratch — no pretrained weights loaded.")
        optimizer_state = None

    model.to(device)

    print(" Training all parameters (no freezing needed for scratch training)")
    for param in model.parameters():
        param.requires_grad = True

    print("\n🔍 Trainable Parameters Summary:")
    for name, param in model.named_parameters():
        status = "✅ trainable" if param.requires_grad else "❌ frozen"
        print(f"{name:50s} | {status}")

    if hasattr(train_dataset, 'class_counts'):
        counts = np.array([train_dataset.class_counts[c] for c in sorted(train_dataset.class_counts)])
    else:
        print("⚖️ Estimating class weights from the full training dataset...")
        all_labels = []
        for idx in tqdm(range(len(train_dataset)), desc="Scanning dataset"):
            try:
                label = int(train_dataset[idx][1])
                all_labels.append(label)
            except Exception as e:
                print(f" Error on sample {idx}: {e}")

        if len(all_labels) == 0:
            raise RuntimeError("No valid labels found — check your dataset logic!")

        num_classes = model.num_classes if hasattr(model, 'num_classes') else len(set(all_labels))
        counts = np.bincount(all_labels, minlength=num_classes)

    weights = counts.max() / counts
    class_weights = torch.tensor(weights, dtype=torch.float32, device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    if optimizer_state:
        print(" Restoring optimizer")
        optimizer = optim.Adam(model.parameters(), lr=head_lr, weight_decay=1e-5)
    else:
        print(" Initializing fresh optimizer (all parameters)")
        optimizer = optim.Adam(model.parameters(), lr=head_lr, weight_decay=1e-5)

    if optimizer_state:
        try:
            optimizer.load_state_dict(optimizer_state)
            print(" Optimizer state successfully restored.")
        except ValueError as e:
            print(f" Optimizer state mismatch: {e}")
            print(" Proceeding with freshly initialized optimizer.")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience)
    best_val_acc = float('-inf')

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=True):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * targets.size(0)
            train_correct += (outputs.argmax(1) == targets).sum().item()
            train_total += targets.size(0)
        train_loss /= train_total
        train_acc = train_correct / train_total

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validating", leave=True):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * targets.size(0)
                val_correct += (outputs.argmax(1) == targets).sum().item()
                val_total += targets.size(0)
        val_loss /= val_total
        val_acc = val_correct / val_total

        print(f"Epoch [{epoch+1}/{num_epochs}] | Train: {train_loss:.4f}, {train_acc:.4f} | Val: {val_loss:.4f}, {val_acc:.4f}")

        scheduler.step(val_loss)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_path = os.path.join(save_dir, "best_val_acc.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, best_path)
            print(f" Saved best model at epoch {epoch+1} with val acc {val_acc:.4f} → {best_path}")

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, chkpt_file)

    print("🏁 Training complete; best val acc:", best_val_acc)


### Train models

In [15]:
model = VGG16WithCBAM(in_channels=117)

In [16]:
sample, label = train_dataset[0]
print(sample.shape)

torch.Size([117, 71, 71])


In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

prisma_ckpt = "/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/model/VGG16_w_cbam_71_patch_clipped/best_val_acc.pth"
save_dir = "/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/models/VGG16_w_cbam_71_patch_clipped"

train_model(model, train_loader, val_loader, train_dataset, device, prisma_ckpt, save_dir, resume=False, unfreeze_epoch=4)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

prisma_ckpt = "/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/model/VGG16_w_cbam_71_patch_clipped/best_val_acc.pth"
save_dir = "/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/models/VGG16_w_cbam_71_patch_clipped"

train_without_pretrained_model(
    model=model,  
    train_loader=train_loader,
    val_loader=val_loader,
    train_dataset=train_dataset,
    device=device,
    checkpoint_path=None,  # not used for scratch training
    save_dir=save_dir,  
    resume=False,  
    num_epochs=100,  
    head_lr=1e-3  
)