In [None]:
import sys
import deepHSI

from pathlib import Path

import numpy as np
from lightning.pytorch import Trainer, seed_everything

# Custom module imports
from deepHSI.datamodule.components import HyperspectralDataset
from deepHSI.datamodule.components.utils import *

from deepHSI.datamodule.medical_datasets import BloodDetectionHSIDataModule

from deepHSI.datamodule.remote_sensing_datasets import *
# from deepHSI.datamodule.remote_sensing_datasets.paviaC import PaviaCDataModule

from deepHSI.models.architectures import HSIFCModel, \
    SpectralSpatialCNN, HyperspectralCNNDetector
from deepHSI.models.task_algos import HSIClassificationModule

seed_everything(42, workers=True)

# Importing from `lightning` instead of `pytorch_lightning`
import lightning as L

# PyTorch and metrics imports
import torch
from torchmetrics import F1Score, Precision, Recall

# from lightning import Trainer

torch.set_float32_matmul_precision("medium")

In [None]:
from pathlib import Path

# Specify the directory to save checkpoints
ckpt_dir = Path("/home/sayem/Desktop/deepHSI/notebooks/ckpt")

# Function to clear directory

def clear_directory(path: Path):
    if path.is_dir():
        for item in path.iterdir():
            if item.is_dir():
                clear_directory(item)
                item.rmdir()
            else:
                item.unlink()
    else:
        path.mkdir(parents=True, exist_ok=True)


# Clear and/or create the log and checkpoint directories
clear_directory(ckpt_dir)

In [None]:
from deepHSI.models.task_algos import BaseModule

In [None]:
import torch
from torch.optim import Adam, lr_scheduler, SGD

# Define the parameters for the optimizer
optimizer_params = {
    'lr': 0.001,
    'betas': (0.9, 0.999),
    'eps': 1e-08,
    'weight_decay': 0,
}

# Define the parameters for the scheduler
scheduler_params = {
    'step_size': 10,
    'gamma': 0.1,
}

In [None]:
# Define the parameters for the data module
data_dir = "/home/sayem/Desktop/deepHSI/data"  # Specify the directory where you want the data to be downloaded

# Include 'batch_size', 'num_workers', and 'num_classes' within the hyperparams dictionary
hyperparams = {
    "batch_size": 64,
    "num_workers": 24,
    "patch_size": 10,
    "center_pixel": True,
    "supervision": "full",
    "num_classes": 10,  # Define the number of classes in your dataset
}

# Assuming YourModel is defined elsewhere and num_classes is known
channels = 102

# Define custom metrics for the classification task using the updated hyperparams
custom_metrics = {
    "precision": Precision(
        num_classes=hyperparams["num_classes"], average="macro", task="multiclass"
    ),
    "recall": Recall(num_classes=hyperparams["num_classes"], average="macro", task="multiclass"),
    "f1": F1Score(num_classes=hyperparams["num_classes"], average="macro", task="multiclass"),
}

In [None]:
from deepHSI.datamodule.transforms.hsi_transforms import *

# Instantiate individual transformations
normalize_transform = HSINormalize()
flip_transform = HSIFlip()
rotate_transform = HSIRotate()
noise_transform = HSISpectralNoise(mean=0.0, std=0.01)
shift_transform = HSISpectralShift(shift=2)
drop_transform = HSIRandomSpectralDrop(drop_prob=0.1)

# Compose transformations together
composed_transform = Compose([
    normalize_transform,
    flip_transform,
    # rotate_transform,
    noise_transform,
    shift_transform,
    drop_transform
])

In [None]:
# # # Initialize the PyTorch Lightning Trainer
# # trainer = Trainer(max_epochs=10, precision='16-mixed', accelerator='gpu', devices=1)
max_epochs = 200

# Initialize the PaviaCDataModule with the updated arguments
pavia_c_datamodule = PaviaCDataModule(
    data_dir=data_dir, hyperparams=hyperparams, \
    transform=composed_transform,  # Pass hyperparams which now includes num_classes
)

pavia_c_datamodule.prepare_data()
pavia_c_datamodule.setup()

In [None]:
batch = next(iter(pavia_c_datamodule.train_dataloader()))

In [None]:
batch[0].shape

## Test vae

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, channels, latent_dim):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        self.latent_dim = latent_dim
        # print(f"Defining fc_z with in_features=256, out_features={self.latent_dim * 2}")
        self.fc_z = nn.Linear(256, 40)


    def forward(self, x):
        # print(f"Input shape {x.shape}")
        x = x.squeeze(1)
        x = self.conv_layers(x)
        # print(f"After conv_layers shape: {x.shape}")
        x = x.view(x.size(0), -1)  # Flatten the tensor
        # print(f"After flatten shape: {x.shape}")
        z_params = self.fc_z(x)
        # print(f"z_params shape: {z_params.shape}")
        # z_params = torch.sigmoid(z_params)
        return z_params


class Decoder(nn.Module):
    def __init__(self, latent_dim, channels, init_height, init_width):
        super(Decoder, self).__init__()
        self.init_height = init_height  # Height dimension before the last flatten operation in the encoder
        self.init_width = init_width  # Width dimension before the last flatten operation in the encoder
        
        # Fully connected layer to map from the latent dimension to the spatial dimensions before the last flatten in the encoder
        self.fc = nn.Linear(in_features=latent_dim, out_features=64 * self.init_height * self.init_width)

        # Sequential container for the transposed convolutional layers
        self.conv_transpose_layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=channels, kernel_size=3, \
                stride=2, padding=1, output_padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        # print(f"Latent vector shape before fc layer: {x.shape}")
        x = self.fc(x)
        # print(f"Shape after fc layer: {x.shape}")
        x = x.view(-1, 64, self.init_height, self.init_width)
        # print(f"Shape after reshape for conv layers: {x.shape}")

        x_loc = self.conv_transpose_layers(x)
        # print(f"Shape after conv_transpose_layers: {x_loc.shape}")
        # print(f"x_loc: {x_loc.shape}")
        # Resize x_loc to ensure it matches init_height and init_width
        x_loc_resized = F.interpolate(x_loc, size=(self.init_height, self.init_width), \
            mode='bilinear', align_corners=False)
        # print(x_loc_resized.shape)
        # Add the singleton depth dimension back and permute to match desired shape
        x_loc_resized = x_loc_resized.unsqueeze(2)
        x_loc_final = x_loc_resized.permute(0, 2, 1, 3, 4)

        # print(x_loc_final)

        x_loc_final = torch.sigmoid(x_loc_final)

        # print(x_loc_final)

        return x_loc_final

In [None]:
## TODO: should be moved to setup(self, stage='fit)
# Example usage with dummy data
channels=102
w, h = 10, 10 # h, w of the patch
latent_dim = 20
x = torch.randn(64, 1, channels, w, h)  # [batch, channel, depth=1, w, h]
print(f"Input shape: {x.shape}")

encoder = Encoder(channels, latent_dim)
decoder = Decoder(latent_dim=latent_dim, \
    channels=channels, init_height=h, init_width=w)

from deepHSI.models.task_algos import VAEPyroModule

vae = VAEPyroModule(encoder, decoder, latent_dim)

from deepHSI.models.task_algos import VAEModule

vae = VAEModule(
    encoder=encoder,
    decoder=decoder,
    latent_dim=latent_dim,
    optimizer_constructor=Adam,
    optimizer_params=optimizer_params,
    scheduler_constructor=lr_scheduler.StepLR,
    scheduler_params=scheduler_params,
    )

In [None]:
# import pyro
# from pyro.infer import SVI, Trace_ELBO
# from pyro.optim import Adam

# # # Set up the optimizer
# # adam_params = {"lr": 0.001}
# # optimizer = Adam(adam_params)

# # # # Set up the inference algorithm with the ELBO loss
# # # svi = SVI(vae.vae.model, vae.vae.guide, optimizer, loss=Trace_ELBO())
# # # Set up the inference algorithm with the ELBO loss
# # svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

# # Number of training epochs
# num_epochs = 10
# data_loader = pavia_c_datamodule.train_dataloader()

In [None]:
# import torch
# import pyro
# from pyro.infer import SVI, Trace_ELBO
# from torch.optim import Adam
# from torch.utils.data import DataLoader
# from pyro.optim import PyroLRScheduler, PyroOptim

# # Assuming 'vae' is your VAEPyroModule instance
# # Assuming 'train_loader' is your DataLoader instance for training data

# # optimizer = torch.optim.SGD

# # scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, \
# #     'optim_args': {'lr': 0.01}, 'gamma': 0.1})
# # # Initialize SVI with the model, guide, PyroLRScheduler, and loss function
# # svi = SVI(vae.model, vae.guide, scheduler, loss=Trace_ELBO())

In [None]:
# import torch
# from torch.optim import Adam
# from torch.optim.lr_scheduler import StepLR
# from pyro.optim import PyroOptim
# from pyro.optim.lr_scheduler import PyroLRScheduler
# from pyro.infer import SVI, Trace_ELBO
# from pyro.infer import TraceGraph_ELBO

# optimizer = torch.optim.SGD
# scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, \
#     'optim_args': {'lr': 0.01}, 'gamma': 0.1})
# svi = SVI(vae.model, vae.guide, scheduler, loss=TraceGraph_ELBO())

In [None]:
# from tqdm import tqdm

# # Training loop
# for epoch in range(num_epochs):
#     epoch_loss = 0
#     with tqdm(total=len(data_loader)) as progress_bar:
#         for data, _ in data_loader:  # Assuming you're not using labels in your VAE
#             if torch.cuda.is_available():
#                 data = data.cuda()  # Move your data to GPU

#             loss = svi.step(data)
#             epoch_loss += loss

#             # Update tqdm progress bar
#             progress_bar.update(1)
#             progress_bar.set_postfix(loss=loss / len(data))
#         scheduler.step()

#     # After each epoch, report the loss
#     print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {epoch_loss / len(data_loader)}")

# print("Training complete!")

In [None]:
# STOP

In [None]:
# import torch
# import pyro
# from pyro.infer import Predictive

# # Define the number of samples you want to generate
# num_samples = 10

# # Assume your model is trained, and you have the optimized parameters
# # and guide function ready

# # Create a predictive object
# predictive = Predictive(vae.model, \
#     guide=vae.guide, num_samples=num_samples, return_sites=("latent", "obs"))

# # You can use dummy data with the same shape as your training data to run the predictive model
# # The actual values don't matter, it's just to provide the right shape
# dummy_input, label = next(iter(data_loader)) # torch.randn([64, 1, 102, 10, 10])  # Adjust this shape as per your model's input

# # Get samples
# samples = predictive(dummy_input)

# # `samples` is a dictionary where keys are the names of the sampled sites
# # and values are tensors of shape (num_samples, original_site_shape)

# # For instance, to get the sampled observations, you can do:
# sampled_obs = samples["obs"]

# sampled_obs.shape

In [None]:
import os
import wandb
from pytorch_lightning.loggers import WandbLogger

# Assuming the BaseModule, optimizer, and scheduler are initialized as follows:
# module = BaseModule(...)

# Extract the optimizer and scheduler names
optimizer_id = vae.optimizer_constructor.__name__  # Assuming optimizer_constructor is stored in the VAEModule
scheduler_id = 'None' if vae.scheduler_constructor is None else vae.scheduler_constructor.__name__

# Dynamically get the model name - since VAEModule doesn't directly have a 'model' attribute, we use its class name
model_name = vae.__class__.__name__  # Using the class name of the VAEModule itself


# Set the WandB notebook name for better organization
os.environ["WANDB_NOTEBOOK_NAME"] = "1.0-hsi-initial-data-exploration.ipynb"
wandb.login()

# Dynamically construct the run name to include key experiment details
dynamic_run_name = f"{optimizer_id}-with-{scheduler_id}-{model_name}"

# Extend the dynamic tags to include other important aspects of the experiment
dynamic_tags = [
    optimizer_id, scheduler_id, model_name,
    # Add more tags as needed, e.g., dataset name, special model configurations
]

print(dynamic_tags)

# Include additional notes for the experiment to describe unique aspects or goals
experiment_notes = f"Testing {optimizer_id} optimizer with \
    {scheduler_id} scheduler on {model_name}."

# Initialize WandbLogger with dynamic configurations
wandb_logger = WandbLogger(
    name=dynamic_run_name,
    project="PaviaC",
    save_dir="/home/sayem/Desktop/deepHSI/notebooks/wandb",
    offline=False,
    tags=dynamic_tags,
    notes=experiment_notes,  # Add the experiment notes
)

# # Assume `hyperparams` is a dictionary containing your experiment-specific hyperparameters
# # Include additional hyperparameters related to the optimizer and scheduler
# hyperparams = {
#     "learning_rate": 0.001,  # Example hyperparameter
#     # Include other experiment-specific hyperparameters
# }

# Update the WandB config with a comprehensive set of experiment details
wandb_logger.experiment.config.update({
    "optimizer": optimizer_id,
    "scheduler": scheduler_id,
    # "model": model_name,
    **hyperparams,
})

In [None]:
# Callbacks: Pytorch Callbacks: See L doc
# Define the EarlyStopping callback
early_stop_callback = L.pytorch.callbacks.EarlyStopping(
    monitor="val/f1",  # Specify the metric to monitor
    patience=20,  # Number of epochs with no improvement after which training will be stopped
    verbose=True,  # Whether to print logs to stdout
    mode="max",  # In 'min' mode, training will stop when the quantity monitored has stopped decreasing
    check_on_train_epoch_end=False,
)

# from lightning.pytorch.callbacks import ModelCheckpoint

# Define the ModelCheckpoint callback
model_checkpoint = L.pytorch.callbacks.ModelCheckpoint(
    monitor="val/f1",  # Metric to monitor
    dirpath=str(ckpt_dir),  # Convert Path object to string, Directory to save checkpoints
    filename="best-checkpoint-{epoch:02d}-{val/f1:.2f}",  # Checkpoint file name
    save_top_k=1,  # Save only the best checkpoint
    mode="max",  # 'max' because we want to maximize 'val/f1'
    verbose=True,  # Print a message when a new best is found
    auto_insert_metric_name=False,  # Prevents metric names being inserted into filename automatically
)

rich_pbar_callback = L.pytorch.callbacks.RichProgressBar(
    refresh_rate=1,
    leave=True,
)

lr_monitor_callback = L.pytorch.callbacks.LearningRateMonitor(logging_interval='epoch') 

In [None]:
# Initialize the PyTorch Lightning Trainer with fast_dev_run enabled
trainer = L.Trainer(
    fast_dev_run=False,  # Enable fast_dev_run
    precision="16-mixed",  # Use 16-bit precision
    accelerator="auto",  # Specify the accelerator as GPU
    max_epochs=max_epochs,
    log_every_n_steps=3,
    callbacks=[
        # lr_finder_callback,
        # early_stop_callback,
        model_checkpoint,
        # confusion_matrix_callback,
        # batch_finder_callback,
        # lr_monitor_callback,
        # model_weight_logger_callback,
    ],  # rich_pbar_callback],
    logger=wandb_logger,
    deterministic=False,
)

In [None]:
trainer.fit(vae, datamodule=pavia_c_datamodule)

In [None]:
# Fit the model using the train dataset from the data module
dictionary = trainer.test(hsi_classifier, pavia_c_datamodule, verbose=True)
# trainer.fit(hsi_module, datamodule=pavia_c_datamodule)tree
# Use train_dataloader() instead of train_dataset

In [None]:
# After training or experiments are done, finish the WandB run
wandb.finish()