In [1]:
import pandas as pd
import numpy as np
import pickle

import os
from PIL import Image

import torch
import torch.nn as nn
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision

from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tqdm.notebook import trange, tqdm

import random

from transformers import BertTokenizer, AutoTokenizer, AutoModel, BertModel

import matplotlib.pyplot as plt

from py_files.models import BertEncoder, ResnetPreTrained, ImageEncoder
from py_files.datasets import WSIBatchedDataset, GetRepsDataset

import faiss

In [2]:
def check_cuda():
    if torch.cuda.is_available():       
        device = torch.device("cuda")
        print(f'There are {torch.cuda.device_count()} GPU(s) available.')
        print('Device name:', torch.cuda.get_device_name(0))
        return device
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")
        return device
    
device=check_cuda()

There are 1 GPU(s) available.
Device name: Tesla K80


In [3]:
df = pd.read_csv('../df.csv')
df.head()

Unnamed: 0,patch_paths,pid,svs_paths,dtype,notes
0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
3,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
4,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...


In [4]:
pids = df.pid.unique()
len(pids)

518

In [5]:
pid_percent_cluster=0.1
n_pids_cluster = int(pid_percent_cluster*len(pids))
n_pids_cluster

51

In [6]:
pids_cluster = np.random.choice(pids, size=n_pids_cluster)

normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

cluster_dataset = GetRepsDataset(df, pids_cluster, transform)
cluster_loader = torch.utils.data.DataLoader(cluster_dataset,batch_size=64, shuffle=True, \
                                             num_workers=1, pin_memory=True)
len(cluster_loader.dataset)

6279

In [7]:
base_image_model = ResnetPreTrained()
image_encoder = ImageEncoder(base_image_model)

In [19]:
def train_cluster_model(image_encoder, cluster_loader, device):
    
    image_encoder = image_encoder.to(device)
    
    rep_list = []
    path_list = []
    
    image_encoder.eval()
    
    for img, path in tqdm(cluster_loader):

        img = img.to(device)    

        in_batch_size = img.shape[0]
        
        with torch.no_grad():
            reps = image_encoder(img)
        rep_list.append(reps.detach().detach().cpu().numpy().reshape(in_batch_size, -1))
        path_list += path
    
    X = np.concatenate(rep_list)
    X = np.ascontiguousarray(X)
    X = X.astype('float32')
    
    ncentroids = 8
    niter = 300
    verbose = False
    d = X.shape[1]
    kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose, nredo=20)
    kmeans.train(X)
    
    return kmeans

In [21]:
kmeans = train_cluster_model(image_encoder, cluster_loader, device)

  0%|          | 0/99 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...


In [4]:
def global_loss(cnn_code, rnn_code, eps=1e-8, temp3=10.0):

    batch_size = cnn_code.shape[0]
    labels = Variable(torch.LongTensor(range(batch_size))).to(cnn_code.device)

    if cnn_code.dim() == 2 :
        cnn_code = cnn_code.unsqueeze(0)
        rnn_code = rnn_code.unsqueeze(0)
        
    cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
    rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)

    scores0 = torch.bmm(cnn_code, rnn_code.transpose(1,2))
    norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
    scores0 = scores0 / norm0.clamp(min=eps) * temp3
    
    # --> batch_size x batch_size
    if scores0.shape[0]!=1:
        scores0 = scores0.squeeze()
    else:
        scores0 = scores0.squeeze(0)
    

    scores1 = scores0.transpose(0, 1)
    loss0 = nn.CrossEntropyLoss()(scores0, labels)
    loss1 = nn.CrossEntropyLoss()(scores1, labels)
    return loss0, loss1

def cosine_similarity(x1, x2, dim=1, eps=1e-8):
    """Returns cosine similarity between x1 and x2, computed along dim."""
    w12 = torch.sum(x1 * x2, dim)
    w1 = torch.norm(x1, 2, dim)
    w2 = torch.norm(x2, 2, dim)
    return (w12 / (w1 * w2).clamp(min=eps)).squeeze()

def attention_fn(query, context, temp1):
    """
    query: batch x ndf x queryL
    context: batch x ndf x ih x iw (sourceL=ihxiw)
    mask: batch_size x sourceL
    """
    batch_size, queryL = query.size(0), query.size(2)
    ih, iw = context.size(2), context.size(3)
    sourceL = ih * iw

    # --> batch x sourceL x ndf
    context = context.view(batch_size, -1, sourceL)
    contextT = torch.transpose(context, 1, 2).contiguous()

    # Get attention
    # (batch x sourceL x ndf)(batch x ndf x queryL)
    # -->batch x sourceL x queryL
    attn = torch.bmm(contextT, query)
    # --> batch*sourceL x queryL
    attn = attn.view(batch_size * sourceL, queryL)
    attn = nn.Softmax(dim=-1)(attn)

    # --> batch x sourceL x queryL
    attn = attn.view(batch_size, sourceL, queryL)
    # --> batch*queryL x sourceL
    attn = torch.transpose(attn, 1, 2).contiguous()
    attn = attn.view(batch_size * queryL, sourceL)

    attn = attn * temp1
    attn = nn.Softmax(dim=-1)(attn)
    attn = attn.view(batch_size, queryL, sourceL)
    # --> batch x sourceL x queryL
    attnT = torch.transpose(attn, 1, 2).contiguous()

    # (batch x ndf x sourceL)(batch x sourceL x queryL)
    # --> batch x ndf x queryL
    weightedContext = torch.bmm(context, attnT)

    return weightedContext, attn.view(batch_size, -1, ih, iw)

def local_loss(
    img_features, words_emb, cap_lens, temp1=4.0, temp2=5.0, temp3=10.0, agg="sum"
):

    batch_size = img_features.shape[0]

    att_maps = []
    similarities = []
    # cap_lens = cap_lens.data.tolist()
    for i in range(words_emb.shape[0]):

        # Get the i-th text description
        words_num = cap_lens[i]  # 25
        # TODO: remove [SEP]
        # word = words_emb[i, :, 1:words_num+1].unsqueeze(0).contiguous()    # [1, 768, 25]
        word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()  # [1, 768, 25]
        word = word.repeat(batch_size, 1, 1)  # [48, 768, 25]
        context = img_features  # [48, 768, 19, 19]

        weiContext, attn = attention_fn(
            word, context, temp1
        )  # [48, 768, 25], [48, 25, 19, 19]

        att_maps.append(
            attn[i].unsqueeze(0).contiguous()
        )  # add attention for curr index  [25, 19, 19]
        word = word.transpose(1, 2).contiguous()  # [48, 25, 768]
        weiContext = weiContext.transpose(1, 2).contiguous()  # [48, 25, 768]

        word = word.view(batch_size * words_num, -1)  # [1200, 768]
        weiContext = weiContext.view(batch_size * words_num, -1)  # [1200, 768]

        row_sim = cosine_similarity(word, weiContext)
        row_sim = row_sim.view(batch_size, words_num)  # [48, 25]

        row_sim.mul_(temp2).exp_()
        if agg == "sum":
            row_sim = row_sim.sum(dim=1, keepdim=True)  # [48, 1]
        else:
            row_sim = row_sim.mean(dim=1, keepdim=True)  # [48, 1]
        row_sim = torch.log(row_sim)

        similarities.append(row_sim)

    similarities = torch.cat(similarities, 1)  #
    similarities = similarities * temp3
    similarities1 = similarities.transpose(0, 1)  # [48, 48]

    labels = Variable(torch.LongTensor(range(batch_size))).to(similarities.device)

    loss0 = nn.CrossEntropyLoss()(similarities, labels)  # labels: arange(batch_size)
    loss1 = nn.CrossEntropyLoss()(similarities1, labels)
    return loss0, loss1, att_maps

In [5]:
# df = pd.read_pickle('../csv/generating_training_df.pickle')
# df.columns = ['patch_paths', 'pid', 'cluster_assignment', 'complete_tokens','dtype', 'notes']
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
wsi_batch_dataset = WSIBatchedDataset(df, dtype='train', tokenizer=tokenizer, img_transform=transform)
train_loader = torch.utils.data.DataLoader(wsi_batch_dataset,batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

NameError: name 'df' is not defined

In [6]:
img_model = ResnetPreTrained()
text_model = BertEncoder()

In [6]:
%%time
for i, (img, text, attention, token_typ, img_seps, pids) in enumerate(train_loader):
    
    img_seps = [0]+[i.numpy()[0] for i in img_seps]
    img_seps = [[img_seps[i],img_seps[i+1]] for i in range(len(img_seps)-1)]
    
    img = img.squeeze(0)
    text = text.squeeze(0)
    attention, token_typ = attention.squeeze(0), token_typ.squeeze(0)
    
    text_outputs = text_model(text, attention, token_typ)
    img_outputs = img_model(img)
    
    pid_word_embeddings = [text_outputs[0][x:y] for x, y in img_seps]
    pid_sent_embeddings = [text_outputs[1][x:y] for x, y in img_seps]
    pid_img_embeddings = [img_outputs[x:y] for x, y in img_seps]
    
    cnn_code = torch.stack([x.mean(dim=0) for x in pid_img_embeddings])
    rnn_code = torch.stack([x[0] for x in pid_sent_embeddings])
    
    img_features = [x.unsqueeze(2).unsqueeze(2) for x in pid_img_embeddings]
    img_features = [x.permute(2,1,0,3) for x in img_features]
    img_features = torch.stack(img_features, dim=0).squeeze(1)
    
    words_embs = [x[0] for x in pid_word_embeddings]
    words_embs = torch.stack(words_embs)
    
    cap_lens = [len([w for w in sent if not w.startswith("[")]) + 1 for sent in text_outputs[2]]
    cap_lens = [cap_lens[i] for i in np.arange(0, len(cap_lens), 8)]
    
    gloss0, gloss1 = global_loss(cnn_code, rnn_code)
    lloss0, lloss1, att_maps=local_loss(img_features, words_embs, cap_lens)

    break

NameError: name 'train_loader' is not defined

In [12]:
df.dtype.value_counts()

train    51279
Name: dtype, dtype: int64

BertEncoder(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)