### Packages and Libraries

In [None]:
# Choose available CUDAs for parallell computing
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,5"
print("This notebook's PID:", os.getpid())

In [2]:
import pandas as pd
import h5py
import matplotlib.pyplot as plt
from rasterio.transform import from_origin
import numpy as np
import matplotlib.animation as animation
from skimage.transform import resize
from skimage import exposure
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16_bn, VGG16_BN_Weights

import random
from collections import defaultdict

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

import gc

# Import models
import import_ipynb
from models import HyperspectralTransferCNN, ImprovedHybrid3D2DCNN_v2, ImprovedHybrid3D2DCNN_v3, VGG16WithAttention, VGG16WithCBAM

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: {device}.")

Current device: cuda.


### Paths

In [6]:
folder_path = "/home/_shared/ARIEL/Faubai/"
test_folder_path = '/home/_shared/ARIEL/Faubai/TEST'
he5_directory = "/home/_shared/ARIEL/Faubai/datalake"
labels_path = '/home/salyken/PRISMA/PRISMA_data/labels_csv'
xlsx_path = os.path.join(folder_path, '2023_02_22_Faubai_dataset_v1.xlsx')
# save_dir_chunks = '/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/chuncked_dataset'
save_dir_chunks = '/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/chuncked_dataset_patch_size_71'

### Loading Dataset from Preprocessed in the Faubai_preprocessing.ipynb notebook

In [None]:
class ChunkedDataset(Dataset):
    def __init__(self, chunk_dir, prefix="train", band_means=None, band_stds=None, augment=False, preload_all=False, clip_bands=False):
        self.chunk_paths = sorted([
            os.path.join(chunk_dir, f)
            for f in os.listdir(chunk_dir)
            if f.startswith(prefix) and f.endswith('.pt')
        ])
        self.clip_bands = clip_bands
        self.index_map = []
        if self.clip_bands:
            # Keep only the last 47 bands (HYPSO range)
            self.band_means = band_means[-47:].float() if band_means is not None else None
            self.band_stds = band_stds[-47:].float() if band_stds is not None else None
        else:
            self.band_means = band_means.float() if band_means is not None else None
            self.band_stds = band_stds.float() if band_stds is not None else None
        self.augment = augment
        self.preload_all = preload_all


        # Sorted from longest to shortest (already known)
        self.keep_indices = torch.tensor([
            16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
            26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
            36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
            46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
            56, 57, 58, 59, 60, 61, 62
        ])  # Corresponding to λ ≤ 802 nm

        self.full_data = []
        self.chunk_cache = {}

        for idx, path in enumerate(tqdm(self.chunk_paths, desc=f" Indexing {prefix} chunks")):
            data = torch.load(path, map_location='cpu')
            n = len(data['y'])
            self.index_map.extend([(idx, i) for i in range(n)])
            if self.preload_all:
                X = data['X'][:, self.keep_indices] if self.clip_bands else data['X']
                self.full_data.append((X.contiguous(), data['y'].contiguous()))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        chunk_idx, sample_idx = self.index_map[idx]

        if self.preload_all:
            X, y = self.full_data[chunk_idx]
        else:
            if chunk_idx not in self.chunk_cache:
                self.chunk_cache.clear()
                data = torch.load(self.chunk_paths[chunk_idx], map_location='cpu')
                X = data['X'][:, self.keep_indices] if self.clip_bands else data['X']
                self.chunk_cache[chunk_idx] = (X, data['y'])
            X, y = self.chunk_cache[chunk_idx]

        sample = X[sample_idx]
        label = y[sample_idx]

        if self.band_means is not None and self.band_stds is not None:
            sample = (sample - self.band_means[:, None, None]) / (self.band_stds[:, None, None] + 1e-6)

        if self.augment:
            sample = self.apply_augmentations(sample)

        return sample.contiguous(), label

    def apply_augmentations(self, x):
        if torch.rand(1).item() < 0.5:
            x = torch.flip(x, dims=[2])
        if torch.rand(1).item() < 0.5:
            x = torch.rot90(x, k=1, dims=[1, 2])
        if torch.rand(1).item() < 0.5:
            x = x + torch.randn_like(x) * 0.01
        return x


In [9]:
stats = torch.load('/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/mean_std/mean_std_71.pt')

# Access tensors
band_means = stats['band_means']
band_stds = stats['band_stds']

In [None]:
print(band_means.shape)

In [None]:
test_dataset = ChunkedDataset(
    chunk_dir= save_dir_chunks,
    prefix="test",
    band_means=band_means,
    band_stds=band_stds,
    augment=False,
    preload_all=True,
    clip_bands=False
)

### Evaluation

In [None]:

import time
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    accuracy_score,
    cohen_kappa_score
)
import matplotlib.pyplot as plt

def evaluate_model(model, test_loader, device, class_names=None, show_confusion=True, show_timing=True):
    model.eval()
    all_preds = []
    all_labels = []

    # Start timing
    if show_timing:
        start_time = time.time()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Stop timing
    if show_timing:
        end_time = time.time()
        total_time = end_time - start_time
        avg_time = total_time / len(test_loader.dataset)
        print(f"\n Test Time: {total_time:.2f} sec")
        print(f" Avg Inference Time per Sample: {avg_time:.6f} sec")

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Classification Report
    print("\n Classification Report:")
    print(classification_report(
        all_labels, all_preds,
        target_names=class_names if class_names else [f"Class {i}" for i in sorted(set(all_labels))]
    ))

    # Confusion Matrix
    if show_confusion:
        cm = confusion_matrix(all_labels, all_preds)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                      display_labels=class_names if class_names else [f"Class {i}" for i in sorted(set(all_labels))])
        disp.plot(cmap='Blues', values_format='d')
        plt.title("Confusion Matrix")
        plt.grid(False)
        plt.show()

    # === Compute Metrics ===
    oa = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    aa = np.mean(per_class_acc)
    kappa = cohen_kappa_score(all_labels, all_preds)

    print(f"\n Overall Accuracy (OA): {oa*100:.2f}%")
    print(f" Average Accuracy (AA): {aa*100:.2f}%")
    print(f" Kappa Coefficient (K×100): {kappa*100:.2f}")

    # Print per-class accuracy
    print("\n Per-Class Accuracy:")
    for i, acc in enumerate(per_class_acc):
        class_label = class_names[i] if class_names else f"Class {i}"
        print(f"{class_label}: {acc*100:.2f}%")

    return {
        "OA": oa,
        "AA": aa,
        "Kappa": kappa,
        "PerClassAccuracy": per_class_acc
    }


In [None]:
model = VGG16WithAttention(input_bands=63)

prisma_ckpt = "/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/model/VGG16_w_att_71_patch/best_val_acc.pth"

# Load checkpoint
ckpt = torch.load(prisma_ckpt, map_location=device)
state_dict = ckpt['model_state_dict']

# Remove 'module.' prefix from keys
clean_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

# Load into model
model.load_state_dict(clean_state_dict)

print(f" Loaded {len(clean_state_dict)} compatible layers from PRISMA")

print(f"Using device: {device}")
model.to(device)

test_loader = DataLoader(test_dataset, batch_size=64, num_workers=8, shuffle=False, pin_memory = True, persistent_workers=False)

evaluate_model(
    model=model,
    test_loader=test_loader,
    device=device,
    class_names=["Spruce", "Pine", "Deciduous"]
)