### Packages and Libraries

In [None]:
# Choose available CUDAs for parallell computing
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,5"
print("This notebook's PID:", os.getpid())

In [None]:
import os
import numpy as np
import pandas as pd
import glob

import rasterio
from rasterio.transform import from_origin

import h5py

import matplotlib.pyplot as plt
import matplotlib.animation as animation

from skimage.transform import resize
from skimage import exposure

from tqdm import tqdm

from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset, Subset, random_split
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16_bn, VGG16_BN_Weights

from collections import defaultdict

import gc

from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, ConfusionMatrixDisplay, classification_report
from sklearn.utils.class_weight import compute_class_weight
from scipy.stats import mode

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: {device}.")

Current device: cuda.


### Paths

In [None]:
folder_path = "/home/_shared/ARIEL/Faubai/"
test_folder_path = '/home/_shared/ARIEL/Faubai/TEST'
he5_directory = "/home/_shared/ARIEL/Faubai/datalake"
labels_path = '/home/salyken/PRISMA/PRISMA_data/labels_csv'
npy_cubes_dir = "/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/npy_cubes"

xlsx_path = os.path.join(folder_path, '2023_02_22_Faubai_dataset_v1.xlsx')

### Calculate the spatial patch sizes with respect to HSI missions' GSDs

In [None]:
prisma_image_path = '/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/npy_cubes/PRS_L1_STD_OFFL_20200421110225_20200421110229_0001.npy'
prisma_image = np.load(prisma_image_path) 


hypso_image_path = '/home/salyken/PRISMA/HYPSO_data/cube/trondheim_2024-05-24T09-50-09Z_l1d_cube.npy' 
hypso_image = np.load(hypso_image_path)

In [None]:
def generate_grid_centers(image_shape, stride=32, margin=0):
    """
    Generate center points in a grid across the image.

    Args:
        image_shape (tuple): Shape of the image (H, W, Bands)
        stride (int): Distance in pixels between patch centers
        margin (int): Padding from the edge of the image

    Returns:
        List of (row, col) tuples
    """
    H, W = image_shape[:2]
    centers = []

    for row in range(margin, H - margin, stride):
        for col in range(margin, W - margin, stride):
            centers.append((row, col))
    return centers


In [None]:
import numpy as np

def extract_patches_by_gsd(
    image: np.ndarray,
    centers: list,
    gsd: float,
    ground_coverage: float,
    pad_mode: str = 'reflect'
):
    """
    Extract square patches from a hyperspectral image based on ground sampling distance (GSD) and desired coverage.

    Args:
        image (np.ndarray): Hyperspectral image of shape (H, W, Bands)
        centers (list): List of (row, col) center coordinates for patches
        gsd (float): Ground sampling distance in meters/pixel
        ground_coverage (float): Desired size of patch in meters (patch will be approx. this size on the ground)
        pad_mode (str): Padding mode for boundary patches ('reflect', 'constant', etc.)

    Returns:
        List[np.ndarray]: List of extracted patches (each patch has shape (patch_size, patch_size, Bands))
    """
    patches = []
    patch_size = int(np.round(ground_coverage / gsd))

    if patch_size % 2 == 0:
        patch_size += 1  # Make patch size odd so we can center it perfectly

    half_size = patch_size // 2
    padded_image = np.pad(image, ((half_size, half_size), (half_size, half_size), (0, 0)), mode=pad_mode)

    for (row, col) in centers:
        row += half_size
        col += half_size
        patch = padded_image[row - half_size: row + half_size + 1,
                             col - half_size: col + half_size + 1,
                             :]
        patches.append(patch)

    return patches


In [28]:
from math import ceil

prisma_gsd = 30
hypso_gsd = 142
coverage = 2100  # meters

prisma_patch_size = ceil(coverage / prisma_gsd) 
hypso_patch_size = ceil(coverage / hypso_gsd)    

if prisma_patch_size % 2 == 0:
    prisma_patch_size += 1  
if hypso_patch_size % 2 == 0:
    hypso_patch_size += 1  

print(f"PRISMA patch size: {prisma_patch_size}x{prisma_patch_size}")
print(f"HYPSO patch size: {hypso_patch_size}x{hypso_patch_size}")

# For PRISMA
prisma_centers = generate_grid_centers(prisma_image.shape, prisma_patch_size)
prisma_patches = extract_patches_by_gsd(prisma_image, prisma_centers, prisma_gsd, coverage)

# For HYPSO
hypso_centers = generate_grid_centers(hypso_image.shape, hypso_patch_size)
hypso_patches = extract_patches_by_gsd(hypso_image, hypso_centers, hypso_gsd, coverage)


print(prisma_patch_size)
print(hypso_patch_size)

PRISMA patch size: 71x71
HYPSO patch size: 15x15
71
15


### Filter Files

In [None]:
def file_identifier(file_path): 

        # file_path = os.path.join(folder_path, file)
        
        if file_path.endswith(".xlsx"):
            df = pd.read_excel(file_path, engine="openpyxl")
            print(f"Opened Excel file: {file_path}")
            return df
    
        elif file_path.endswith(".tif"):
            with rasterio.open(file_path) as src:
                print(f"Opened TIFF file: {file_path}, Shape:", src.read(3).shape)
                img = src.read(1)  # Read the first band (1-based index)
                
            # Display the image
            plt.imshow(img, cmap="gray")
            plt.colorbar()
            plt.title("GeoTIFF - Single Band")
            plt.xlabel("Width (X)")
            plt.ylabel("Height (Y)")
            plt.show()

    
        elif file_path.endswith(".mat"):
            mat_data = loadmat(file_path)
            print(f"Opened MAT file: {file_path}, Keys:", mat_data.keys())
    
        else:
            print(f"Skipping unknown file: {file_path}")

xlxs_ex = os.path.join(folder_path, '2023_02_22_Faubai_dataset_v1.xlsx')
updated_list = os.path.join(folder_path, 'updated_list.xlsx') 
tif_ex = os.path.join(folder_path, 'TEST', 'PRS_L1_STD_OFFL_20200525105010_20200525105014_0001.tif')
                       

df_excel = file_identifier(xlxs_ex)
df_updated_list = file_identifier(updated_list)
# tif_data = file_identifier(tif_ex)


Opened Excel file: /home/_shared/ARIEL/Faubai/2023_02_22_Faubai_dataset_v1.xlsx
Opened Excel file: /home/_shared/ARIEL/Faubai/updated_list.xlsx


In [11]:


he5_files = sorted([os.path.join(he5_directory, f) for f in os.listdir(he5_directory) if f.endswith(".he5")])

print(f"Found {len(he5_files)} .he5 files.")

csv_files = sorted([os.path.join(labels_path, f) for f in os.listdir(labels_path) if f.endswith(".csv")])

print(f"Found {len(csv_files)} .csv files.")


target_cols = ["pine", "spruce", "deciduous", "water", "cloudsnow"]

he5_basenames = [os.path.basename(f).replace(".he5", "") for f in he5_files]

xlsx_filtered = df_excel[df_excel["name"].astype(str).isin(he5_basenames)]

print(f"Found {len(xlsx_filtered)} data entries in xlsx file.")



Found 104 .he5 files.
Found 104 .csv files.
Found 104 data entries in xlsx file.


### Cut Non-Informative Bands and Save HSI Cubes

In [None]:

def load_and_save_vnir_cube(name, image_dir, output_dir):
    input_path = os.path.join(image_dir, name + ".he5")
    output_path = os.path.join(output_dir, name + ".npy")

    with h5py.File(input_path, 'r') as f:
        data = f['HDFEOS']['SWATHS']['PRS_L1_HCO']['Data Fields']['VNIR_Cube'][()]
    if data.shape[1] == 66:
        data = np.transpose(data, (0, 2, 1))
    elif data.shape[0] == 66:
        data = np.transpose(data, (1, 2, 0))
    data = data.astype(np.float32) / 65535.0
    data = data[:, :, 3:]  # Keep only 63 bands
    np.save(output_path, data)

# Run this once
df =  xlsx_filtered # same filtered_df used before
image_names = df['name'].tolist()

input_dir = he5_directory
output_dir = "/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/npy_cubes"
os.makedirs(output_dir, exist_ok=True)

for name in tqdm(image_names, desc="Preprocessing cubes"):
    load_and_save_vnir_cube(name, input_dir, output_dir)


### Make patch dataset for Faubai

In [None]:

def _collect_patch_indices_for_one(args):
    name, label_dir, half, stride, valid_labels = args
    pattern = os.path.join(label_dir, f"{name}_*labels.csv")
    matches = glob.glob(pattern)
    if not matches:
        return []

    label_path = matches[0]
    df = pd.read_csv(label_path, header=None)
    label_array = df[0].apply(lambda x: list(map(int, str(x).split()))).tolist()
    label = np.array(label_array)

    h, w = label.shape
    index_map = []
    for i in range(half, h - half, stride):
        for j in range(half, w - half, stride):
            patch_labels = label[i - half:i + half + 1, j - half:j + half + 1]
            valid = patch_labels[np.isin(patch_labels, valid_labels)]
            if valid.size > 0:
                index_map.append((name, i, j))
    return index_map

class HyperspectralForestDataset(Dataset):
    _cube_cache = {}
    _label_cache = {}

    def __init__(
        self,
        filtered_df,
        image_dir,
        label_dir,
        patch_size=71,
        stride=15,
        valid_labels = tuple(range(20)),  # if you don't know all of them yet,
        majority_label=False,
    ):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.patch_size = patch_size
        self.stride = stride
        self.valid_labels = valid_labels
        self.majority_label = majority_label
        self.half = patch_size // 2

        self.image_names = filtered_df['name'].tolist()
        self.index_map = self._collect_patch_indices()

    def _load_vnir_cube(self, filepath):
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"VNIR cube not found: {filepath}")
        return np.load(filepath)

    def _load_label_mask(self, name):
        pattern = os.path.join(self.label_dir, f"{name}_*labels.csv")
        matches = glob.glob(pattern)
        if not matches:
            raise FileNotFoundError(f"No label file found for: {name}")
        label_path = matches[0]
        df = pd.read_csv(label_path, header=None)
        label_array = df[0].apply(lambda x: list(map(int, str(x).split()))).tolist()
        return np.array(label_array)

    def _collect_patch_indices(self):
        print(" Starting parallel patch collection...")
        args_list = [
            (name, self.label_dir, self.half, self.stride, self.valid_labels)
            for name in self.image_names
        ]
        with ProcessPoolExecutor(max_workers=32) as executor:
            results = list(tqdm(executor.map(_collect_patch_indices_for_one, args_list), total=len(args_list)))
        return [item for sublist in results for item in sublist]

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        name, i, j = self.index_map[idx]

        if name not in self._cube_cache:
            image_path = os.path.join(self.image_dir, name + ".npy")
            self._cube_cache[name] = self._load_vnir_cube(image_path)
        cube = self._cube_cache[name]

        if name not in self._label_cache:
            self._label_cache[name] = self._load_label_mask(name)
        label = self._label_cache[name]

        patch = cube[i - self.half:i + self.half + 1, j - self.half:j + self.half + 1, :]
        patch = torch.from_numpy(patch).permute(2, 0, 1)  # (Bands, H, W)

        if self.majority_label:
            patch_labels = label[i - self.half:i + self.half + 1, j - self.half:j + self.half + 1]
            valid = patch_labels[np.isin(patch_labels, self.valid_labels)]
            if valid.size == 0:
                raise ValueError(f"No valid labels in patch at {name}, {i}, {j}")
            label_val = int(mode(valid, axis=None).mode.item())
        else:
            label_val = label[i, j]

        # PRISMA: 0:'other', 1:'spruce', 2:'pine', 3:'deciduous', 4:'water', 5:'cloudsnow'
        # mapped_label = {1: 0, 2: 1, 3: 2}.get(label_val, 3)  # Map 1–3 normally, others → 3
        mapped_label = {1: 0, 2: 1, 3: 2}[label_val] # Only using 1-3 labels, discard the rest

        return patch, torch.tensor(mapped_label).long()


In [None]:
dataset = HyperspectralForestDataset(
    filtered_df= xlsx_filtered,
    image_dir=npy_cubes_dir,
    label_dir=labels_path,
    patch_size=71,
    stride=15,
    majority_label=True  # enable majority labeling here
)


In [None]:
# Checking
labels = [dataset[i][1].item() for i in range(0,2000)]
print("Unique labels actually used in patches:", set(labels))

x, y = dataset[4000]
print(f"Patch shape: {x.shape}")  # Should be (63, 71, 71)
print(f"Label: {y}")

### Split Into Train/Test/Validation

In [None]:
lengths = [0.7, 0.15, 0.15]
total = len(dataset)
lengths = [int(total * l) for l in lengths]
lengths[-1] = total - sum(lengths[:-1])  # fix rounding

train_set, val_set, test_set = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))

In [None]:
print(len(train_set))
print(len(test_set))
print(len(val_set))

### Save Dataset Into Chunks

In [None]:
def save_dataset_in_chunks(dataset, save_dir, prefix="train", batch_size=1000):
    os.makedirs(save_dir, exist_ok=True)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=False)
    chunk_idx = 0

    print(f" Saving '{prefix}' batches one at a time to: {save_dir}")
    for batch in tqdm(loader, desc=f"Saving {prefix} chunks"):
        X, y = batch
        chunk_path = os.path.join(save_dir, f"{prefix}_chunk_{chunk_idx}.pt")
        torch.save({'X': X, 'y': y}, chunk_path)
        chunk_idx += 1

    print(f" Finished saving {chunk_idx} chunks for '{prefix}' to: {save_dir}")


In [None]:
# Check data

loader = DataLoader(train_set, batch_size=1, num_workers=0)
for X, y in loader:
    print(X.shape, y)
    break


torch.Size([1, 63, 71, 71]) tensor([2])


In [None]:
import warnings
warnings.filterwarnings("ignore", message=".*can only test a child process.*")

save_chunks_dir = '/home/salyken/PRISMA/PRISMA_data/PRISMA_dataset_processed/chuncked_dataset_patch_size_71'

save_dataset_in_chunks(train_set, save_chunks_dir, prefix="train", batch_size=1000)
save_dataset_in_chunks(val_set, save_chunks_dir, prefix="val", batch_size=1000)
save_dataset_in_chunks(test_set, save_chunks_dir, prefix="test", batch_size=1000)