In [11]:
import os
from anndata import read_h5ad
import numpy as np
import pandas as pd
import os
import time
from tqdm import tqdm

import sys
# Add the src/ directory as one where we can import modules
src_dir = "../src"
sys.path.append(src_dir)
from utils import VisiumClassificationDataset
from dkl import initial_values, GP, DKL, GP_Natural

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
# from torchvision.io import read_image
# from torchvision.transforms import ToTensor
from torchvision.models import resnet34, ResNet34_Weights

from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import SoftmaxLikelihood
from gpytorch import settings as gpytorch_settings
from gpytorch.optim import NGD

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import matplotlib.pyplot as plt

In [2]:
# update the location where models will be saved to
if torch.hub.get_dir() == '/clusterdata/uqsmac12/.cache/torch/hub':
    torch.hub.set_dir('/scratch/smp/uqsmac12/.cache/torch/hub')

In [3]:
DIR_DATA = '/scratch/smp/uqsmac12/stimage2_data'

# DIR_WANDB = DIR_DATA
DIR_TILES = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
DIR_ANNDATA_PROCESSED = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
file_processed_alex_data = 'all_adata.h5ad'
# DIR_PROCESSED_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/PROCESSED/dataset_breast_cancer_9visium'
DIR_RAW_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/RAW/Alex_NatGen_6BreastCancer'
DIR_RAW_METADATA = os.path.join(DIR_RAW_DATA, 'metadata')

DIR_CHECKPOINTS = os.path.join(DIR_DATA, 'checkpoints/')
# location to save data
DIR_PROCESSED_DATASET = os.path.join(DIR_DATA, 'data_processed')
FILE_PROCESSED_VISIUM9 = os.path.join(DIR_PROCESSED_DATASET, 'df_adata_rna_logcpt_images_labels_visium9.csv')

In [4]:
# load csv
df_data = pd.read_csv(FILE_PROCESSED_VISIUM9, index_col='Unnamed: 0.1')

  df_data = pd.read_csv(FILE_PROCESSED_VISIUM9, index_col='Unnamed: 0.1')


In [5]:
# drop instances without cancer labels
df_data = df_data[df_data['cancer_class'].notna()]

In [6]:
# split into train val and test datasets
df_test = df_data[df_data['library'] == 'CID4465']
df_train = df_data[df_data['library'] != 'CID4465']

In [7]:
# shuffle the rows preserving unique set of instances 
df_train = df_train.sample(frac=1, replace=False, random_state=42)

In [8]:
df_val = df_train.iloc[:1000, :].copy()
df_train = df_train.iloc[1000:, :]

In [9]:
# get transformations
transform_scale = transforms.Lambda(lambda x: x / 255.)
transform_normalise = ResNet34_Weights.DEFAULT.transforms()
composed_transforms = transforms.Compose([transform_scale, transform_normalise])
# preprocess = weights.transforms()

In [10]:
# dataset objects
dataset_train = VisiumClassificationDataset(df_train, transform=composed_transforms)
dataset_val = VisiumClassificationDataset(df_val, transform=composed_transforms)
dataset_test = VisiumClassificationDataset(df_test, transform=composed_transforms)

In [11]:
# dataloaders
batch_size = 384
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, )
dataloader_val = DataLoader(dataset_val, batch_size=batch_size)
dataloader_test = DataLoader(dataset_test)

In [12]:
## load pretrained model and reset final fully connected layer
feature_extractor = resnet34(weights=ResNet34_Weights.DEFAULT)
num_final_fc_in = feature_extractor.fc.in_features # get number of features
# replace final layer
num_features_out = 128
feature_extractor.fc = nn.Linear(num_final_fc_in, num_features_out)

In [13]:
num_inducing_points = 200
initial_inducing_points, initial_lengthscale = initial_values(
    dataset_train, feature_extractor, num_inducing_points,
)



In [12]:
kernel = 'Matern52' # Matern52 # RBF # Matern32
num_classes = 2
gp = GP(
# gp = GP_Natural(
    num_outputs=num_classes,
    initial_lengthscale=initial_lengthscale,
    initial_inducing_points=initial_inducing_points,
    kernel=kernel,
)

NameError: name 'initial_lengthscale' is not defined

In [15]:
model = DKL(feature_extractor, gp)

In [16]:
likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
likelihood = likelihood.cuda()

In [17]:
elbo_fn = VariationalELBO(likelihood, gp, num_data=len(dataset_train))
loss_fn = lambda x, y: -elbo_fn(x, y)

In [28]:
# optimizer = torch.optim.Adam(
#     model.parameters(),
#     lr=1e-4,
#     weight_decay=1e-2,
# )
variational_ngd_optimizer = NGD(
    model.variational_parameters(), 
    num_data=len(dataset_train), 
    lr=1e-4
)
hyperparameter_optimizer = torch.optim.Adam([
    {'params': model.hyperparameters()},
    {'params': likelihood.parameters()},
], lr=1e-4)

In [23]:
if variational_ngd_optimizer:
    print('True')

True


In [24]:
exp_lr_scheduler = optim.lr_scheduler.StepLR(hyperparameter_optimizer, step_size=10, gamma=0.2)

In [25]:
dataloaders = {'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}
dataset_sizes = {'train': len(dataloader_train.dataset), 'val': len(dataloader_val.dataset), 'test': len(dataloader_test.dataset)}

In [30]:
def train_model(model, criterion, optimizer, scheduler, dir_best_model_params=DIR_CHECKPOINTS, num_epochs=50, optimizer_variational_ngd=None):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    best_model_params_path = os.path.join(dir_best_model_params, 'best_model_params.pt')
    torch.save(model.state_dict(), best_model_params_path)
    best_acc = 0.0
    model = model.cuda()
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase], total=len(dataloaders[phase]) , leave = False):
                inputs = inputs.cuda()
                labels = labels.cuda()
                
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # zero the parameter gradients
                    optimizer.zero_grad()
                    pred_MVN = model(inputs)
                    # Convert a … x N MVN distribution into a batch of independent Normal distributions. 
                    # Essentially, this throws away all covariance information and treats all dimensions as batch dimensions.
                    pred_MVN = pred_MVN.to_data_independent_dist() 
                    
                    with gpytorch_settings.num_likelihood_samples(15):
                        # the average of samples from the likelihood; matrix dims=(batch_size, num_classes) 
                        pred_mean_prob = likelihood(pred_MVN).probs.mean(0)
                        # predicted index/class_id come from 2nd index position from the output tupple
                        pred_class_id = pred_mean_prob.max(dim=1)[1]
                    # criterion is the (negative) variational elbo corresponding the the SoftMaxLikelihood
                    loss = criterion(pred_MVN, labels)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        # get the gradients
                        loss.backward()
                        # optimize 
                        optimizer.step()
                        ### ### ### ### ### ### ### ### ### ### ### ### 
                        # potential second step for NGD
                        if optimizer_variational_ngd:
                            variational_ngd_optimizer.zero_grad()
                            pred_MVN = model(inputs)
                            loss = criterion(pred_MVN, labels)
                            loss.backward()
                            optimizer_variational_ngd.step()
                        ### ### ### ### ### ### ### ### ### ### ### ### 
                        

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(pred_class_id == labels.data)
            if phase == 'train':
                # step the learning rate schedule
                scheduler.step() 

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), best_model_params_path)
        
        # epoch ended share the time.
        time_elapsed = time.time() - since
        print(f'Epoch {epoch} finished at time {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(torch.load(best_model_params_path))
    
    return model

In [27]:
model = train_model(model, loss_fn, hyperparameter_optimizer, exp_lr_scheduler, optimizer_variational_ngd=variational_ngd_optimizer)

Epoch 0/49
----------


                                               

train Loss: 0.2984 Acc: 0.8615


                                             

val Loss: 44.0485 Acc: 0.9080
Epoch 0 finished at time 3m 28s
Epoch 1/49
----------


                                               

train Loss: 0.0815 Acc: 0.9659


                                             

val Loss: 37.6067 Acc: 0.9280
Epoch 1 finished at time 6m 50s
Epoch 2/49
----------


                                               

train Loss: 0.0243 Acc: 0.9915


                                             

val Loss: 42.0358 Acc: 0.9290
Epoch 2 finished at time 10m 13s
Epoch 3/49
----------


                                               

train Loss: 0.0154 Acc: 0.9938


                                             

val Loss: 43.3935 Acc: 0.9270
Epoch 3 finished at time 13m 33s
Epoch 4/49
----------


                                               

train Loss: 0.0131 Acc: 0.9934


                                             

val Loss: 45.5707 Acc: 0.9270
Epoch 4 finished at time 16m 52s
Epoch 5/49
----------


                                               

train Loss: 0.0115 Acc: 0.9959


                                             

val Loss: 66.9438 Acc: 0.9040
Epoch 5 finished at time 20m 11s
Epoch 6/49
----------


                                               

train Loss: 0.0109 Acc: 0.9955


                                             

val Loss: 41.5599 Acc: 0.9300
Epoch 6 finished at time 23m 32s
Epoch 7/49
----------


                                               

train Loss: 0.0106 Acc: 0.9974


                                             

val Loss: 50.4073 Acc: 0.9240
Epoch 7 finished at time 26m 48s
Epoch 8/49
----------


                                               

train Loss: 0.0100 Acc: 0.9981


                                             

val Loss: 44.9913 Acc: 0.9250
Epoch 8 finished at time 30m 9s
Epoch 9/49
----------


                                               

KeyboardInterrupt: 

In [37]:
model = train_model(model, loss_fn, optimizer, exp_lr_scheduler)

Epoch 0/99
----------


                                               

train Loss: 98.7711 Acc: 0.8239


                                             

val Loss: 93.1365 Acc: 0.8090
Epoch 0 finished at time 2m 60s
Epoch 1/99
----------


                                               

train Loss: 93.8696 Acc: 0.8480


                                             

val Loss: 88.7781 Acc: 0.8140
Epoch 1 finished at time 5m 53s
Epoch 2/99
----------


                                               

train Loss: 89.9032 Acc: 0.8710


                                             

val Loss: 83.2975 Acc: 0.8660
Epoch 2 finished at time 8m 50s
Epoch 3/99
----------


                                               

train Loss: 86.0231 Acc: 0.8877


                                             

val Loss: 80.5704 Acc: 0.8780
Epoch 3 finished at time 11m 47s
Epoch 4/99
----------


                                               

train Loss: 82.5547 Acc: 0.8974


                                             

val Loss: 78.1572 Acc: 0.8830
Epoch 4 finished at time 14m 43s
Epoch 5/99
----------


                                               

train Loss: 79.5191 Acc: 0.9076


                                             

val Loss: 80.7858 Acc: 0.8570
Epoch 5 finished at time 17m 43s
Epoch 6/99
----------


                                               

train Loss: 76.8293 Acc: 0.9192


                                             

val Loss: 75.2085 Acc: 0.9020
Epoch 6 finished at time 20m 37s
Epoch 7/99
----------


                                               

train Loss: 73.0879 Acc: 0.9298


                                             

val Loss: 69.6625 Acc: 0.9060
Epoch 7 finished at time 23m 39s
Epoch 8/99
----------


                                               

train Loss: 70.3950 Acc: 0.9368


                                             

val Loss: 71.1320 Acc: 0.8990
Epoch 8 finished at time 26m 34s
Epoch 9/99
----------


                                               

train Loss: 68.2058 Acc: 0.9429


                                             

val Loss: 69.3369 Acc: 0.9060
Epoch 9 finished at time 29m 28s
Epoch 10/99
----------


                                               

train Loss: 64.2379 Acc: 0.9530


                                             

val Loss: 70.7596 Acc: 0.8890
Epoch 10 finished at time 32m 21s
Epoch 11/99
----------


                                              

KeyboardInterrupt: 

In [None]:
def compute_uncertainties(prob_tensor):
    """
    Compute uncertainties from a collection of simulation probabilities.

    Args:
    - prob_tensor: Tensor of shape [num_simulations, batch_size, num_classes]

    Returns:
    - epistemic_uncertainty: Epistemic uncertainty for each item in the batch.
    - aleatoric_uncertainty: Aleatoric uncertainty for each item in the batch.
    - predictive_uncertainty: Total predictive uncertainty (Shannon's entropy) for each item in the batch.
    """
    # Calculate mean and variance across simulations
    mean_probs = prob_tensor.mean(dim=0)  # shape: [batch_size, num_classes]
    var_probs = prob_tensor.var(dim=0)  # shape: [batch_size, num_classes]

    # Epistemic Uncertainty: variance of the expected predictions across simulations
    epistemic_uncertainty = var_probs  # summing over classes

    # Aleatoric Uncertainty: entropy of the average prediction
    aleatoric_uncertainty = -mean_probs * torch.log(mean_probs + 1e-10)

    # Predictive Uncertainty (Shannon's entropy): averaged entropy over all simulations
    entropies = -torch.sum(prob_tensor * torch.log(prob_tensor + 1e-10), dim=2)  # shape: [num_simulations, batch_size]
    predictive_uncertainty = entropies.mean(dim=0)

    return epistemic_uncertainty, aleatoric_uncertainty, predictive_uncertainty

def evaluate_and_return_predictions(model, criterion, dataloader_test):
    """
    Evaluate the model's performance on the test set and return a DataFrame with predictions.
    
    Args:
    - model: Trained model.
    - criterion: Loss function used during training.
    - dataloader_test: DataLoader object for the test set.
    
    Returns:
    - df: DataFrame with columns "label", "pred_class_id", and "loss".
    """
    model = model.cuda()
    model.eval()  # Set model to evaluate mode
    
    all_labels = []
    all_preds = []
    all_losses = []
    
    # Iterate over the test data
    for inputs, labels in tqdm(dataloader_test, total=len(dataloader_test), leave=False):
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        # forward
        with torch.no_grad():
            pred_MVN = model(inputs)
            pred_MVN = pred_MVN.to_data_independent_dist()
            with gpytorch_settings.num_likelihood_samples(15):
                prob_sims = likelihood(pred_MVN).probs
                epistemic_unc, aleatoric_unc, predictive_unc = compute_uncertainties(prob_sims)
                pred_mean_prob = prob_sims.mean(0)
                pred_class_id = pred_mean_prob.max(dim=1)[1]
                
            loss = criterion(pred_MVN, labels)

        # collect results
        all_labels.extend(labels.cpu().numpy().tolist())
        all_preds.extend(pred_class_id.cpu().numpy().tolist())
        all_losses.extend([loss.item()] * inputs.size(0))

    df = pd.DataFrame({
        'label': all_labels,
        'pred_class_id': all_preds,
        'loss': all_losses
    })

    return df


In [None]:
import torch
import torch.nn.functional as F



# Example
prob_tensor = torch.rand([10, 384, 2])
prob_tensor = F.softmax(prob_tensor, dim=-1)  # Make sure the tensor represents valid probabilities
epistemic, aleatoric, predictive = compute_uncertainties(prob_tensor)
print(epistemic, aleatoric, predictive)


In [31]:
for inputs, labels in dataloaders['train']:
    inputs = inputs.cuda()
    labels = labels.cuda()
    break

In [32]:
pred_MVN = model(inputs)

In [37]:
pred_MVN.base_sample_shape

torch.Size([128, 2])

In [33]:
likelihood(pred_MVN).probs.shape

torch.Size([10, 384, 2])

In [65]:
pred_mean_prob = likelihood(pred_MVN).probs.mean(0)
pred = pred_mean_prob.max(dim=1)[1]

In [72]:
loss = loss_fn(pred_MVN, labels)

In [73]:
loss

tensor(55.0969, device='cuda:0', grad_fn=<NegBackward0>)

In [63]:
pred_mean_prob

tensor([[0.6005, 0.3995],
        [0.5145, 0.4855],
        [0.4735, 0.5265],
        [0.5037, 0.4963],
        [0.5575, 0.4425],
        [0.4396, 0.5604],
        [0.4145, 0.5855],
        [0.4609, 0.5391],
        [0.4923, 0.5077],
        [0.5740, 0.4260],
        [0.5006, 0.4994],
        [0.3756, 0.6244],
        [0.4784, 0.5216],
        [0.4453, 0.5547],
        [0.4729, 0.5271],
        [0.2812, 0.7188],
        [0.5067, 0.4933],
        [0.5373, 0.4627],
        [0.4764, 0.5236],
        [0.6365, 0.3635],
        [0.4937, 0.5063],
        [0.4666, 0.5334],
        [0.5238, 0.4762],
        [0.4606, 0.5394],
        [0.5955, 0.4045],
        [0.5134, 0.4866],
        [0.4719, 0.5281],
        [0.5662, 0.4338],
        [0.5675, 0.4325],
        [0.5378, 0.4622],
        [0.4792, 0.5208],
        [0.4352, 0.5648],
        [0.5509, 0.4491],
        [0.6218, 0.3782],
        [0.6447, 0.3553],
        [0.4988, 0.5012],
        [0.6204, 0.3796],
        [0.5447, 0.4553],
        [0.5

In [66]:
pred

tensor([1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0,
        1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 1, 0], device='cuda:0')

In [None]:
x, y = next(iter(dataloader_train))

In [None]:
y_pred = feature_extractor(x)

In [None]:
y_pred = y_pred.detach().numpy()

In [None]:
plt.hist(y_pred.reshape(-1,1), bins=30)
plt.show()

In [None]:
VisiumClassificationDataset??

In [None]:
weights.transforms??

In [None]:
x,y = next(iter(dataset_train))

In [None]:
x = x.numpy().transpose(1,2,0)

In [None]:
plt.imshow(x)
plt.show()

In [None]:
plt.hist(x[2,:,:].numpy().reshape(-1,1), bins=30)
plt.show()

In [None]:
plt.hist(x.numpy().reshape(-1,1), bins=30)
plt.show()

In [None]:
img_array = x.numpy().transpose(1,2,0)

In [None]:
img_array[:,:,0]

In [None]:
plt.imshow(img_array)
plt.show()