In [1]:
import pandas as pd
import numpy as np
import torch 
from tqdm import tqdm 
from sklearn.model_selection import train_test_split
import glob, os, pickle
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda')

# Relationship Between Patches in a WSI 
- Need to figure out how to turn a WSI into a graph, and how to know the relationships between different patches
- If the x,y coords represent the center of each patch, I suppose what we can say is that adjacent patches need to be sqrt(2x^2) away if they are truly neighboring patches. 
- Apparently, two nodes might be neighbors if their embeddings are sufficiently similar. 
- Josh says that if patches are sqrt(2)*256 away from each other, then we should consider them neighbors 

In [4]:
df = pd.read_pickle("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/scc_tumor_data/prelim_patch_info_v2/109_A1c_ASAP_tumor_map.pkl")

In [5]:
df

Unnamed: 0,ID,x,y,patch_size,annotation,y_true,inflamm,scc
0,109_A1c_ASAP,1024,16640,256,0,0,0,0
1,109_A1c_ASAP,1280,15872,256,0,0,0,0
2,109_A1c_ASAP,1280,16128,256,0,0,0,0
3,109_A1c_ASAP,1280,16384,256,0,0,0,0
4,109_A1c_ASAP,1280,16640,256,0,0,0,0
...,...,...,...,...,...,...,...,...
41290,109_A1c_ASAP,98560,25600,256,0,0,0,0
41291,109_A1c_ASAP,98560,25856,256,0,0,0,0
41292,109_A1c_ASAP,98560,26112,256,0,0,0,0
41293,109_A1c_ASAP,98560,26368,256,0,0,0,0


In [6]:
#create node -> (x,y) map 
pos_map = {}

In [7]:
for index, row in tqdm(df.iterrows()):
    pos_map[index] = (row["x"], row["y"])

41295it [00:01, 30492.68it/s]


# Create edge adj. list
- Here we will use a library because n^2 is too slow
- I am relying on the standard indexing in the df. Meaning, the first row should be patch 1, the second should be patch 2, and so on... 

In [8]:
import dgl
import torch

In [9]:
nodes = torch.tensor([])

In [10]:
for i in tqdm(pos_map):
    x = torch.tensor([[float(pos_map[i][0]), float(pos_map[i][1])]])
    nodes = torch.cat((nodes, x), 0)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41295/41295 [00:00<00:00, 43656.97it/s]


In [11]:
len(nodes)

41295

In [12]:
# create graph where points sqrt(2)*256 away from each other are considered neighbors 
r_g, dist = dgl.radius_graph(nodes, 256*(2**(1/2)), get_distances=True) 

In [13]:
r_g.edges()

(tensor([    0,     0,     0,  ..., 41294, 41294, 41294]),
 tensor([    3,     4,     5,  ..., 41229, 41230, 41293]))

In [14]:
adj_list = list(zip(list(r_g.edges()[0]), list(r_g.edges()[1])))  #this is the adj list 

In [15]:
for i in tqdm(range(len(adj_list))): #converting everything from tensors to ints 
    adj_list[i] = [adj_list[i][0].item(), adj_list[i][1].item()]

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 326540/326540 [00:00<00:00, 709379.03it/s]


In [16]:
adj_list

[[0, 3],
 [0, 4],
 [0, 5],
 [1, 2],
 [1, 12],
 [1, 13],
 [1, 14],
 [2, 1],
 [2, 3],
 [2, 13],
 [2, 14],
 [2, 15],
 [3, 0],
 [3, 2],
 [3, 4],
 [3, 14],
 [3, 15],
 [3, 16],
 [4, 0],
 [4, 3],
 [4, 5],
 [4, 15],
 [4, 16],
 [4, 17],
 [5, 0],
 [5, 4],
 [5, 6],
 [5, 16],
 [5, 17],
 [5, 18],
 [6, 5],
 [6, 7],
 [6, 17],
 [6, 18],
 [6, 19],
 [7, 6],
 [7, 8],
 [7, 18],
 [7, 19],
 [7, 20],
 [8, 7],
 [8, 9],
 [8, 19],
 [8, 20],
 [8, 21],
 [9, 8],
 [9, 10],
 [9, 20],
 [9, 21],
 [9, 22],
 [10, 9],
 [10, 21],
 [10, 22],
 [10, 23],
 [11, 12],
 [11, 31],
 [11, 32],
 [11, 33],
 [12, 1],
 [12, 11],
 [12, 13],
 [12, 32],
 [12, 33],
 [12, 34],
 [13, 1],
 [13, 2],
 [13, 12],
 [13, 14],
 [13, 33],
 [13, 34],
 [13, 35],
 [14, 1],
 [14, 2],
 [14, 3],
 [14, 13],
 [14, 15],
 [14, 34],
 [14, 35],
 [14, 36],
 [15, 2],
 [15, 3],
 [15, 4],
 [15, 14],
 [15, 16],
 [15, 35],
 [15, 36],
 [15, 37],
 [16, 3],
 [16, 4],
 [16, 5],
 [16, 15],
 [16, 17],
 [16, 36],
 [16, 37],
 [16, 38],
 [17, 4],
 [17, 5],
 [17, 6],
 [17, 16],

In [17]:
adj_list = torch.tensor(adj_list)

In [18]:
adj_list.shape

torch.Size([326540, 2])

# Create a graph data object 
- Need to use torch.geometric here
- Also need to define the model class
- Also need to get embeddings for each patch here 
- Also need the y matrix
- We need to save all graph data objects in a seperate directory, and get their file location and map them to the meta file
- The idea is that we will save all of this to a df that will contain columns= ["sample_id", "file_loc"] 

In [19]:
#define the model class 
model = torch.hub.load('pytorch/vision', 'resnet50')
model.fc = nn.Sequential(nn.Linear(2048, 100), 
                         nn.ReLU(), 
                         nn.Dropout(p=.5), 
                         nn.Linear(100,2))

Using cache found in /dartfs-hpc/rc/home/9/f003xr9/.cache/torch/hub/pytorch_vision_main


In [20]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [21]:
#load the best model 
model_path = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/Saved_Models/resnet50.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [22]:
#modify the model so that the output is the embedding layer 
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

model.fc = Identity() # remove the fc layer 
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [23]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [24]:
#get the patches
path = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/scc_tumor_data/prelim_patch_info/109_A1c_ASAP_tumor_map.npy"
arr = np.load(path)

In [25]:
embed_dic = {} #map the patch_id -> embedding
patches = []

In [26]:
for patch_id in tqdm(range(0, len(arr))):
    patch = preprocess(arr[patch_id])
    patches.append((patch_id, patch))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41295/41295 [00:13<00:00, 3063.88it/s]


In [27]:
patches_loader = DataLoader(dataset = patches, batch_size = 1)

In [None]:
with torch.no_grad():
    for idx, patch in tqdm(patches_loader):
        patch = patch.to(device=device)
        embed = model(patch)
        embed_dic[idx.item()] = embed.detach().cpu().tolist()[0]

 13%|████████████████████▋                                                                                                                                             | 5275/41295 [00:57<06:16, 95.57it/s]

In [None]:
len(embed_dic)
embeds = []

In [None]:
for patch_idx in embed_dic:
    embeds.append(embed_dic[patch_idx])

In [None]:
embeds = torch.tensor(embeds)

In [None]:
embeds.shape

In [None]:
adj_list = adj_list.T

In [None]:
y = torch.tensor(list(df["scc"])) # this is the scc for each patch 

In [None]:
y.shape

In [None]:
from torch_geometric.data import Data

In [None]:
data = Data(x=embeds, edge_index=adj_list, y=y)

In [None]:
#see if you can save this object 
torch.save(data, "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/109_A1c_ASAP.pt")

In [None]:
#see if you can load it 
recovered_data = torch.load("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/109_A1c_ASAP.pt")

In [None]:
recovered_data

In [None]:
#combine all of this into a function

#takes: df of WSI, np array of WSI, and model 
#returns: saves the graph data in a directory 
def create_graph(df, arr, model, preprocess, device, save_dir, sample_id):
    
    #create node -> (x,y) map 
    pos_map = {}
    for index, row in tqdm(df.iterrows()):
        pos_map[index] = (row["x"], row["y"])
        
    #use these nodes, which are made sequentially, to create a graph and eventually and edge list
    nodes = torch.tensor([])
    for i in tqdm(pos_map):
        x = torch.tensor([[float(pos_map[i][0]), float(pos_map[i][1])]])
        nodes = torch.cat((nodes, x), 0)
    
    # create graph where points sqrt(2)*256 away from each other are considered neighbors 
    r_g, dist = dgl.radius_graph(nodes, 256*(2**(1/2)), get_distances=True) 
    
    #get the adj_list
    adj_list = list(zip(list(r_g.edges()[0]), list(r_g.edges()[1])))  #this is the adj list 
    for i in tqdm(range(len(adj_list))): #converting everything from tensors to ints 
        adj_list[i] = [adj_list[i][0].item(), adj_list[i][1].item()]
        
    #make it a tensor
    adj_list = torch.tensor(adj_list).T
    
    #now, create embeddings for all of the patches within the WSI 
    embed_dic = {} #map the patch_id -> embedding
    patches = []
    
    for patch_id in tqdm(range(0, len(arr))): #get (idx, patch array)
        patch = preprocess(arr[patch_id])
        patches.append((patch_id, patch))

    patches_loader = DataLoader(dataset = patches, batch_size = 1)

    with torch.no_grad():
        for idx, patch in tqdm(patches_loader): # get the embeddings here 
            patch = patch.to(device=device)
            embed = model(patch)
            embed_dic[idx.item()] = embed.detach().cpu().tolist()[0]
    
    #now create an array for these embeddings 
    embeds = []
    for patch_idx in embed_dic:
        embeds.append(embed_dic[patch_idx])
    embeds = torch.tensor(embeds)
    
    #get the SCC array 
    y = torch.tensor(list(df["scc"])) # this is the scc for each patch 
    
    #make the graph data object 
    data = Data(x=embeds, edge_index=adj_list, y=y)
    
    #save this object 
    print(torch.save(data, save_dir+sample_id+".pt"))

In [None]:
#test function 
save_dir = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/"
sample_id = "109_A1c_ASAP"

create_graph(df = df, arr= arr, model = model, preprocess=preprocess, device=device, save_dir = save_dir, sample_id = sample_id)

# Create & Save Graphs From All of the WSI 

In [61]:
metadata = pd.DataFrame()
#here we will store the sample ids and their paths 
samples = []
paths = []

In [62]:
#now create an np array containing all of the included patches from the relevant tumor maps
parent_dir = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/scc_tumor_data/prelim_patch_info/"
save_dir = "/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Gokul_Srinivasan/SCC-Tumor-Detection/Gokul_files/graph_data/"

In [63]:
#iterate through all of the WSI samples
for f in tqdm(os.listdir(parent_dir)):
    id, ext = f.split(".")
    id = id[0:id.find("tumor")-1]
    print(id)
#     #if it is the npy file
#     if ext == "npy" and id in ids:
#         #get the np array 
#         arr = np.load(parent_dir+f)
#         #here, we can basically save each persons image in a seperate directory
#         sub_dir = id + "/"
#         data = [] # collect all a persons patch level data
#         for i in range(0, arr.shape[0]):
#             data.append(arr[i])
#         n_data = np.array(data)
#         #make the new dir if it doesn't exist
#         path = save_dir + sub_dir
#         os.mkdir(path)
#         np.save(save_dir + sub_dir + "data.npy", n_data) #save that persons data in their own folder
        

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 190/190 [00:00<00:00, 26594.07it/s]

365_A1b_ASAP
109_A1c_ASAP
354_A1d_ASAP
354_A3c_ASAP
109_A1c_ASAP
354_A3a_ASAP
14_A2b_ASAP
364_A1b_ASAP
367_A2b_ASAP
356_A1b_ASAP
364_A1b_ASAP
363_A1b_ASAP
14_A1b_ASAP
344_b_ASAP
369_A2b_ASAP
351_A2b_ASAP
358_A1a_ASAP
70_A2b_ASAP
365_A1b_ASAP
70_A2b_ASAP
361_a_ASAP
345_b_ASAP
344_b_ASAP
112_a_ASAP
10_A1b_ASAP
350_A1d_ASAP
354_A3b_ASAP
12_A1c_ASAP
7_A1e_ASAP
10_A2b_ASAP
362_A1b_ASAP
363_A1c_ASAP
270_A2b_ASAP
352_A1d_ASAP
366_A1b_ASAP
350_A1a_ASAP
112_b_ASAP
112_b_ASAP
10_A2b_ASAP
353_A2b_ASAP
354_A1b_ASAP
344_a_ASAP
350_A1e_ASAP
343_b_ASAP
270_A1b_ASAP
341_b_ASAP
7_A1c_ASAP
368_A1d_ASAP
368_A1b_ASAP
110_A2b_ASAP
341_a_ASAP
363_A2b_ASAP
123_A1a_ASAP
327_A1d_ASAP
10_A1b_ASAP
354_A1c_ASAP
327_B1c_ASAP
370_A2a_ASAP
369_A1b_ASAP
358_A1b_ASAP
361_b_ASAP
270_A2b_ASAP
370_A2b_ASAP
370_A1b_ASAP
342_a_ASAP
362_A1c_ASAP
358_A1b_ASAP
352_A1g_ASAP
367_A2b_ASAP
366_A1c_ASAP
366_A1a_ASAP
354_D1b_ASAP
270_A1d_ASAP
10_A1a_ASAP
7_A1e_ASAP
355_A1d_ASAP
311_A2c_ASAP
327_A1a_ASAP
364_A2b_ASAP
362_A1a_ASAP
36


