### Packages and Libraries

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
print("This notebook's PID:", os.getpid())

In [None]:
import os
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import rasterio
from rasterio.transform import from_origin
import numpy as np
import matplotlib.animation as animation
from skimage.transform import resize
from skimage import exposure
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16_bn, VGG16_BN_Weights

import random
from collections import defaultdict

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

import gc

import time
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    accuracy_score,
    cohen_kappa_score
)

import import_ipynb
from models import HyperspectralTransferCNN, ImprovedHybrid3D2DCNN_v2, VGG16WithAttention, VGG16WithCBAM

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: {device}.")

Current device: cuda.


### Paths

In [5]:
# HYPSO paths
cube_path = '/home/salyken/PRISMA/HYPSO_data/cube'
labels_path = '/home/salyken/PRISMA/HYPSO_data/labels'
list_path = '/home/salyken/PRISMA/HYPSO_data/list/hypso_labels.xlsx'
split_save_path = '/home/salyken/PRISMA/HYPSO_data/list/hypso_train_val_test_split.csv'

cube_files = sorted([f for f in os.listdir(cube_path) if f.endswith('.npy')])


### Creating Patch Dataset

In [3]:
from scipy.stats import mode
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import random
from collections import defaultdict

def create_patch_datasets(
    cube_dir, label_dir, cube_list,
    band_means=None, band_stds=None,
    patch_size=33, stride=4,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    projection_matrix=None,
    majority_label=False  # toggleable
):
    class HYPSOPatchDataset(Dataset):
        def __init__(self, cube_data, index_map, band_means, band_stds, patch_size=5, augment=False, majority_label=False):
            self.cube_data = cube_data
            self.index_map = index_map
            self.band_means = band_means
            self.band_stds = band_stds
            self.patch_size = patch_size
            self.half = patch_size // 2
            self.augment = augment
            self.majority_label = majority_label  # store flag

        def __len__(self):
            return len(self.index_map)

        def __getitem__(self, idx):
            data = self.index_map[idx]
            if self.majority_label:
                cube_idx, i, j, majority = data
                label_val = majority - 1  # already validated
            else:
                cube_idx, i, j = data
                cube, label = self.cube_data[cube_idx]
                raw_label = int(label[i, j])
                if raw_label not in (1, 2, 3):
                    raise ValueError(f"Invalid label {raw_label} at index {idx}")
                label_val = raw_label - 1

            cube, _ = self.cube_data[cube_idx]
            patch = cube[
                i - self.half:i + self.half + 1,
                j - self.half:j + self.half + 1,
                :
            ]
            patch = np.transpose(patch, (2, 0, 1))
            patch = torch.tensor(patch, dtype=torch.float32)

            if self.band_means is not None and self.band_stds is not None:
                mean = torch.tensor(self.band_means[:, None, None], dtype=torch.float32)
                std = torch.tensor(self.band_stds[:, None, None], dtype=torch.float32)
                patch = (patch - mean) / (std + 1e-6)

            if self.augment:
                patch = self.apply_augmentations(patch)

            return patch, torch.tensor(label_val).long()

        def apply_augmentations(self, x):
            if torch.rand(1) < 0.5:
                x = torch.flip(x, dims=[1])
            if torch.rand(1) < 0.5:
                x = torch.flip(x, dims=[2])
            if torch.rand(1) < 0.5:
                x = torch.rot90(x, k=1, dims=[1, 2])
            if torch.rand(1) < 0.5:
                x += torch.randn_like(x) * 0.01
            return x

    # Step 1: preload cubes
    cube_data = []
    for fname in cube_list:
        cube = np.load(os.path.join(cube_dir, fname))[:, :, 3:]
        if projection_matrix is not None:
            cube = np.tensordot(projection_matrix, cube, axes=([1], [2]))
            cube = np.moveaxis(cube, 0, -1)

        label = np.loadtxt(
            os.path.join(label_dir, fname.replace('_l1d_cube.npy', '_labels.csv')),
            dtype=np.uint8
        )
        cube_data.append((cube, label))

    # Step 2: collect forest patches
    class_map = defaultdict(list)
    half = patch_size // 2

    for cube_idx, (cube, label) in enumerate(tqdm(cube_data, desc="Indexing patches")):
        h, w = label.shape
        for i in range(half, h - half, stride):
            for j in range(half, w - half, stride):
                if majority_label:
                    patch_labels = label[i - half:i + half + 1, j - half:j + half + 1]
                    valid = patch_labels[np.isin(patch_labels, [1, 2, 3])]
                    if valid.size > 0:
                        majority = mode(valid, axis=None).mode.item()
                        if majority in (1, 2, 3):
                            class_map[majority - 1].append((cube_idx, i, j, majority))  #  include majority
                else:
                    class_label = label[i, j]
                    if class_label in (1, 2, 3):
                        class_map[class_label - 1].append((cube_idx, i, j))

    # Step 3: print class counts
    for cls in sorted(class_map.keys()):
        print(f" Class {cls}: selected {len(class_map[cls])} patches")

    # Step 4–6: shuffle, split, build datasets
    indices = [item for lst in class_map.values() for item in lst]
    random.seed(seed)
    random.shuffle(indices)

    n_total = len(indices)
    n_train = int(train_ratio * n_total)
    n_val = int(val_ratio * n_total)

    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]

    train_ds = HYPSOPatchDataset(cube_data, train_idx, band_means, band_stds, patch_size, augment=True, majority_label=majority_label)
    val_ds = HYPSOPatchDataset(cube_data, val_idx, band_means, band_stds, patch_size, augment=False, majority_label=majority_label)
    test_ds = HYPSOPatchDataset(cube_data, test_idx, band_means, band_stds, patch_size, augment=False, majority_label=majority_label)

    return train_ds, val_ds, test_ds


### With Mapping Layer

In [7]:
stats = torch.load('/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/mean_std/mean_std.pt', weights_only=False)

# Access tensors
band_means = stats['band_means']
band_stds = stats['band_stds']

In [None]:
print(band_means.shape)

In [None]:
train_dataset, val_dataset, test_dataset = create_patch_datasets(
    cube_dir = cube_path, label_dir=labels_path, cube_list=cube_files,
    band_means=band_means, band_stds=band_stds,
    patch_size=71, stride=15,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    majority_label=True
)

### With Projection Matrix

In [9]:
projection_stats = torch.load('/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/mean_std/projection_mean_std_47.pt', weights_only=False)

# Access tensors
projection_band_means = projection_stats['band_means']
projection_band_stds = projection_stats['band_stds']

In [None]:
print(projection_band_means.shape)

In [None]:
W = np.load("/home/salyken/PRISMA/hypso_to_prisma_projection/hypso_to_prisma_projection.npy") #shape(47,117)
print(W.shape)

train_dataset, val_dataset, test_dataset = create_patch_datasets(
    cube_dir = cube_path, label_dir=labels_path, cube_list=cube_files,
    band_means=projection_band_means, band_stds=projection_band_stds,
    patch_size=71, stride=15,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    seed=42,
    projection_matrix = W,
    majority_label=True
)

In [12]:
for patch in test_dataset[1]:
    print(patch.shape)

torch.Size([47, 15, 15])
torch.Size([])


### Evaluation

In [None]:
def evaluate_model(model, test_loader, device, class_names=None, show_confusion=True, show_timing=True):
    model.eval()
    all_preds = []
    all_labels = []

    # Start timing
    if show_timing:
        start_time = time.time()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Stop timing
    if show_timing:
        end_time = time.time()
        total_time = end_time - start_time
        avg_time = total_time / len(test_loader.dataset)
        print(f"\n Test Time: {total_time:.2f} sec")
        print(f" Avg Inference Time per Sample: {avg_time:.6f} sec")

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Classification Report
    print("\n Classification Report:")
    print(classification_report(
        all_labels, all_preds,
        target_names=class_names if class_names else [f"Class {i}" for i in sorted(set(all_labels))]
    ))

    # Confusion Matrix
    if show_confusion:
        cm = confusion_matrix(all_labels, all_preds)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                      display_labels=class_names if class_names else [f"Class {i}" for i in sorted(set(all_labels))])
        disp.plot(cmap='Blues', values_format='d')
        plt.title("Confusion Matrix")
        plt.grid(False)
        plt.show()

    # === Compute Metrics ===
    oa = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    aa = np.mean(per_class_acc)
    kappa = cohen_kappa_score(all_labels, all_preds)

    print(f"\n Overall Accuracy (OA): {oa*100:.2f}%")
    print(f" Average Accuracy (AA): {aa*100:.2f}%")
    print(f" Kappa Coefficient (K×100): {kappa*100:.2f}")

    # Print per-class accuracy
    print("\n Per-Class Accuracy:")
    for i, acc in enumerate(per_class_acc):
        class_label = class_names[i] if class_names else f"Class {i}"
        print(f"{class_label}: {acc*100:.2f}%")

    return {
        "OA": oa,
        "AA": aa,
        "Kappa": kappa,
        "PerClassAccuracy": per_class_acc
    }


In [None]:

model = VGG16WithAttention(input_bands=117)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

ckpt = torch.load('/home/salyken/PRISMA/HYPSO_data/HYPSO_dataset_processed/models/VGG16_w_att_71_patch/without_prisma/checkpoint.pth', map_location=device)
print(f" Loaded checkpoint from epoch {ckpt['epoch']}")
model.load_state_dict(ckpt['model_state_dict'], strict=True)

model.to(device)

# Run evaluation
evaluate_model(
    model=model,
    test_loader=test_loader,
    device=device,
    class_names=["Spruce", "Pine", "Deciduous"]
)
