In [None]:
import os
os.environ['PATH'] = '/scratch/smp/uqsmac12/.conda/env/lit_torch_gp/bin:' + os.environ['PATH']

In [None]:
!which python

In [53]:
import os
from anndata import read_h5ad
import numpy as np
import pandas as pd
import os
import torch

from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F

from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torchvision.io import read_image
# from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping



In [2]:
# update the location where models will be saved to
if torch.hub.get_dir() == '/clusterdata/uqsmac12/.cache/torch/hub':
    torch.hub.set_dir('/scratch/smp/uqsmac12/.cache/torch/hub')

In [3]:
import matplotlib.pyplot as plt

In [4]:
torch.cuda.is_available()

True

## load data

In [5]:
DIR_DATA = '/scratch/smp/uqsmac12/stimage2_data'
DIR_CHECKPOINTS = os.path.join(DIR_DATA, 'checkpoints')
# DIR_WANDB = DIR_DATA
DIR_TILES = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
DIR_ANNDATA_PROCESSED = '/scratch/smp/uqsmac12/dataset_breast_cancer_9visium'
file_processed_alex_data = 'all_adata.h5ad'
# DIR_PROCESSED_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/PROCESSED/dataset_breast_cancer_9visium'
DIR_RAW_DATA = '/afm03/Q2/Q2051/STimage_project/STimage_dataset/RAW/Alex_NatGen_6BreastCancer'
DIR_RAW_METADATA = os.path.join(DIR_RAW_DATA, 'metadata')

In [6]:
adata_all = read_h5ad(os.path.join(DIR_ANNDATA_PROCESSED, file_processed_alex_data))

## Get image and RNA data

In [7]:
# update metadata for annadata such that it maps to the correct location
adata_all.obs["tile_path"] = adata_all.obs.tile_path.map(
    lambda x: x.replace("/clusterdata/uqxtan9/Xiao/breast_cancer_9visium",
                        DIR_TILES))

In [8]:
# ensure change of datapath executed correctly
assert 'uqsmac12' in adata_all.obs['tile_path'][0]

In [9]:
gene_list = ["COX6C","TTLL12", "PABPC1", "GNAS", "HSP90AB1", "TFF3", "ATP1A1", "B2M", "FASN", "SPARC", "CD74", "CD63", "CD24", "CD81"]

In [10]:
adata_all.obs['library_id'].value_counts()

1160920F    4783
1142243F    4704
block2      3770
block1      3519
FFPE        2338
CID4290     2300
CID4465     1106
CID44971    1046
CID4535     1012
Name: library_id, dtype: int64

In [11]:
X_counts = np.exp(adata_all.to_df().values)

In [12]:
X_counts.shape

(24578, 14664)

## Get class labels

In [13]:
dict_meta = {}
dict_meta_counts = {}
unique_classes = set()
list_library = []
for fname in os.listdir(DIR_RAW_METADATA):
    if fname[0] != '.':
        library_id = fname.split('_')[0]
        list_library.append(library_id)
        dict_meta[library_id] = pd.read_csv(os.path.join(DIR_RAW_METADATA, fname))
        dict_meta[library_id]['library'] = library_id
        dict_meta_counts[library_id] = dict_meta[library_id]['Classification'].value_counts()
        print(library_id)
        print(dict_meta[library_id]['Classification'].value_counts())
        unique_classes.update(set(dict_meta[library_id]['Classification'].unique()))
        print()

1142243F
Invasive cancer + stroma + lymphocytes    3627
Necrosis                                   568
Stroma                                     445
Artefact                                   119
Lymphocytes                                 15
TLS                                         10
Name: Classification, dtype: int64

1160920F
Invasive cancer + stroma + lymphocytes      3146
Stroma                                      1132
Normal glands + lymphocytes                  278
Lymphocytes                                  186
Adipose tissue                                83
Artefact                                      48
DCIS                                          12
Cancer trapped in lymphocyte aggregation       9
Name: Classification, dtype: int64

CID4290
Invasive cancer + stroma                  2082
Invasive cancer + stroma + lymphocytes     215
Stroma                                     122
Artefact                                     7
Name: Classification, dtype: int64

CID4

In [14]:
df_meta = pd.DataFrame(dict_meta_counts)
df_meta = df_meta.fillna(0)

In [17]:
df_meta

Unnamed: 0,1142243F,1160920F,CID4290,CID4535,CID44971,CID4465
Adipose tissue,0.0,83.0,0.0,8.0,0.0,0.0
Artefact,119.0,48.0,7.0,23.0,1.0,4.0
Cancer trapped in lymphocyte aggregation,0.0,9.0,0.0,0.0,0.0,0.0
DCIS,0.0,12.0,0.0,0.0,273.0,0.0
Invasive cancer,0.0,0.0,0.0,418.0,0.0,0.0
Invasive cancer + adipose tissue + lymphocytes,0.0,0.0,0.0,3.0,0.0,0.0
Invasive cancer + lymphocytes,0.0,0.0,0.0,361.0,317.0,0.0
Invasive cancer + stroma,0.0,0.0,2082.0,0.0,0.0,0.0
Invasive cancer + stroma + lymphocytes,3627.0,3146.0,215.0,0.0,0.0,1131.0
Lymphocytes,15.0,186.0,0.0,69.0,81.0,0.0


In [None]:
set(dict_meta[library_id]['Classification'].unique())

In [18]:
dict_meta.keys()

dict_keys(['1142243F', '1160920F', 'CID4290', 'CID4535', 'CID44971', 'CID4465'])

In [20]:
dict_meta[library_id]

Unnamed: 0.1,Unnamed: 0,nCount_RNA,nFeature_RNA,subtype,patientid,Classification,library
0,AGCTTATAGAGACCTG-1,12930,4501,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
1,AACTAGCGTATCGCAC-1,9500,3867,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
2,AACTTTAGCTGCTGAG-1,17978,5603,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
3,CCCAAGACAGAGTATG-1,5676,2901,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
4,GGCATCAACGAGCACG-1,28000,6936,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
...,...,...,...,...,...,...,...
1206,TAGAGTGTTCCGGGTA-1,20786,5926,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
1207,GCTTCCATGTAACCGC-1,20039,5990,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
1208,GGTTCGCATTTGCCGT-1,21465,6019,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465
1209,TCAATCCGGGAAGTTT-1,10138,4330,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465


In [25]:
mapper_celltype_to_binary = { "Artefact": 0, "Lymphocytes": 0, "Stroma": 0, "DCIS": 0, "Invasive cancer + lymphocytes": 1, "Normal + stroma + lymphocytes": 0, "Stroma + adipose tissue": 0,
    "Invasive cancer + stroma + lymphocytes": 1, "Normal duct": 0, "Adipose tissue": 0, "Invasive cancer": 1, "Invasive cancer + adipose + lymphocytes": 1, "Uncertain": 0,
    "Invasive cancer + stroma": 1, "Cancer trapped in lymphocyte aggregation": 1, "Normal glands + lymphocytes": 0, "Necrosis": 0, "TLS": 0}

In [26]:
dict_meta[library_id]['cancer_class'] = dict_meta[library_id]['Classification'].map(mapper_celltype_to_binary)

In [28]:
dict_meta[library_id]['cancer_class'].value_counts()

1    1131
0      80
Name: cancer_class, dtype: int64

In [36]:
for library_id in dict_meta.keys():
    dict_meta[library_id]['cancer_class'] = dict_meta[library_id]['Classification'].map(mapper_celltype_to_binary)
    dict_meta[library_id]['cancer_class'] = dict_meta[library_id]['cancer_class'].fillna(0)
    print(library_id)
    print(dict_meta[library_id]['cancer_class'].value_counts())
    print(dict_meta[library_id]['cancer_class'].isna().sum())
    print()

1142243F
1    3627
0    1157
Name: cancer_class, dtype: int64
0

1160920F
1.0    3155
0.0    1740
Name: cancer_class, dtype: int64
0

CID4290
1.0    2297
0.0     135
Name: cancer_class, dtype: int64
0

CID4535
1.0    779
0.0    348
Name: cancer_class, dtype: int64
0

CID44971
0.0    845
1.0    317
Name: cancer_class, dtype: int64
0

CID4465
1    1131
0      80
Name: cancer_class, dtype: int64
0



In [41]:
for library_id in dict_meta.keys():
    dict_meta[library_id].index = dict_meta[library_id]['Unnamed: 0'] + '-' + library_id

In [42]:
dict_meta[library_id]

Unnamed: 0_level_0,Unnamed: 0,nCount_RNA,nFeature_RNA,subtype,patientid,Classification,library,cancer_class
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
AGCTTATAGAGACCTG-1-CID4465,AGCTTATAGAGACCTG-1,12930,4501,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
AACTAGCGTATCGCAC-1-CID4465,AACTAGCGTATCGCAC-1,9500,3867,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
AACTTTAGCTGCTGAG-1-CID4465,AACTTTAGCTGCTGAG-1,17978,5603,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
CCCAAGACAGAGTATG-1-CID4465,CCCAAGACAGAGTATG-1,5676,2901,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
GGCATCAACGAGCACG-1-CID4465,GGCATCAACGAGCACG-1,28000,6936,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
...,...,...,...,...,...,...,...,...
TAGAGTGTTCCGGGTA-1-CID4465,TAGAGTGTTCCGGGTA-1,20786,5926,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
GCTTCCATGTAACCGC-1-CID4465,GCTTCCATGTAACCGC-1,20039,5990,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
GGTTCGCATTTGCCGT-1-CID4465,GGTTCGCATTTGCCGT-1,21465,6019,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1
TCAATCCGGGAAGTTT-1-CID4465,TCAATCCGGGAAGTTT-1,10138,4330,TNBC,CID4465,Invasive cancer + stroma + lymphocytes,CID4465,1


In [47]:
df_meta = pd.concat(list(dict_meta.values()), axis=0)


In [49]:
df_meta.head()

Unnamed: 0_level_0,Unnamed: 0,nCount_RNA,nFeature_RNA,patientid,subtype,Classification,library,cancer_class
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TACCGATCCAACACTT-1-1142243F,TACCGATCCAACACTT-1,4073,2071,1142243F,TNBC,Artefact,1142243F,0.0
GATAAGGGACGATTAG-1-1142243F,GATAAGGGACGATTAG-1,4628,2320,1142243F,TNBC,Artefact,1142243F,0.0
TGTTGGCTGGCGGAAG-1-1142243F,TGTTGGCTGGCGGAAG-1,5116,2494,1142243F,TNBC,Artefact,1142243F,0.0
GCGAGGGACTGCTAGA-1-1142243F,GCGAGGGACTGCTAGA-1,8170,3464,1142243F,TNBC,Artefact,1142243F,0.0
GCGCGTTTAAATCGTA-1-1142243F,GCGCGTTTAAATCGTA-1,7534,3345,1142243F,TNBC,Artefact,1142243F,0.0


In [43]:
df_adata = adata_all.obs.copy()

In [50]:
df_adata = df_adata.merge(df_meta, how='left', left_index=True, right_index=True)

In [51]:
df_adata.shape

(24578, 17)

In [52]:
df_adata.head()

Unnamed: 0.1,in_tissue,array_row,array_col,imagecol,imagerow,tile_tissue_mask_path,tissue_area,tile_path,library_id,Unnamed: 0,nCount_RNA,nFeature_RNA,patientid,subtype,Classification,library,cancer_class
GATAAGGGACGATTAG-1-1142243F,1,1,3,12601,4511,/tmp/1142243F_tissue_mask/1142243F-12601-4511-...,0.733437,/scratch/smp/uqsmac12/dataset_breast_cancer_9v...,1142243F,GATAAGGGACGATTAG-1,4628.0,2320.0,1142243F,TNBC,Artefact,1142243F,0.0
TGTTGGCTGGCGGAAG-1-1142243F,1,1,5,12872,4512,/tmp/1142243F_tissue_mask/1142243F-12872-4512-...,0.878391,/scratch/smp/uqsmac12/dataset_breast_cancer_9v...,1142243F,TGTTGGCTGGCGGAAG-1,5116.0,2494.0,1142243F,TNBC,Artefact,1142243F,0.0
GCGAGGGACTGCTAGA-1-1142243F,1,1,7,13144,4513,/tmp/1142243F_tissue_mask/1142243F-13144-4513-...,0.884632,/scratch/smp/uqsmac12/dataset_breast_cancer_9v...,1142243F,GCGAGGGACTGCTAGA-1,8170.0,3464.0,1142243F,TNBC,Artefact,1142243F,0.0
GCGCGTTTAAATCGTA-1-1142243F,1,1,9,13416,4514,/tmp/1142243F_tissue_mask/1142243F-13416-4514-...,0.813425,/scratch/smp/uqsmac12/dataset_breast_cancer_9v...,1142243F,GCGCGTTTAAATCGTA-1,7534.0,3345.0,1142243F,TNBC,Artefact,1142243F,0.0
ATCTATCGATGATCAA-1-1142243F,1,3,3,12599,4984,/tmp/1142243F_tissue_mask/1142243F-12599-4984-...,0.879218,/scratch/smp/uqsmac12/dataset_breast_cancer_9v...,1142243F,ATCTATCGATGATCAA-1,7042.0,3108.0,1142243F,TNBC,Invasive cancer + stroma + lymphocytes,1142243F,1.0


In [54]:
adata_all.obs = df_adata

In [57]:
adata_all.obs['library'].value_counts()

1160920F    4783
1142243F    4704
CID4290     2300
CID4465     1106
CID44971    1046
CID4535     1012
Name: library, dtype: int64

In [56]:
adata_all.obs['cancer_class'].value_counts()

1.0    10975
0.0     3976
Name: cancer_class, dtype: int64

In [55]:
adata_all.obs['cancer_class'].isna().value_counts()

False    14951
True      9627
Name: cancer_class, dtype: int64

In [58]:
df_adata = adata_all.obs

In [63]:
# drop instances without labels
df_adata = df_adata[df_adata['cancer_class'].notna()]

In [67]:
# drop instances no in the tissue
df_adata = df_adata[df_adata['in_tissue'] == 1]

In [70]:
img_test = read_image(df_adata['tile_path'][0])

In [71]:
img_test.shape

torch.Size([3, 299, 299])

In [135]:
class VisiumClassificationDataset(Dataset):
    """
    X (image) -> y (class)
    """
#     def __init__(self, adata, dim=(299, 299), n_channels=3, genes=None, transform=None, target_transform=None):
    def __init__(self, df_data, transform=None, target_transform=None):
        self.df_data = df_data
        # testing on dataloader
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.df_data)
    
    def __getitem__(self, idx):
        idx_name = self.df_data.index[idx]
        X_img = self.load_img(idx_name)
#         y = self.get_expression(idx_name)
        c = self.get_class(idx_name)
#         return idx_name
        print(c)
        return X_img, c
    
    def load_img(self, key):
        """load an image"""
        img_path = self.df_data.loc[key, 'tile_path']
        X_img = read_image(img_path)
#         # if shape illegal cast to proper shape
#         if X_img.shape != (self.n_channels, dim[0], dim[1]):
#             assert False, 'not tested'
#             X_img = transforms.Resize(self.dim)(X_img)
            
        if self.transform:
#             X_img = seq_aug(image=X_img)
            X_img = self.transform(X_img)
        
        return X_img
    
#     def get_expression(self, key):
#         y = torch.Tensor(adata_all[idx_name, gene_list].to_df().values)
        
#         if self.target_transform:
#             y = self.target_transform(y)
        
#         return y
    
    def get_class(self, key):
        c = self.df_data.loc[key, 'cancer_class']
        print(c)
        return c

In [136]:
# split into train validation and test

In [137]:
df_test = df_adata[df_adata['library'] == 'CID4465']
df_train = df_adata[df_adata['library'] != 'CID4465']

In [138]:
train_dataset = VisiumClassificationDataset(df_train)

In [139]:
test_dataset = VisiumClassificationDataset(df_test)

In [None]:
kwargs = {"num_workers": 4, "pin_memory": True}

In [140]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [141]:
# train DKL model

In [None]:
feature_extractor = WideResNet(
        input_size,
        hparams.spectral_conv,
        hparams.spectral_bn,
        dropout_rate=hparams.dropout_rate,
        coeff=hparams.coeff,
        n_power_iterations=hparams.n_power_iterations,
    )

In [None]:
initial_inducing_points, initial_lengthscale = dkl.initial_values(
    train_dataset, feature_extractor, hparams.n_inducing_points
)

gp = dkl.GP(
    num_outputs=num_classes,
    initial_lengthscale=initial_lengthscale,
    initial_inducing_points=initial_inducing_points,
    kernel=hparams.kernel,
)

model = dkl.DKL(feature_extractor, gp)

likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
likelihood = likelihood.cuda()

elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset))
loss_fn = lambda x, y: -elbo_fn(x, y)

In [None]:
model = model.cuda()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=hparams.learning_rate,
    momentum=0.9,
    weight_decay=hparams.weight_decay,
)

milestones = [60, 120, 160]

scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=milestones, gamma=0.2
)

In [None]:
def step(engine, batch):
    model.train()
    if not hparams.sngp:
        likelihood.train()

    optimizer.zero_grad()

    x, y = batch
    x, y = x.cuda(), y.cuda()

    y_pred = model(x)
    loss = loss_fn(y_pred, y)

    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
def eval_step(engine, batch):
    model.eval()
    if not hparams.sngp:
        likelihood.eval()

    x, y = batch
    x, y = x.cuda(), y.cuda()

    with torch.no_grad():
        y_pred = model(x)

    return y_pred, y

In [None]:
trainer = Engine(step)
evaluator = Engine(eval_step)

metric = Average()
metric.attach(trainer, "loss")

In [None]:
def output_transform(output):
    y_pred, y = output

    # Sample softmax values independently for classification at test time
    y_pred = y_pred.to_data_independent_dist()

    # The mean here is over likelihood samples
    y_pred = likelihood(y_pred).probs.mean(0)

    return y_pred, y



In [None]:
metric = Accuracy(output_transform=output_transform)
metric.attach(evaluator, "accuracy")

In [None]:
metric = Loss(lambda y_pred, y: -likelihood.expected_log_prob(y, y_pred).mean())
metric.attach(evaluator, "loss")

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_results(trainer):
    metrics = trainer.state.metrics
    train_loss = metrics["loss"]

    result = f"Train - Epoch: {trainer.state.epoch} "
    if hparams.sngp:
        result += f"Loss: {train_loss:.2f} "
    else:
        result += f"ELBO: {train_loss:.2f} "
    print(result)

    writer.add_scalar("Loss/train", train_loss, trainer.state.epoch)

    if hparams.spectral_conv:
        for name, layer in model.feature_extractor.named_modules():
            if isinstance(layer, torch.nn.Conv2d):
                writer.add_scalar(
                    f"sigma/{name}", layer.weight_sigma, trainer.state.epoch
                )

    if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0:
        _, auroc, aupr = get_ood_metrics(
            hparams.dataset, "SVHN", model, likelihood, hparams.data_root
        )
        print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}")
        writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch)
        writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch)

    evaluator.run(test_loader)
    metrics = evaluator.state.metrics
    acc = metrics["accuracy"]
    test_loss = metrics["loss"]

    result = f"Test - Epoch: {trainer.state.epoch} "
    if hparams.sngp:
        result += f"Loss: {test_loss:.2f} "
    else:
        result += f"NLL: {test_loss:.2f} "
    result += f"Acc: {acc:.4f} "
    print(result)
    writer.add_scalar("Loss/test", test_loss, trainer.state.epoch)
    writer.add_scalar("Accuracy/test", acc, trainer.state.epoch)

    scheduler.step()

In [None]:
pbar = ProgressBar(dynamic_ncols=True)
pbar.attach(trainer)

trainer.run(train_loader, max_epochs=200)

# Done training - time to evaluate
results = {}

evaluator.run(test_loader)
test_acc = evaluator.state.metrics["accuracy"]
test_loss = evaluator.state.metrics["loss"]
results["test_accuracy"] = test_acc
results["test_loss"] = test_loss

_, auroc, aupr = get_ood_metrics(
    hparams.dataset, "SVHN", model, likelihood, hparams.data_root
)
results["auroc_ood_svhn"] = auroc
results["aupr_ood_svhn"] = aupr

print(f"Final accuracy {results['test_accuracy']:.4f}")

results_json = json.dumps(results, indent=4, sort_keys=True)
(results_dir / "results.json").write_text(results_json)

torch.save(model.state_dict(), results_dir / "model.pt")
if likelihood is not None:
    torch.save(likelihood.state_dict(), results_dir / "likelihood.pt")

writer.close()