In [2]:
import pandas as pd
import numpy as np
import torch 
from tqdm import tqdm 
from sklearn.model_selection import train_test_split
import glob, os, pickle
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torch_geometric.data import Data, Dataset
import dgl

In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Need to define the data class 
- Here focus mainly on the get() method. We don't need to process anything

In [62]:
class WSI_Graph_Class(Dataset):
    
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(None, transform, pre_transform)
        self.root_dir = root
        self.WSI_df = pd.read_csv(root) #get the WSI metadata
        
    #just pass here, we aren't going to return any raw file names
    def raw_file_names(self):
        pass 

    #here we can return each of the WSI 
    def processed_file_names(self):
        return list(self.WSI_df["sample_id"])
    
    def len(self):
        return len(self.processed_file_names())
    
    #return the graph class for that idx 
    def get(self, idx):
        path = self.WSI_df["path"].iloc[idx]
        data = torch.load(path)
        return data

In [63]:
root = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/metadata.csv"

dataset = WSI_Graph_Class(root = root, transform = None, pre_transform = None, pre_filter = None)

In [57]:
data.get(0)

Data(x=[23215, 2048], edge_index=[2, 179532], y=[23215])

# Define Model 
- This mainly draws upon HIV project code 

In [64]:
import torch.nn.functional as F 
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
torch.manual_seed(42)

class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        num_classes = 2
        embedding_size = 2048 # from resnet  

        #define the GNN layers 

        #layer 1
        #the first graph attention layer which will create 3*embed size embeddings for each node. This will also take care of all the message passing and aggregation
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout = 0.3)
        #reduce the dimensionality back
        self.head_transform1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)

        #layer 2
        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout = 0.3)
        self.head_transform2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)

        #layer 3
        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout = 0.3)
        self.head_transform3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)


        #linear layers - these need to be modified to match the output size? Or maybe not
        self.linear1 = Linear(embedding_size*2, embedding_size)
        self.linear2 = Linear(embedding_size, 2)

    def forward(self, x, edge_attr, edge_index, batch_index):
        #block 1 
        x = self.conv1(x, edge_index)
        x = self.head_transform1(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x, edge_index, None, batch_index)
        #graph rep. 
        x1 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
        #block 2 
        x = self.conv2(x, edge_index)
        x = self.head_transform2(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x, edge_index, None, batch_index)
        #graph rep. 
        x2 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
        #block 3
        x = self.conv3(x, edge_index)
        x = self.head_transform3(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x, edge_index, None, batch_index)
        #graph rep. 
        x3 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
        #element wise addition , and each is 2048 
        x = x1 + x2 + x3
        #output block 
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.5)
        x = self.linear2(x)

        return x

# Dataloader

In [66]:
train_set, test_set = torch.utils.data.random_split(dataset, [int(len(dataset)*.8), int(len(dataset)*.2)])

In [70]:
train_set, val_set = torch.utils.data.random_split(train_set, [int(len(train_set)*.8), int(len(train_set)*.2)+1])

In [71]:
print(len(train_set), len(val_set), len(test_set))

19 5 6


In [72]:
#loading the model 
num_features = 2048
model = GNN(feature_size=num_features)
#put model on device 
model = model.to(device)
print(model)

RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
#loss and optimizer 
weights = torch.tensor([1, 10], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [None]:
#prepare training 
batch_size = 1

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

# Actual Model Training 