In [1]:
import pandas as pd
import numpy as np
import torch 
from tqdm import tqdm 
from sklearn.model_selection import train_test_split
import glob, os, pickle
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
import dgl
from torch_geometric.utils import dropout_edge
from sklearn.metrics import roc_auc_score
import torch.optim as optim
import torchvision.ops.focal_loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn.functional as F 
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, SAGEConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

In [3]:
!nvidia-smi

Mon Jan  9 13:11:29 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:18:00.0 Off |                    0 |
| N/A   35C    P0    40W / 300W |      3MiB / 32768MiB |      0%   E. Process |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:3B:00.0 Off |                    0 |
| N/A   29C    P0    41W / 300W |      3MiB / 32768MiB |      0%      Default |
|       

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
device 

device(type='cuda', index=0)

# Need to define the data class 
- Here focus mainly on the get() method. We don't need to process anything
- We also return masks for each graph, that will help with training 
- Actually, no masks. Inductive training.
- We are no longer using this data class

In [6]:
# class WSI_Graph_Class(Dataset):
    
#     def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
#         super().__init__(None, transform, pre_transform)
#         self.root_dir = root
#         self.WSI_df = pd.read_csv(root) #get the WSI metadata
#         self.masks = {} #map node num -> train/val/test masks
#         self.create_all_masks()
        
#     def create_all_masks(self):
        
#         for idx in tqdm(range(len(self.WSI_df["sample_id"]))):
#             path = self.WSI_df["path"].iloc[idx]
            
#             #this is the graph. We also need to return the training/validation/testing masks 
#             data = torch.load(path)
#             nodes = [i for i in range(data.x.shape[0])] #node 0 is in 0th pos, 1 in 1, and so on 
            
#             #all of the masks 
#             train_mask = [False] * len(nodes)
#             val_mask = [False] * len(nodes)
#             test_mask = [False] * len(nodes)
#             self.create_mask(nodes, train_mask, val_mask, test_mask)
#             #now add them to dictionary
#             self.masks[idx] = [train_mask, val_mask, test_mask]
        
        
#     def create_mask(self, nodes, train_mask, val_mask, test_mask):        
#         #create train/test/val nodes (75/25)
#         train, test = train_test_split(nodes)
#         test, val = train_test_split(test)
        
#         #now create masks
#         for i in range(len(nodes)):
#             if i in train: 
#                 train_mask[i] = True 
                
#         for i in range(len(nodes)):
#             if nodes[i] in val: 
#                 val_mask[i] = True 
                
#         for i in range(len(nodes)):
#             if nodes[i] in test: 
#                 test_mask[i] = True 
                
#     #just pass here, we aren't going to return any raw file names
#     def raw_file_names(self):
#         pass 
#     #here we can return each of the WSI 
#     def processed_file_names(self):
#         return list(self.WSI_df["sample_id"])
    
#     def len(self):
#         return len(self.processed_file_names())
    
#     #return the graph class for that idx 
#     def get(self, idx):
#         path = self.WSI_df["path"].iloc[idx]
#         #this is the graph. We also need to return the training/validation/testing masks 
#         data = torch.load(path)
#         masks = self.masks[idx]
#         train_mask = masks[0]
#         val_mask = masks[1]
#         test_mask = masks[2]
#         return (data, torch.tensor(train_mask), torch.tensor(val_mask), torch.tensor(test_mask))

In [7]:
# root = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/metadata.csv"

# dataset = WSI_Graph_Class(root = root, transform = None, pre_transform = None, pre_filter = None)

# Define Model 
- This mainly draws upon HIV project code 

In [8]:
torch.manual_seed(42)

# class GNN(torch.nn.Module):
#     def __init__(self, feature_size):
#         super(GNN, self).__init__()
#         num_classes = 2
#         embedding_size = 2048 # from resnet  

#         #define the GNN layers 

#         #layer 1
#         #the first graph attention layer which will create 3*embed size embeddings for each node. This will also take care of all the message passing and aggregation
#         self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout = 0.3)
#         #reduce the dimensionality back
#         self.head_transform1 = Linear(embedding_size*3, embedding_size)
#         self.pool1 = TopKPooling(embedding_size, ratio=0.8)

#         #layer 2
#         self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout = 0.3)
#         self.head_transform2 = Linear(embedding_size*3, embedding_size)
#         self.pool2 = TopKPooling(embedding_size, ratio=0.5)

#         #layer 3
#         self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout = 0.3)
#         self.head_transform3 = Linear(embedding_size*3, embedding_size)
#         self.pool3 = TopKPooling(embedding_size, ratio=0.2)


#         #linear layers - these need to be modified to match the output size? Or maybe not
#         self.linear1 = Linear(embedding_size*2, embedding_size)
#         self.linear2 = Linear(embedding_size, 2)

#     def forward(self, x, edge_attr, edge_index, batch_index):
#         #block 1 
#         x = self.conv1(x, edge_index)
#         x = self.head_transform1(x)

#         x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x, edge_index, None, batch_index)
#         #graph rep. 
#         x1 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
#         #block 2 
#         x = self.conv2(x, edge_index)
#         x = self.head_transform2(x)

#         x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x, edge_index, None, batch_index)
#         #graph rep. 
#         x2 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
#         #block 3
#         x = self.conv3(x, edge_index)
#         x = self.head_transform3(x)

#         x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x, edge_index, None, batch_index)
#         #graph rep. 
#         x3 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)
#         #element wise addition , and each is 2048 
#         x = x1 + x2 + x3
#         #output block 
#         x = self.linear1(x).relu()
#         x = F.dropout(x, p=0.5)
#         x = self.linear2(x)

#         return x

<torch._C.Generator at 0x2b6f5fc05730>

In [38]:
torch.manual_seed(42)

class simple_GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(simple_GNN, self).__init__()
        
        self.num_classes = 3 # 0 = bengign, 1 = scc, 2 = inflamm
        self.embedding_size = 2048 # this is what we want the embedding to be
        
        self.linear1 = Linear(self.embedding_size, 128)
        #define the GNN layers 
        
        self.drop_edge = lambda edge_index: dropout_edge(edge_index,p=0.3)[0]
        self.layer_norm1 = nn.LayerNorm(128)
        self.layer_norm2 = nn.LayerNorm(256)
        #layer 1
        #the first graph attention layer which will create 3*embed size embeddings for each node. This will also take care of all the message passing and aggregation
        self.conv1 = GATConv(128, 128, heads=3, dropout = 0.3)
        #reduce the dimensionality back
        self.head_transform1 = Linear(128*3, 128)
        
        #layer 2
        self.conv2 = GATConv(128, 128, heads=3, dropout = 0.3)
        self.head_transform2 = Linear(128*3, 128)

        #layer 3
        self.conv3 = GATConv(128, 256, heads=3, dropout = 0.3)
        self.head_transform3 = Linear(256*3, 256)
        
        #layer 4
        self.conv4 = GATConv(256, 256, heads=3, dropout = 0.3)
        self.head_transform4 = Linear(256*3, 256)
        
        #linear layers - these need to be modified to match the output size? Or maybe not
        self.linear2 = Linear(128, 64) 
        self.linear3 = Linear(64, self.num_classes) #prediction for each class

    def forward(self, x, edge_index, batch):
        x = x 
        edge_index = edge_index
        batch = batch 
        # downsize the embeddings
        x = self.linear1(x).relu()
        
        #block 1 
        x = self.conv1(x, edge_index) #this is does all the aggregation and message passing
        x = self.head_transform1(x)       
        x = self.layer_norm1(x)
        #block 2
#         edge_index = self.drop_edge(edge_index)
        x = self.conv2(x, edge_index) 
        x = self.head_transform2(x)      
        x = self.layer_norm1(x)

#         #block 3
# #         edge_index = self.drop_edge(edge_index)
#         x = self.conv3(x, edge_index) #this is does all the aggregation and message passing
#         x = self.head_transform3(x)   
#         x = self.layer_norm2(x)

#         #block 4
# #         edge_index = self.drop_edge(edge_index)
#         x = self.conv4(x, edge_index) 
#         x = self.head_transform4(x)   
        
        #output block 
        x = self.linear2(x).relu()
        x = F.dropout(x, p=0.5)
        x = self.linear3(x)
        
#         print("Inside model, after all computations", torch.cuda.memory_summary(device=None, abbreviated=False)) #bulk of the memory is used here, somehow

        return x

In [10]:
class sage_GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(sage_GNN, self).__init__()
        
        self.num_classes = 2 #scc or normal
        self.embedding_size = 2048 # this is what we want the embedding to be
        
        self.linear1 = Linear(2048, 512)

        #define the GNN layers 
    
        #layer 1
        self.conv1 = SAGEConv(512, 128)
        
        #layer 2
        self.conv2 = SAGEConv(128, 128)
           
        #layer 3
        self.conv3 = SAGEConv(128,128)

        self.linear2 = Linear(128, 64) 
        self.linear3 = Linear(64, 2)

    def forward(self, x, edge_index, batch):
        x = self.linear1(x).relu()
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)

        return x

In [11]:
class simple_NN(torch.nn.Module):
    def __init__(self, embedding_size):
        super(simple_NN, self).__init__()
        

        self.linear = Linear(2048, 512) 
        self.linear2 = Linear(512, 64)
        self.linear3 = Linear(64, 2)

    def forward(self, x, edge_index, batch):
        
        x = F.relu(self.linear(x))
        x = F.relu(self.linear2(x))
        x = F.dropout(x, p=0.5)
        x = self.linear3(x)

        return x

# Dataloader

In [12]:
# num_features = 2048

# # device_ids = [0, 1]
# model = simple_GNN(2048)
# # model= nn.DataParallel(model, device_ids = device_ids)
# model = model.to(device)

In [13]:
# model

In [14]:
# #loss and optimizer 
# import torch.optim as optim
# import torchvision.ops.focal_loss

# # loss_fn = torchvision.ops.focal_loss.sigmoid_focal_loss
# loss_fn = nn.CrossEntropyLoss()
# opt = optim.Adam(model.parameters(), lr=1e-7)

In [15]:
# #prepare training 
# from torch_geometric.data import DataLoader

# # train_loader = DataLoader(training_data, batch_size=1, shuffle=True)
# # val_loader = DataLoader(val_data, batch_size=1, shuffle=True)
# # test_loader = DataLoader(testing_data, batch_size=1, shuffle=True)

# # num_epochs = 500

In [16]:
# from sklearn.metrics import roc_auc_score

# softmax = nn.Softmax(dim=1)

# Inductive Model Training on Gokul Data

In [17]:
# for epoch in range(num_epochs):
#     #training portion
#     model.train()
#     epoch_loss = []
#     for data in tqdm(train_loader):
#         #get graph and the relevant stuff
#         graph = data[0]
#         x = data[0].x
#         edge_index = data[0].edge_index
#         y = data[0].y
        
#         #move to device
#         x = x.to(device)
#         edge_index = edge_index.to(device)
#         y = y.to(device)

#         #get predictions 
#         logits = model(x, edge_index)
#         loss = loss_fn(logits, y) #for CE

#         epoch_loss.append(loss.item())

#         opt.zero_grad()
#         loss.backward()
#         opt.step()
#     #now find the average training loss for this epoch 
#     epoch_loss = sum(epoch_loss)/len(epoch_loss)
#     print("Epoch :%d. Epoch loss: %f" %(epoch, epoch_loss))    
#     #validation portion
#     validation_correct = 0
#     validation_total = 0
#     model.eval()
#     with torch.no_grad():
#         for data in tqdm(val_loader):
#             #get graph
#             graph = data[0]
#             x = data[0].x
#             edge_index = data[0].edge_index
#             y = data[0].y
        
#             #move to device
#             x = x.to(device)
#             edge_index = edge_index.to(device)
#             y = y.to(device)

#             #get predictions 
#             logits = model(x, edge_index)
#             #get them into label predictions
#             _, indices = torch.max(logits, dim=1)
# #             print(indices)
#             validation_correct += sum(indices == y).item()
#             validation_total += len(y)
# #             print("Accuracy on this graph's val set", sum(indices == y).item()/len(y))
# #             print("SCC percent", sum(y == 1).item()/len(y))
    
#     print("Epoch :%d. Validation accuracy: %f" %(epoch, validation_correct/validation_total))

In [18]:
#  #test portion
# test_correct = 0
# test_total = 0
# model.eval()
# with torch.no_grad():
#     for data in tqdm(data_loader):
#         #get graph
#         graph = data[0]
#         x = graph.x 
#         edge_index = graph.edge_index
#         y = graph.y 
#         #move to device
#         x = x.to(device)
#         edge_index = edge_index.to(device)
#         y = y.to(device)
#         #get masks
#         test_mask = data[3].T.reshape([data[3].T.shape[0]])

#         #get predictions 
#         logits = model(x, edge_index)
#         #get them into label predictions
#         _, indices = torch.max(logits, dim=1)
#         print(1 in indices)
#         test_correct += sum(indices[test_mask] == y[test_mask]).item()
#         test_total += sum(test_mask == True).item()

# print("Test accuracy: %f" %(test_correct/test_total))

# Training With Sophie Data
- Here, use inductive training 
- split the ids themselves into different categories 

In [19]:
sophie_data = pd.read_pickle("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/graph_dataset/graph_dataset_modified.pkl")

In [20]:
train_ids = []
test_ids = []
val_ids = []

for id in sophie_data:
    if "train" in sophie_data[id]:
        train_ids.append(id)
    elif "test" in sophie_data[id]:
        test_ids.append(id)
    elif "val" in sophie_data[id]:
        val_ids.append(id)

In [21]:
save_dir = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/Graph_Data/"

train_dataset = []
test_dataset = []
val_dataset = []

In [22]:
for id in tqdm(train_ids):
    train_dataset.append(torch.load(save_dir + id +".pt"))
    
for id in tqdm(test_ids):
    test_dataset.append(torch.load(save_dir + id +".pt"))

for id in tqdm(val_ids):
    val_dataset.append(torch.load(save_dir + id +".pt"))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:06<00:00, 10.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:01<00:00, 11.90it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00, 10.58it/s]


In [23]:
train_dataset[0]

Data(x=[41295, 2048], edge_index=[2, 326540], y=[41295])

In [24]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)



In [25]:
print(len(train_loader), len(val_loader), len(test_loader))
num_epochs = 200

62 18 15


In [39]:
num_features = 2048

# device_ids = [0, 1]
model = simple_GNN(2048)
# model= nn.DataParallel(model, device_ids = device_ids)
model = model.to(device)

#loss and optimizer 
# loss_fn = torchvision.ops.focal_loss.sigmoid_focal_loss
loss_fn = nn.CrossEntropyLoss()

#optim and scheduler
opt = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor = 0.1, patience = 5, verbose = True)
softmax = nn.Softmax(dim=1)

In [40]:
lossess = []
validation_aucs = []

In [28]:
from sklearn.metrics import roc_auc_score

# code from: https://stackoverflow.com/questions/39685740/calculate-sklearn-roc-auc-score-for-multi-class

def roc_auc_score_multiclass(actual_class, pred_class, average = "macro"):

    #creating a set of all the unique classes using the actual class list
    unique_class = set([0,1,2])
    roc_auc_dict = {}

    for per_class in tqdm(unique_class):
        #creating a list of all the classes except the current class 
        other_class = [x for x in unique_class if x != per_class]

        #marking the current class as 1 and all other classes as 0
        new_actual_class = [0 if x in other_class else 1 for x in actual_class]
        new_pred_class = [0 if x in other_class else 1 for x in pred_class]

        #using the sklearn metrics method to calculate the roc_auc_score
        roc_auc = roc_auc_score(new_actual_class, new_pred_class, average = average)
        roc_auc_dict[per_class] = roc_auc

    return roc_auc_dict

# Experiment notes
- So 2 layers seems to work well with fairly small embedding sizes
- increasing embed size doesn't seem to improve performance 
- 1e-5 works well with CE
- simple NN baseline easily achieves .70 AUC
- So far, best performance has been found with layer size = 256, and 3 layers

In [None]:
for epoch in tqdm(range(num_epochs)):
    #training portion
    model.train()
    epoch_loss = []
    for data in train_loader:
        #get graph and the relevant stuff
        graph = data
        x = data.x
        edge_index = data.edge_index
        y = data.y
        batch = data.batch 
        
        #move to device
        x = x.to(device)
        edge_index = edge_index.to(device)
        y = y.to(device)
        batch = batch.to(device)
        #get predictions 
        logits = model(x, edge_index, batch) #for CE - CE takes logics 
#         scores = softmax(model(x, edge_index, batch))[:, 1] # for FL - takes the class prob
        loss = loss_fn(logits, y) #for CE
#         loss = loss_fn(scores, y.float()).sum() #for focal loss
        epoch_loss.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()
    
    #now find the average training loss for this epoch 
    epoch_loss = sum(epoch_loss)/len(epoch_loss)
    lossess.append(epoch_loss) #append to master array 
    print("Epoch :%d. Epoch loss: %f" %(epoch, epoch_loss)) 
    
    scheduler.step(epoch_loss) #show the scheduler the epoch loss and adjust lr accordingly
    
    
    #model in test mode 
    if epoch % 10 == 0:
        model.eval()
        predictions = torch.Tensor([])
        ground_truth = torch.Tensor([])
        #here, we can collect the AUC for each class
        with torch.no_grad():
            for data in val_loader:
                #get graph
                graph = data
                x = graph.x 
                edge_index = graph.edge_index
                y = graph.y 
                batch = graph.batch

                #move to device
                x = x.to(device)
                edge_index = edge_index.to(device)
                y = y.to(device)
                batch = batch.to(device)

                #find the probs
                scores = softmax(model(x, edge_index, batch))
                scores = torch.argmax(scores, dim=1) #transform into indices

                #move to cpu
                scores = scores.detach().cpu()
                y = y.detach().cpu()

                #concat them 
                predictions = torch.cat((predictions, scores))
                ground_truth = torch.cat((ground_truth, y))

        aucs = roc_auc_score_multiclass(ground_truth, predictions)
        validation_aucs.append(aucs) #add these aucs
        print("Epoch : "+ str(epoch))
        print(aucs)

  0%|                                                                                                                                                                               | 0/500 [00:00<?, ?it/s]

Epoch :0. Epoch loss: 0.836680



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.85s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.33s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.15s/it][A
  0%|▎                                                                                                                                                                 

Epoch : 0
{0: 0.528258794409841, 1: 0.5334483934406621, 2: 0.49987062564022583}


  0%|▋                                                                                                                                                                    | 2/500 [00:24<1:29:19, 10.76s/it]

Epoch :1. Epoch loss: 0.677719


  1%|▉                                                                                                                                                                    | 3/500 [00:28<1:03:08,  7.62s/it]

Epoch :2. Epoch loss: 0.621884


  1%|█▎                                                                                                                                                                     | 4/500 [00:32<50:47,  6.14s/it]

Epoch :3. Epoch loss: 0.591677


  1%|█▋                                                                                                                                                                     | 5/500 [00:36<43:50,  5.31s/it]

Epoch :4. Epoch loss: 0.571721


  1%|██                                                                                                                                                                     | 6/500 [00:39<39:45,  4.83s/it]

Epoch :5. Epoch loss: 0.557661


  1%|██▎                                                                                                                                                                    | 7/500 [00:43<37:24,  4.55s/it]

Epoch :6. Epoch loss: 0.546715


  2%|██▋                                                                                                                                                                    | 8/500 [00:47<35:28,  4.33s/it]

Epoch :7. Epoch loss: 0.538385


  2%|███                                                                                                                                                                    | 9/500 [00:51<34:04,  4.16s/it]

Epoch :8. Epoch loss: 0.530673


  2%|███▎                                                                                                                                                                  | 10/500 [00:55<33:32,  4.11s/it]

Epoch :9. Epoch loss: 0.523861
Epoch :10. Epoch loss: 0.517765



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.88s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.49s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.28s/it][A
  2%|███▌                                                                                                                                                              

Epoch : 10
{0: 0.7617424886495542, 1: 0.8137432294783499, 2: 0.5}


  2%|███▉                                                                                                                                                                | 12/500 [01:20<1:02:21,  7.67s/it]

Epoch :11. Epoch loss: 0.511730


  3%|████▎                                                                                                                                                                 | 13/500 [01:24<52:47,  6.50s/it]

Epoch :12. Epoch loss: 0.506705


  3%|████▋                                                                                                                                                                 | 14/500 [01:28<46:24,  5.73s/it]

Epoch :13. Epoch loss: 0.502346


  3%|████▉                                                                                                                                                                 | 15/500 [01:31<41:30,  5.14s/it]

Epoch :14. Epoch loss: 0.497730


  3%|█████▎                                                                                                                                                                | 16/500 [01:35<38:01,  4.71s/it]

Epoch :15. Epoch loss: 0.494202


  3%|█████▋                                                                                                                                                                | 17/500 [01:39<35:53,  4.46s/it]

Epoch :16. Epoch loss: 0.489540


  4%|█████▉                                                                                                                                                                | 18/500 [01:43<34:23,  4.28s/it]

Epoch :17. Epoch loss: 0.485661


  4%|██████▎                                                                                                                                                               | 19/500 [01:47<33:03,  4.12s/it]

Epoch :18. Epoch loss: 0.482000


  4%|██████▋                                                                                                                                                               | 20/500 [01:50<32:13,  4.03s/it]

Epoch :19. Epoch loss: 0.478611
Epoch :20. Epoch loss: 0.474859



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.64s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.35s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.15s/it][A
  4%|██████▉                                                                                                                                                           

Epoch : 20
{0: 0.7770867702557109, 1: 0.8330991965587675, 2: 0.5010012278694805}


  4%|███████▎                                                                                                                                                              | 22/500 [02:15<59:07,  7.42s/it]

Epoch :21. Epoch loss: 0.472181


  5%|███████▋                                                                                                                                                              | 23/500 [02:19<50:42,  6.38s/it]

Epoch :22. Epoch loss: 0.468660


  5%|███████▉                                                                                                                                                              | 24/500 [02:23<44:31,  5.61s/it]

Epoch :23. Epoch loss: 0.466067


  5%|████████▎                                                                                                                                                             | 25/500 [02:26<39:51,  5.04s/it]

Epoch :24. Epoch loss: 0.462461


  5%|████████▋                                                                                                                                                             | 26/500 [02:30<36:49,  4.66s/it]

Epoch :25. Epoch loss: 0.459882


  5%|████████▉                                                                                                                                                             | 27/500 [02:34<35:03,  4.45s/it]

Epoch :26. Epoch loss: 0.457295


  6%|█████████▎                                                                                                                                                            | 28/500 [02:38<33:51,  4.31s/it]

Epoch :27. Epoch loss: 0.454908


  6%|█████████▋                                                                                                                                                            | 29/500 [02:42<32:31,  4.14s/it]

Epoch :28. Epoch loss: 0.451995


  6%|█████████▉                                                                                                                                                            | 30/500 [02:46<31:44,  4.05s/it]

Epoch :29. Epoch loss: 0.449421
Epoch :30. Epoch loss: 0.447202



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.58s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.37s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.19s/it][A
  6%|██████████▏                                                                                                                                                       

Epoch : 30
{0: 0.7842373511650612, 1: 0.8362038282937058, 2: 0.5268075171945424}


  6%|██████████▌                                                                                                                                                           | 32/500 [03:10<57:47,  7.41s/it]

Epoch :31. Epoch loss: 0.444948


  7%|██████████▉                                                                                                                                                           | 33/500 [03:14<49:35,  6.37s/it]

Epoch :32. Epoch loss: 0.442547


  7%|███████████▎                                                                                                                                                          | 34/500 [03:18<43:41,  5.63s/it]

Epoch :33. Epoch loss: 0.440849


  7%|███████████▌                                                                                                                                                          | 35/500 [03:21<39:11,  5.06s/it]

Epoch :34. Epoch loss: 0.438142


  7%|███████████▉                                                                                                                                                          | 36/500 [03:25<36:10,  4.68s/it]

Epoch :35. Epoch loss: 0.436413


  7%|████████████▎                                                                                                                                                         | 37/500 [03:29<34:28,  4.47s/it]

Epoch :36. Epoch loss: 0.434512


  8%|████████████▌                                                                                                                                                         | 38/500 [03:33<33:08,  4.30s/it]

Epoch :37. Epoch loss: 0.432398


  8%|████████████▉                                                                                                                                                         | 39/500 [03:37<31:51,  4.15s/it]

Epoch :38. Epoch loss: 0.430854


  8%|█████████████▎                                                                                                                                                        | 40/500 [03:41<31:17,  4.08s/it]

Epoch :39. Epoch loss: 0.428869
Epoch :40. Epoch loss: 0.426534



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.74s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.35s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.20s/it][A
  8%|█████████████▍                                                                                                                                                    

Epoch : 40
{0: 0.7895628105155985, 1: 0.8369310827416582, 2: 0.5536559360087847}


  8%|█████████████▉                                                                                                                                                        | 42/500 [04:05<57:45,  7.57s/it]

Epoch :41. Epoch loss: 0.424547


  9%|██████████████▎                                                                                                                                                       | 43/500 [04:09<48:59,  6.43s/it]

Epoch :42. Epoch loss: 0.422301


  9%|██████████████▌                                                                                                                                                       | 44/500 [04:13<42:51,  5.64s/it]

Epoch :43. Epoch loss: 0.419945


  9%|██████████████▉                                                                                                                                                       | 45/500 [04:17<38:47,  5.11s/it]

Epoch :44. Epoch loss: 0.418980


  9%|███████████████▎                                                                                                                                                      | 46/500 [04:21<35:59,  4.76s/it]

Epoch :45. Epoch loss: 0.417090


  9%|███████████████▌                                                                                                                                                      | 47/500 [04:25<33:55,  4.49s/it]

Epoch :46. Epoch loss: 0.415189


 10%|███████████████▉                                                                                                                                                      | 48/500 [04:29<32:18,  4.29s/it]

Epoch :47. Epoch loss: 0.413410


 10%|████████████████▎                                                                                                                                                     | 49/500 [04:32<31:29,  4.19s/it]

Epoch :48. Epoch loss: 0.412212


 10%|████████████████▌                                                                                                                                                     | 50/500 [04:36<30:42,  4.09s/it]

Epoch :49. Epoch loss: 0.410288
Epoch :50. Epoch loss: 0.408859



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.90s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.36s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.21s/it][A
 10%|████████████████▋                                                                                                                                                 

Epoch : 50
{0: 0.7867774756247603, 1: 0.8315516720581189, 2: 0.5645470173841974}


 10%|█████████████████▎                                                                                                                                                    | 52/500 [05:01<56:24,  7.56s/it]

Epoch :51. Epoch loss: 0.407409


 11%|█████████████████▌                                                                                                                                                    | 53/500 [05:05<47:49,  6.42s/it]

Epoch :52. Epoch loss: 0.405213


 11%|█████████████████▉                                                                                                                                                    | 54/500 [05:09<41:58,  5.65s/it]

Epoch :53. Epoch loss: 0.403960


 11%|██████████████████▎                                                                                                                                                   | 55/500 [05:13<38:10,  5.15s/it]

Epoch :54. Epoch loss: 0.402375


 11%|██████████████████▌                                                                                                                                                   | 56/500 [05:16<35:19,  4.77s/it]

Epoch :55. Epoch loss: 0.400618


 11%|██████████████████▉                                                                                                                                                   | 57/500 [05:20<33:05,  4.48s/it]

Epoch :56. Epoch loss: 0.398975


 12%|███████████████████▎                                                                                                                                                  | 58/500 [05:24<31:54,  4.33s/it]

Epoch :57. Epoch loss: 0.397576


 12%|███████████████████▌                                                                                                                                                  | 59/500 [05:28<30:56,  4.21s/it]

Epoch :58. Epoch loss: 0.397163


 12%|███████████████████▉                                                                                                                                                  | 60/500 [05:32<29:53,  4.08s/it]

Epoch :59. Epoch loss: 0.395036
Epoch :60. Epoch loss: 0.393430



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.94s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.48s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.29s/it][A
 12%|████████████████████                                                                                                                                              

Epoch : 60
{0: 0.7799482853184118, 1: 0.8232923058285146, 2: 0.5689798804852056}


 12%|████████████████████▌                                                                                                                                                 | 62/500 [05:57<55:01,  7.54s/it]

Epoch :61. Epoch loss: 0.391840


 13%|████████████████████▉                                                                                                                                                 | 63/500 [06:01<46:47,  6.42s/it]

Epoch :62. Epoch loss: 0.390225


 13%|█████████████████████▏                                                                                                                                                | 64/500 [06:05<41:32,  5.72s/it]

Epoch :63. Epoch loss: 0.388090


 13%|█████████████████████▌                                                                                                                                                | 65/500 [06:08<37:27,  5.17s/it]

Epoch :64. Epoch loss: 0.386563


 13%|█████████████████████▉                                                                                                                                                | 66/500 [06:12<34:33,  4.78s/it]

Epoch :65. Epoch loss: 0.385127


 13%|██████████████████████▏                                                                                                                                               | 67/500 [06:16<32:40,  4.53s/it]

Epoch :66. Epoch loss: 0.383103


 14%|██████████████████████▌                                                                                                                                               | 68/500 [06:20<31:12,  4.34s/it]

Epoch :67. Epoch loss: 0.381299


 14%|██████████████████████▉                                                                                                                                               | 69/500 [06:24<29:54,  4.16s/it]

Epoch :68. Epoch loss: 0.380339


 14%|███████████████████████▏                                                                                                                                              | 70/500 [06:28<29:12,  4.08s/it]

Epoch :69. Epoch loss: 0.378559
Epoch :70. Epoch loss: 0.377115



  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s][A
 33%|████████████████████████████████████████████████████████▎                                                                                                                | 1/3 [00:06<00:13,  6.74s/it][A
 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                        | 2/3 [00:11<00:05,  5.40s/it][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.23s/it][A
 14%|███████████████████████▎                                                                                                                                          

Epoch : 70
{0: 0.7740583709909353, 1: 0.8157905938650454, 2: 0.5716674041504111}


 14%|███████████████████████▉                                                                                                                                              | 72/500 [06:52<53:31,  7.50s/it]

Epoch :71. Epoch loss: 0.375283


In [None]:
#test portion - still need to modify
    
predictions = torch.Tensor([])
ground_truth = torch.Tensor([])

model.eval()
with torch.no_grad():
    for data in tqdm(test_loader):
        #get graph
        graph = data
        x = graph.x 
        edge_index = graph.edge_index
        y = graph.y 
        batch = graph.batch
        #move to device
        x = x.to(device)
        edge_index = edge_index.to(device)
        y = y.to(device)
        batch = batch.to(device)
        #find the probs
        scores = softmax(model(x, edge_index, batch))
        
        #move to cpu
        scores = scores.detach().cpu()
        y = y.detach().cpu()
        
        #concat them 
        probabilities = torch.cat((probabilities, scores))
        ground_truth = torch.cat((ground_truth, y))
  

In [None]:
probabilities

In [None]:
#predict the whole test cohort AUC-ROC

roc_auc_score(ground_truth, probabilities[:, 1])

In [None]:
#from sophie's code - viz. the curve 
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

# fpr and tpr of all thresohlds
true = ground_truth
preds = probabilities[:, 1]
fpr, tpr, threshold = metrics.roc_curve(true, preds)

#get the metrics 
roc_auc = metrics.auc(fpr, tpr)

#plot
plt.title('Test Cohort-wide AUC-ROC')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()

In [None]:
!nvidia-smi

In [None]:
torch.cuda.device_count()