In [None]:
import os
from anndata import read_h5ad
import numpy as np
import pandas as pd
import os
import time

import sys
# Add the src/ directory as one where we can import modules
src_dir = "../src"
sys.path.append(src_dir)
from utils import VisiumClassificationDataset
from dkl import initial_values, GP, DKL

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
# from torchvision.io import read_image
# from torchvision.transforms import ToTensor
from torchvision.models import resnet34, ResNet34_Weights

from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import SoftmaxLikelihood

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import matplotlib.pyplot as plt

In [None]:
# update the location where models will be saved to
if torch.hub.get_dir() == '/clusterdata/uqsmac12/.cache/torch/hub':
    torch.hub.set_dir('/scratch/smp/uqsmac12/.cache/torch/hub')

In [None]:
DIR_DATA = '/scratch/smp/uqsmac12/stimage2_data'

# DIR_WANDB = DIR_DATA
DIR_TILES = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
DIR_ANNDATA_PROCESSED = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
file_processed_alex_data = 'all_adata.h5ad'
# DIR_PROCESSED_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/PROCESSED/dataset_breast_cancer_9visium'
DIR_RAW_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/RAW/Alex_NatGen_6BreastCancer'
DIR_RAW_METADATA = os.path.join(DIR_RAW_DATA, 'metadata')

DIR_CHECKPOINTS = os.path.join(DIR_DATA, 'checkpoints')
# location to save data
DIR_PROCESSED_DATASET = os.path.join(DIR_DATA, 'data_processed')
FILE_PROCESSED_VISIUM9 = os.path.join(DIR_PROCESSED_DATASET, 'df_adata_rna_logcpt_images_labels_visium9.csv')

In [None]:
# load csv
df_data = pd.read_csv(FILE_PROCESSED_VISIUM9, index_col='Unnamed: 0.1')

In [None]:
# drop instances without cancer labels
df_data = df_data[df_data['cancer_class'].notna()]

In [None]:
# split into train val and test datasets
df_test = df_data[df_data['library'] == 'CID4465']
df_train = df_data[df_data['library'] != 'CID4465']

In [None]:
# shuffle the rows preserving unique set of instances 
df_train = df_train.sample(frac=1, replace=False, random_state=42)

In [None]:
df_val = df_train.iloc[:1000, :].copy()
df_train = df_train.iloc[1000:, :]

In [None]:
# get transformations
transform_scale = transforms.Lambda(lambda x: x / 255.)
transform_normalise = ResNet34_Weights.DEFAULT.transforms()
composed_transforms = transforms.Compose([transform_scale, transform_normalise])
# preprocess = weights.transforms()

In [None]:
# dataset objects
dataset_train = VisiumClassificationDataset(df_train, transform=composed_transforms)
dataset_val = VisiumClassificationDataset(df_val, transform=composed_transforms)
dataset_test = VisiumClassificationDataset(df_test, transform=composed_transforms)

In [None]:
# dataloaders
dataloader_train = DataLoader(dataset_train, batch_size=128, shuffle=False)
dataloader_val = DataLoader(dataset_val)
dataloader_test = DataLoader(dataset_test)

In [None]:
## load pretrained model and reset final fully connected layer
feature_extractor = resnet34(weights=ResNet34_Weights.DEFAULT)
num_final_fc_in = feature_extractor.fc.in_features # get number of features
# replace final layer
num_features_out = 128
feature_extractor.fc = nn.Linear(num_final_fc_in, num_features_out)

In [None]:
num_inducing_points = 50
initial_inducing_points, initial_lengthscale = initial_values(
    dataset_train, feature_extractor, num_inducing_points
)

In [None]:
kernel = 'Matern52' # RBF # Matern32
num_classes = 2
gp = GP(
    num_outputs=num_classes,
    initial_lengthscale=initial_lengthscale,
    initial_inducing_points=initial_inducing_points,
    kernel=kernel,
)

In [None]:
model = DKL(feature_extractor, gp)

In [None]:
likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
likelihood = likelihood.cuda()

In [None]:
elbo_fn = VariationalELBO(likelihood, gp, num_data=len(dataset_train))
loss_fn = lambda x, y: -elbo_fn(x, y)

In [None]:
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=1e-4,
    momentum=0.9,
    weight_decay=1e-2,
)

In [None]:
# criterion = nn.CrossEntropyLoss()

In [None]:
# # Observe that all parameters are being optimized
# optimizer_ft = optim.SGD(feature_extractor.parameters(), lr=1e-4, momentum=0.9)

In [None]:
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.2)

In [None]:
dataloaders = {'train': dataloader_train, 'val': dataloader_val}

In [None]:
def train_model(model, criterion, optimizer, scheduler, best_model_params_path=DIR_CHECKPOINTS, num_epochs=100):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    
    torch.save(model.state_dict(), best_model_params_path)
    best_acc = 0.0
    model = model.cuda()
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.cuda()
                labels = labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
#                     _, preds = torch.max(outputs, 1)
                    # 
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), best_model_params_path)

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(torch.load(best_model_params_path))
    
    return model

In [None]:
train_model(model, loss_fn, optimizer, exp_lr_scheduler)

In [None]:
x, y = next(iter(dataloader_train))

In [None]:
y_pred = feature_extractor(x)

In [None]:
y_pred = y_pred.detach().numpy()

In [None]:
plt.hist(y_pred.reshape(-1,1), bins=30)
plt.show()

In [None]:
VisiumClassificationDataset??

In [None]:
weights.transforms??

In [None]:
x,y = next(iter(dataset_train))

In [None]:
x = x.numpy().transpose(1,2,0)

In [None]:
plt.imshow(x)
plt.show()

In [None]:
plt.hist(x[2,:,:].numpy().reshape(-1,1), bins=30)
plt.show()

In [None]:
plt.hist(x.numpy().reshape(-1,1), bins=30)
plt.show()

In [None]:
img_array = x.numpy().transpose(1,2,0)

In [None]:
img_array[:,:,0]

In [None]:
plt.imshow(img_array)
plt.show()