In [None]:
import sys

sys.path.append("/home/sayem/Desktop/deepHSI")  # Adjust to your project root path

from pathlib import Path
import numpy as np

# Custom module imports
from src.dataset.components.hyperspectral_dataset import HyperspectralDataset
from src.dataset.components.utils import *
from src.dataset.medical_datasets.bloodHSI import BloodDetectionHSIDataModule
from src.dataset.remote_sensing_datasets.paviaC import PaviaCDataModule
from src.dataset.remote_sensing_datasets.ksc import KSCDataModule
from src.models.hsi_classification_module import HSIClassificationLitModule
from src.models.components.simple_dense_net import HSIFCModel

# PyTorch and metrics imports
import torch
from torchmetrics import Precision, Recall, F1Score

# Importing from `lightning` instead of `pytorch_lightning`
import lightning as L
# from lightning import Trainer

torch.set_float32_matmul_precision('medium')

In [None]:
from scipy.io import loadmat

# Path to the .mat file
mat_file_path = '/home/sayem/Desktop/deepHSI/data/PaviaC/PaviaC/Pavia_gt.mat'

# Load the .mat file
data = loadmat(mat_file_path)

data

In [None]:
# Define the parameters for the data module
data_dir = '/home/sayem/Desktop/deepHSI/data'  # Specify the directory where you want the data to be downloaded

# Include 'batch_size', 'num_workers', and 'num_classes' within the hyperparams dictionary
hyperparams = {
    "batch_size": 64,
    "num_workers": 24,
    "patch_size": 10, 
    "center_pixel": True, 
    "supervision": "full",
    "num_classes": 10  # Define the number of classes in your dataset
}

# Assuming YourModel is defined elsewhere and num_classes is known
input_channels = 102

# Define custom metrics for the classification task using the updated hyperparams
custom_metrics = {
    "precision": Precision(num_classes=hyperparams["num_classes"], average='macro', task='multiclass'),
    "recall": Recall(num_classes=hyperparams["num_classes"], average='macro', task='multiclass'),
    "f1": F1Score(num_classes=hyperparams["num_classes"], average='macro', task='multiclass')
}

model = HSIFCModel(
    input_channels=input_channels,
    patch_size=hyperparams["patch_size"],  # Use patch_size from hyperparams
    n_classes=hyperparams["num_classes"],  # Use num_classes from hyperparams
    dropout=True
)

# Initialize the HSIClassificationLitModule with the model and other hyperparameters
hsi_module = HSIClassificationLitModule(
    net=model,
    optimizer='Adam',
    optimizer_params={"lr": 1e-5},
    num_classes=hyperparams["num_classes"],  # Use num_classes from hyperparams
    custom_metrics=custom_metrics
)

# # Initialize the PyTorch Lightning Trainer
# trainer = Trainer(max_epochs=10, precision='16-mixed', accelerator='gpu', devices=1)
max_epochs = 15

# Initialize the PaviaCDataModule with the updated arguments
pavia_c_datamodule = PaviaCDataModule(
    data_dir=data_dir,
    hyperparams=hyperparams  # Pass hyperparams which now includes num_classes
)

# Define the EarlyStopping callback
early_stop_callback = L.pytorch.callbacks.EarlyStopping(
    monitor='val/f1',  # Specify the metric to monitor
    patience=3,  # Number of epochs with no improvement after which training will be stopped
    verbose=True,  # Whether to print logs to stdout
    mode='max',  # In 'min' mode, training will stop when the quantity monitored has stopped decreasing
)

# Initialize the PyTorch Lightning Trainer with fast_dev_run enabled
trainer = L.Trainer(
    fast_dev_run=False,  # Enable fast_dev_run
    precision='16-mixed',  # Use 16-bit precision
    accelerator='auto',  # Specify the accelerator as GPU
    max_epochs = max_epochs,
    callbacks=[early_stop_callback],
)

# # # Prepare and set up the data module
# pavia_c_datamodule.prepare_data()
# pavia_c_datamodule.setup(stage='fit')

In [None]:
# Fit the model using the train dataset from the data module
# trainer.fit(hsi_module, pavia_c_datamodule.train_dataloader())  
trainer.fit(hsi_module, datamodule=pavia_c_datamodule)
# Use train_dataloader() instead of train_dataset