In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import h5py
import re
import sklearn
import umap


import torch 
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from torch import optim, utils, Tensor



# from AttentionMIL_model import Attention

path_to_extracted_features = '/omics/odcf/analysis/OE0585_projects/chromothripsis/histopathology/UKHD_Neuro/RetCLL_Features'


%matplotlib inline



  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
class Attention(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        # Features are already extracted
        # self.feature_extractor_part1 = nn.Sequential(
        #     nn.Conv2d(1, 20, kernel_size=5),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2, stride=2),
        #     nn.Conv2d(20, 50, kernel_size=5),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2, stride=2)
        # )
        
        # Features come in at 2048 per patch
        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(2048, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x = x.squeeze(0)

        # H = self.feature_extractor_part1(x)
        H = self.feature_extractor_part2(x)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 2, 1)  # KxN
        A = F.softmax(A, dim=2)  # softmax over N

        M = torch.matmul(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood.mean()
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        loss = self.calculate_objective(x, y)
        # x = x.view(x.size(0), -1)
        # z = self.encoder(x)
        # x_hat = self.decoder(z)
        # loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    


In [3]:
class RetCCLFeatureLoader(Dataset):
    def __init__(self, slide_filenames, feature_path, labels, patches_per_iter = 10):
        assert len(labels) == len(slide_filenames)
        self.labels = labels
        self.file_paths = [feature_path + "/" +  x for x in slide_filenames]
        self.slide_names = slide_filenames
        self.num_patches = patches_per_iter

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cur_path = self.file_paths[idx]
        features = h5py.File(cur_path, 'r')['feats'][()]
        label = self.labels[idx]
        # features = features.reshape([1,-1,2048])
        sampled_pchs = np.random.randint(0, features.shape[0], self.num_patches)
        features = features[sampled_pchs,:]
        return features.astype(np.float32), label

In [3]:
import os
os.environ['HTTP_PROXY']="http://www-int.dkfz-heidelberg.de:80"
os.environ['HTTPS_PROXY']="http://www-int.dkfz-heidelberg.de:80"

os.environ['TENSORBOARD_BINARY'] = '/home/p163v/mambaforge/envs/marugoto/bin/tensorboard'


In [5]:
slide_meta = pd.read_csv("../metadata/slides_FS_anno.csv")
ct_scoring = pd.read_csv("../metadata/CT_3_Class_Draft.csv")



ct_scoring["txt_idat"] = ct_scoring["idat"].astype("str")
ct_scoring.index = ct_scoring.txt_idat
slide_meta.index = slide_meta.txt_idat
ct_scoring = ct_scoring.drop("txt_idat", axis=1)
slide_meta = slide_meta.drop("txt_idat", axis=1)
slide_annots = slide_meta.join(ct_scoring, lsuffix="l")


myx = [x in ["Chromothripsis", "No Chromothripsis"] for x in slide_annots.CT_class]

slide_annots = slide_annots.loc[myx]
slide_names = slide_annots.uuid

# slide_names
slide_annots.CT_class

txt_idat
10003886253_R02C02       Chromothripsis
10003886253_R03C01    No Chromothripsis
10003886256_R03C02    No Chromothripsis
10003886258_R02C01    No Chromothripsis
10003886259_R02C01    No Chromothripsis
                            ...        
9969477124_R05C02     No Chromothripsis
9980102013_R06C01     No Chromothripsis
9980102032_R03C01     No Chromothripsis
9980102032_R04C01     No Chromothripsis
9980102032_R05C01     No Chromothripsis
Name: CT_class, Length: 1956, dtype: object

In [6]:

# Load the data

files = [x + ".h5" for x in slide_names]

# with h5py.File(filename, "r") as f:
#     # Print all root level object names (aka keys) 
#     # these can be group or dataset names 
#     print("Keys: %s" % f.keys())
#     # get first object name/key; may or may NOT be a group
#     a_group_key = list(f.keys())[0]

#     # get the object type for a_group_key: usually group or dataset
#     print(type(f[a_group_key])) 

#     # If a_group_key is a group name, 
#     # this gets the object names in the group and returns as a list
#     data = list(f[a_group_key])

#     # If a_group_key is a dataset name, 
#     # this gets the dataset values and returns as a list
#     data = list(f[a_group_key])
#     # preferred methods to get dataset values:
#     ds_obj = f[a_group_key]      # returns as a h5py dataset object
#     ds_arr = f[a_group_key][()]  # returns as a numpy array
len(files)

1956

In [7]:
myx = [os.path.exists(path_to_extracted_features + "/" + x) for x in files]
files = np.array(files)[myx]
len(files)

1530

In [8]:
TilesPerPat = pd.read_csv("../metadata/TilesPerPat.csv")
filestokeep = TilesPerPat.loc[TilesPerPat["Tiles Per Slide"]>=10].File.tolist()

myx2 = [x in filestokeep for x in files]

files = files[myx2]
len(files)

1529

In [9]:
labels = slide_annots.CT_class.factorize()[0][myx][myx2]

In [10]:
RetCCLDataset = RetCCLFeatureLoader(files, path_to_extracted_features,labels)
x, y = RetCCLDataset.__getitem__(0)

In [11]:
train_data, test_data =  torch.utils.data.random_split(RetCCLDataset, [0.8, 0.2])

In [12]:
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

In [13]:
model = Attention()

In [4]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
trainer = L.Trainer(max_epochs=10, log_every_n_steps=10) # limit_train_batches=100,
trainer.fit(model=model, train_dataloaders=train_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                    | Type       | Params
-------------------------------------------------------
0 | feature_extractor_part2 | Sequential | 1.0 M 
1 | attention               | Sequential | 64.3 K
2 | classifier              | Sequential | 501   
-------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.357     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

[rank: 0] Received SIGTERM: 15
