In [1]:
# Import Pytorch 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision


class LSTM(nn.Module):

    def __init__(self, hidden_dim=128, input_dim=76, num_classes=1, batch_first=True, dropout=0.0, layers=1):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.layers = layers
        for layer in range(layers):
            setattr(self, f'layer{layer}', nn.LSTM(
                input_dim, hidden_dim,
                batch_first=batch_first,
                dropout = dropout)
            )
            input_dim = hidden_dim

        self.do = nn.Dropout(dropout)
        self.feats_dim = hidden_dim
        self.dense_layer = nn.Identity() #nn.Linear(hidden_dim, num_classes)
        self.initialize_weights()

    def initialize_weights(self):
        for model in self.modules():

            if type(model) in [nn.Linear]:
                nn.init.xavier_uniform_(model.weight)
                nn.init.zeros_(model.bias)
            elif type(model) in [nn.LSTM, nn.RNN, nn.GRU]:
                nn.init.orthogonal_(model.weight_hh_l0)
                nn.init.xavier_uniform_(model.weight_ih_l0)
                nn.init.zeros_(model.bias_hh_l0)
                nn.init.zeros_(model.bias_ih_l0)

    def forward(self, x, seq_lengths):
        x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
        for layer in range(self.layers):
             x, (ht, _) = getattr(self, f'layer{layer}')(x)
        feats = ht.squeeze()
        out = self.do(feats)
        out = self.dense_layer(out)
#         scores = torch.sigmoid(out)
        return out


class CXRModels(nn.Module):

    def __init__(self, args, hidden_dim, device='cpu'):
        super(CXRModels, self).__init__()
        self.args = args
        self.device = device
        self.vision_backbone = getattr(torchvision.models, self.args.vision_backbone)(pretrained=self.args.pretrained) #,
                                                                                     #num_classes=4*hidden_dim)
        #d_visual = self.vision_backbone.fc.in_features
        # try changing to adaptive averge pooling 
        classifiers = [ 'classifier', 'fc']
        for classifier in classifiers:
            cls_layer = getattr(self.vision_backbone, classifier, None)
            # print(cls_layer)
            if cls_layer is None:
                continue
            d_visual = cls_layer.in_features
            setattr(self.vision_backbone, classifier, nn.Identity(d_visual))
            break
            
        self.bce_loss = torch.nn.BCELoss(size_average=True)
        #self.classifier = nn.Sequential(nn.Linear(d_visual, self.args.vision_num_classes), nn.Sigmoid())
        self.feats_dim = d_visual
        # self.feats_dim = 2048
        self.feats_dim = 512
        

    def forward(self, x, labels=None, n_crops=0, bs=16):
        lossvalue_bce = torch.zeros(1).to(self.device)

        visual_feats = self.vision_backbone(x)
#         preds = self.vision_backbone.fc(visual_feats)
#         if n_crops > 0:
#             preds = preds.view(bs, n_crops, -1).mean(1)
#         if labels is not None:
#             lossvalue_bce = self.bce_loss(preds, labels)

        return  visual_feats

In [2]:
# Copyright 2022 Farah E. Shamout
#
# TODO: licsense
# ==============================================================================
"""This script defines the different fusion functions that can be used with SimCLR and baseline models."""

#TODO: move to models folder later

import os

# Import Pytorch 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision

# Pytorch flask to get LARS
import flash
from flash.core.optimizers import LARS

## Performance metrics
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

# Other
from copy import deepcopy
from tqdm import tqdm

# Custom
import load_tasks as tasks


class Fusion(nn.Module): 
    def __init__(self, args):
        super().__init__()
        self.args = args
        hidden_dim=args.hidden_dim
        fusion_dim = {0:{},
                      3:{}}

        if 'ehr' not in args.fusion_type:
            # Base model chest X-ray modality f(.)
            self.cxr_model = CXRModels(args, args.hidden_dim)
            # MLP for chest X-ray modality g(.)
            w=args.width
            print(self.cxr_model.vision_backbone.fc)
            self.cxr_model_g = nn.Sequential(
                self.cxr_model.vision_backbone.fc,  # Linear(ResNet output, 4*hidden_dim)
                # shaza: changed linear layer input from 512 to 2048
                nn.Linear(512, w*hidden_dim, bias=True),
                nn.ReLU(inplace=True),
                nn.Linear(w*hidden_dim, w*hidden_dim, bias=True),
                nn.ReLU(inplace=True),
                nn.Linear(w*hidden_dim, w*hidden_dim, bias=False)
            )

            fusion_dim[0]['cxr']= self.cxr_model.feats_dim
            fusion_dim[3]['cxr']=w*hidden_dim

        if 'cxr' not in args.fusion_type:
            # Base model EHR f(.)
            w=args.width
            self.ehr_model = LSTM(hidden_dim=args.hidden_dim, input_dim=76, num_classes=w*args.hidden_dim, dropout=args.dropout, layers=args.layers)

            # MLP for EHR modality g(.) 
            self.ehr_model_g = nn.Sequential(
                self.ehr_model.dense_layer,  # this is identify in encoders.py
                nn.Linear(128, w*hidden_dim, bias=True),
                nn.ReLU(inplace=True),
                nn.Linear(w*hidden_dim, w*hidden_dim, bias=True),
                nn.ReLU(inplace=True),
                nn.Linear(w*hidden_dim, w*hidden_dim, bias=False)
            )

            fusion_dim[0]['ehr']=self.ehr_model.feats_dim
            fusion_dim[3]['ehr']=w*hidden_dim
        
        
        # Single layer for linear evaluation of representations
        if self.args.fusion_type == 'lineareval_ehr':
            feats_dim = fusion_dim[args.fusion_layer]['ehr']
            #feats_dim = self.ehr_model.feats_dim
        
        elif self.args.fusion_type == 'lineareval_cxr':
            feats_dim = fusion_dim[args.fusion_layer]['cxr']
            #feats_dim = self.cxr_model.feats_dim
        
        else:
            feats_dim = fusion_dim[args.fusion_layer]['ehr'] + fusion_dim[args.fusion_layer]['cxr']
            #feats_dim = self.ehr_model.feats_dim + self.cxr_model.feats_dim
        
        # print(self.args.fusion_type)
        if self.args.fusion_type != 'None':
            self.fused_cls = nn.Sequential(
                nn.Linear(feats_dim, self.args.num_classes),
                nn.Sigmoid()
            )
        
    def forward(self, x=None, seq_lengths=None, img=None, pairs=None):
        # New for SimCLR
        if self.args.fusion_type == 'lineareval_ehr':
            return self.forward_uni_eval_ehr(x, seq_lengths=seq_lengths)
        elif self.args.fusion_type == 'lineareval_cxr':
            return self.forward_uni_eval_cxr(img=img)
        elif self.args.fusion_type in ['joint',  'early', 'late_avg', 'unified', 'lineareval']:
            return self.forward_fused(x, seq_lengths=seq_lengths, img=img, pairs=pairs)
        else:
            return self.forward_simclr(x, seq_lengths=seq_lengths, img=img, pairs=pairs)
        
    def forward_simclr(self, x, seq_lengths, img, pairs=None):
        if self.args.mode == 'eval':
            feats_img_0 = self.cxr_model(img)
            feats_img_3 = self.cxr_model_g(feats_img_0)
            feats_ehr_0 = self.ehr_model(x, seq_lengths)
            feats_ehr_3 = self.ehr_model_g(feats_ehr_0)
            
            return feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3
            
        else:
            feats_img = self.cxr_model(img)
            feats_img = self.cxr_model_g(feats_img)
            feats_ehr = self.ehr_model(x, seq_lengths)
            feats_ehr = self.ehr_model_g(feats_ehr)
        
            return feats_ehr, feats_img
    
    def forward_fused(self, x, seq_lengths=None, img=None, pairs=None ):
        if ('lineareval' in self.args.fusion_type) & (not self.args.finetune):
            ehr_feats = x
            cxr_feats = img
        else:
            ehr_feats = self.ehr_model(x, seq_lengths) #ehr_preds , 
            cxr_feats = self.cxr_model(img) #cxr_preds, _ , 
#         projected = self.projection(cxr_feats)

        feats = torch.cat([ehr_feats, cxr_feats], dim=1)
        fused_preds = self.fused_cls(feats)

        return {
            'early': fused_preds, 
            'joint': fused_preds, 
            'lineareval': fused_preds,
            'ehr_feats': ehr_feats,
#             'cxr_feats': projected,
            'unified': fused_preds
            }
    
    def forward_uni_eval_cxr(self, img ):
        if ('lineareval' in self.args.fusion_type) & (not self.args.finetune):
            cxr_feats = img
        else:
            cxr_feats = self.cxr_model(img)
        preds = self.fused_cls(cxr_feats)
        return {
            'lineareval_cxr': preds,
            }
    
    def forward_uni_eval_ehr(self, x, seq_lengths=None):
        if ('lineareval' in self.args.fusion_type) & (not self.args.finetune):
            ehr_feats = x
        else:
            ehr_feats = self.ehr_model(x, seq_lengths)
        preds = self.fused_cls(ehr_feats)
        return {
            'lineareval_ehr': preds,
            }   
                      

In [3]:
# Import Pytorch 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, Callback, TQDMProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision
import math

# Import other useful libraries
from sklearn.linear_model import LogisticRegression as LR
from sklearn.neural_network import MLPClassifier
import pickle
from flash.core.optimizers import LARS
import os
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from copy import deepcopy
from tqdm import tqdm

# Import custom libraries/functions
import load_tasks as tasks

gpu_num=os.environ['CUDA_VISIBLE_DEVICES']

# Set cuda device
if gpu_num=='0':
    gpu=[0]
elif gpu_num=='1':
    gpu=[1]
elif gpu_num=='2':
    gpu=[2]
elif gpu_num=='3':
    gpu=[3]
elif gpu_num=='4':
    gpu=[4]
elif gpu_num=='5':
    gpu=[5]
elif gpu_num=='6':
    gpu=[6]
elif gpu_num=='7':
    gpu=[7]
elif gpu_num=='8':
    gpu=[8]
else:
    gpu=['None']
print('Using {} device...'.format(gpu)) 
       
class SimCLR(pl.LightningModule):

    def __init__(self, args, train_dl):
        super().__init__()
        assert args.temperature > 0.0, 'The temperature must be a positive float!'
        self.warmup_epochs= 10 #int(0.05*max_epochs) (10 as in SimCLR)
        self.automatic_optimization = False
        
        self.num_train_batches=len(train_dl)
        self.batch_size=args.batch_size
        hidden_dim=args.hidden_dim
        self.args=args
        self.LABEL_COLUMNS = tasks.load_labels(args.task)
        self.task = args.task
        
        # Load the architecture based on args
        self.model = Fusion(args)
        self.load_weights()
        self.freeze_weights()
        
        
    def load_weights(self):
        # loads both encoders for simclr
        load_dir_simclr = self.args.save_dir

        if self.args.load_state is not None: 
            # what is 'LC' and why this condition?
            # if 'LC' not in self.args.load_state:
            #     if 'mortality' in self.args.save_dir:
            #         load_dir_simclr = load_dir_simclr.replace('mortality', 'phenotyping')
                    
            if 'epoch' in self.args.load_state:
                model_dir='/'+self.args.load_state.split('_epoch')[0] + '/'
                if self.args.tag == 'eval_epoch':
                    checkpoint = torch.load(load_dir_simclr + model_dir + self.args.load_state+".ckpt", map_location="cpu")
                else:
                    checkpoint = torch.load(load_dir_simclr + model_dir + self.args.load_state+".ckpt")    
            else:
                if self.args.tag == 'eval_epoch':
                    checkpoint = torch.load(os.path.join(load_dir_simclr, self.args.load_state+".ckpt"), map_location="cpu")
                else:
                    checkpoint = torch.load(os.path.join(load_dir_simclr, self.args.load_state+".ckpt"))
            own_state = self.model.state_dict()
            own_keys = list(own_state.keys())
            checkpoint_keys = list(checkpoint['state_dict'].keys())
            
            print('Total number of checkpoint params = {}'.format(len(checkpoint_keys)))
            print('Total number of current model params = {}'.format(len(own_keys)))

            count = 0
            changed = []
            for name in own_keys:
                if name not in checkpoint_keys:
                    # print(name)
                    # double check if name exists in a different format
                    for x in checkpoint_keys:
                        if name in x:
                            param=checkpoint['state_dict'][x]
                            if isinstance(param, torch.nn.Parameter):
                                param=param.data
                            own_state[name].copy_(param)
                            count+=1
                else:
                    param=checkpoint['state_dict'][name]
                    if isinstance(param, torch.nn.Parameter):
                        param=param.data
                    own_state[name].copy_(param)
                    count+=1
            print('Total number params loaded for model weights = {}'.format(count))
        
    def freeze_weights(self):
        if self.args.finetune:
            if 'ehr' not in self.args.fusion_type:
                self.freeze(self.model.cxr_model_g)
            if 'cxr' not in self.args.fusion_type:    
                self.freeze(self.model.ehr_model_g)
        else: 
            if 'lineareval' in self.args.fusion_type:
                print('freezing encoders')
                if 'ehr' not in self.args.fusion_type:
                    self.freeze(self.model.cxr_model)
                    self.freeze(self.model.cxr_model_g)
                if 'cxr' not in self.args.fusion_type:
                    self.freeze(self.model.ehr_model)
                    self.freeze(self.model.ehr_model_g) 
        
    def freeze(self, model):
        for p in model.parameters():
            p.requires_grad = False     
    
    def configure_optimizers(self):
        
        if self.args.fusion_type == 'None':
            # Scaled learning rate in case of multiple GPUs
            if self.args.num_gpu > 1:
                effective_batchsize = self.args.batch_size*self.args.num_gpu
                scaled_lr = self.args.lr*effective_batchsize/self.args.batch_size
            else:
                scaled_lr = self.args.lr 
                        
            # Optimizer
            optimizer = LARS(self.parameters(), lr=scaled_lr, momentum=0.9, weight_decay=self.args.weight_decay)
            
            # Note that the order of the below affects the initial starting learning rate, hence do not change.
            # Main scheduler
            mainscheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 500, verbose=False)
            # Learning rate warmup
            lambda1= lambda epoch : (epoch+1)/self.warmup_epochs
            warmupscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda1, verbose=False)
                         
            return [optimizer], [mainscheduler, warmupscheduler]
        
        
        else:
            optimizer_adam = optim.AdamW(self.parameters(), lr=self.args.lr) #, weight_decay=self.args.weight_decay)
            lr_scheduler_adam = optim.lr_scheduler.MultiStepLR(optimizer_adam,milestones=[int(self.args.epochs*0.6),
                                                       int(self.args.epochs*0.8)],gamma=0.1)
            return [optimizer_adam], [lr_scheduler_adam]
                
    def logging_status(self, mode):
        if mode == 'train':
            on_step=True
            on_epoch=True
        else:
            on_step=False # Report for the sake of naming but it's not useful
            on_epoch=True
        return on_step, on_epoch
    
#     # TODO: Make this more efficient
#     def accuracy_top_k(self, k, temp):
#         temp = temp.argsort(dim=1, descending=True)[:, :k]
#         batchsize=temp.shape[0]
#         b_idx = np.arange(0,batchsize)
#         tot=0
#         for j in range(0, batchsize):
#             tot+= b_idx[j] in temp[j]
#         return tot*100/batchsize
    
    def bce_loss(self, preds, y, mode='train'):
        
        loss = nn.BCELoss()(preds, y)
        
        if torch.is_tensor(y):
            y = y.detach().cpu().numpy()
            
        auroc = np.round(roc_auc_score(y, preds.detach().cpu()), 4)
        auprc = np.round(average_precision_score(y, preds.detach().cpu()), 4)
        
        on_step=False
        on_epoch=True
        #self.log(mode + '_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        self.log(mode + '_auroc', auroc, on_step=on_step, on_epoch=on_epoch) #, logger=True)
        self.log(mode + '_auprc', auprc, on_step=on_step, on_epoch=on_epoch) #, logger=True)
        
        return loss 
    
    
    def info_nce_loss(self, feats_ehr, feats_img, mode='train'):
        # Calculate cosine similarity matrix
        cos_sim = F.cosine_similarity(feats_img[:,None,:], feats_ehr[None,:,:], dim=-1)
        #print(cos_sim.size())
        cos_sim = cos_sim /  self.args.temperature
        # double-check 
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)
        #print(self_mask.size())
        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
        #temp_1=torch.reshape(cos_sim, (cos_sim.shape[0],cos_sim.shape[1]))
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0) 
        #temp_2=torch.reshape(cos_sim_2, (cos_sim_2.shape[0],cos_sim_2.shape[1]))
        
        # Total loss 
        loss = -(nll_1 + nll_2).mean()
            
        # Logging ranking metrics
        #self.log(mode+'_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        on_step, on_epoch = self.logging_status(mode)
        #self.log(mode+'_acc_top1', self.accuracy_top_k(1, temp_1), on_step=on_step, on_epoch=on_epoch) #, logger=True)
        #self.log(mode+'_acc_top5', self.accuracy_top_k(5, temp_1), on_step=on_step, on_epoch=on_epoch) #, logger=True)
                     
        return loss
    
    
    
    def modified_info_nce_loss(self, feats_ehr, feats_img, time_diff, mode='train'):
        # Calculate cosine similarity matrix
        cos_sim = F.cosine_similarity(feats_img[:,None,:], feats_ehr[None,:,:], dim=-1)
        cos_sim = cos_sim /  self.args.temperature
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)
        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute the values of beta
        k = 1
        time_diff = torch.FloatTensor(time_diff)
        beta = torch.exp(-k*time_diff).to(self.device)
        
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
        nll_1 = beta*nll_1
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0)
        nll_2 = beta*nll_2
        
        # Total loss 
        loss = -(nll_1 + nll_2).mean()
            
        # Logging ranking metrics
        #self.log(mode+'_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        on_step, on_epoch = self.logging_status(mode)
       
        return loss 
    
    
    def off_diagonal(self,x):
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    
    def vicreg_loss(self, feats_ehr, feats_img, mode='train'):
        x = feats_ehr
        y = feats_img
        repr_loss = F.mse_loss(x, y)

        #x = torch.cat(FullGatherLayer.apply(x), dim=0) #
        #y = torch.cat(FullGatherLayer.apply(y), dim=0) #
        
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
                
        cov_x = (x.T @ x) / (self.args.batch_size - 1)
        cov_y = (y.T @ y) / (self.args.batch_size - 1)
        
        num_features = len(cov_x) #TODO as arg
                
        cov_loss = self.off_diagonal(cov_x).pow_(2).sum().div(num_features) + self.off_diagonal(cov_y).pow_(2).sum().div(num_features)

        loss = (
            self.args.sim_coeff * repr_loss
            + self.args.std_coeff * std_loss
            + self.args.cov_coeff * cov_loss
        )
        on_step, on_epoch = self.logging_status(mode)
        
        return loss
        
        
    
    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        mode = 'train'
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False) & (self.args.vicreg == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            print("ehr", ehr.shape, "seq_lengths" , len(seq_lengths) , "imgs", imgs.shape)
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            print(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            # print(feats_ehr.shape, feats_img.shape)
            # Compute and log infoNCE loss
            loss = self.info_nce_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
            
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == True) & (self.args.vicreg == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log infoNCE loss
            loss = self.modified_info_nce_loss(feats_ehr, feats_img, time_diff, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == False) & (self.args.vicreg == True)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log vicreg loss
            loss = self.vicreg_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
                        
        else:
            
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            # print(preds)

            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
    
        # import pdb; pdb.set_trace()
        # print("loss" , loss.shape)
        # Backpropagate
        self.manual_backward(loss)
        # Optimizer step
        opt.step()
        # Learning rate step
        if self.args.fusion_type=='None':
            mainscheduler, warmupscheduler = self.lr_schedulers()
            if (self.trainer.is_last_batch) and (self.trainer.current_epoch < self.warmup_epochs-1):
                warmupscheduler.step()
            elif (self.trainer.is_last_batch) and (self.trainer.current_epoch >= self.warmup_epochs-1):
                mainscheduler.step()
                
#             if (batch_idx==self.num_train_batches-1) & (self.trainer.current_epoch < self.warmup_epochs-1):
#                 warmupscheduler.step()
#             elif (batch_idx==self.num_train_batches-1) & (self.trainer.current_epoch >= self.warmup_epochs-1):
#                 mainscheduler.step()
            

            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
        else:
            return {'loss': loss}

    
    def validation_step(self, batch, batch_idx):
        mode='val'
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs) 
            
            # Compute and log infoNCE loss
            loss = self.info_nce_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
        
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == True)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log infoNCE loss
            loss = self.modified_info_nce_loss(feats_ehr, feats_img, time_diff, mode)
            self.log(mode+'_loss', loss, on_step=False, on_epoch=True) #, logger=True)
            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
            
        else:
            
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            # print(preds)
            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            # Compute and log BCE loss
            #loss = self.bce_loss(batch, mode='val')
            self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
            return {'loss': loss}
        
    def test_step(self, batch, batch_idx):
        mode='test'
        
        
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False)):
            if self.args.beta_infonce == True:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            else:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            
            # At test time of SIMCLR, always return all the layer features
            if self.args.mode == 'eval':
                feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3 = self.model(ehr, seq_lengths, imgs) 
            
                # Compute and log infoNCE loss
                if self.args.beta_infonce == True:
                    loss = self.modified_info_nce_loss(feats_ehr_3, feats_img_3, time_diff, mode)
                else:
                    loss = self.info_nce_loss(feats_ehr_3, feats_img_3, mode)
                self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
                return {'loss': loss,   'feats_ehr_0': feats_ehr_0.detach().cpu(), 
                                        'feats_ehr_3': feats_ehr_3.detach().cpu(), 
                                        'feats_img_0': feats_img_0.detach().cpu(), 
                                        'feats_img_3': feats_img_3.detach().cpu(), 
                                        'y_ehr':y_ehr}        
        
        else:
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            
            # print(y.shape, preds.shape)
            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            #loss = self.bce_loss(batch, mode=mode)
            return {'loss': loss, 'preds': preds, 'y_ehr': y}
                
    
    def process_features(self, outputs, mode):
        y = []
        if self.args.mode=='eval':
            feats_ehr_0=[]
            feats_ehr_3=[]
            feats_img_0=[]
            feats_img_3=[]
        elif mode == 'test':
            preds = []
        else:
            feats_ehr = []
            feats_img = []
        # Iterate through batches and append
        i=0
        for output in outputs:
            if i ==0:
                if self.args.mode == 'eval':
                    feats_ehr_0 = output['feats_ehr_0'].detach().cpu()
                    feats_ehr_3 = output['feats_ehr_3'].detach().cpu()
                    feats_img_0 = output['feats_img_0'].detach().cpu()
                    feats_img_3 = output['feats_img_3'].detach().cpu()
                elif mode == 'test':
                    preds = output['preds'].detach().cpu()
                else: 
                    feats_ehr = output['feats_ehr'].detach().cpu()
                    feats_img = output['feats_img'].detach().cpu()
                y = output['y_ehr'].tolist()
                
            else:
                if self.args.mode == 'eval':
                    feats_ehr_0 = torch.cat((feats_ehr_0, output['feats_ehr_0'].detach().cpu()))
                    feats_ehr_3 = torch.cat((feats_ehr_3, output['feats_ehr_3'].detach().cpu()))
                    feats_img_0 = torch.cat((feats_img_0, output['feats_img_0'].detach().cpu()))
                    feats_img_3 = torch.cat((feats_img_3, output['feats_img_3'].detach().cpu()))
                elif mode == 'test':
                    preds = torch.cat((preds, output['preds'].detach().cpu()))
                else:
                    feats_ehr = torch.cat((feats_ehr, output['feats_ehr'].detach().cpu()))
                    feats_img = torch.cat((feats_img, output['feats_img'].detach().cpu()))
                y.extend(output['y_ehr'].tolist())
            i+=1
        if self.args.mode=='eval':
            return feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3, y
        elif mode =='test':
            return y, preds
        else:
            return feats_ehr, feats_img, y
    
    def save_features(self, x, descrip, mode):
        model_path = self.args.save_dir+'/simclr_lr/'+self.args.file_name
        if not os.path.exists(model_path):
          os.makedirs(model_path)
        
        torch.save(x, model_path+'/{}_{}_epoch_{}.pt'.format(mode, descrip, self.current_epoch))
    
    def training_epoch_end(self, outputs):
        mode='train'
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            feats_ehr, feats_img, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr, 'feats_ehr', mode)
            self.save_features(feats_img, 'feats_img', mode)      
            self.save_features(y, 'y', mode)   
        
    def validation_epoch_end(self, outputs):
        mode='val'
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            feats_ehr, feats_img, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr, 'feats_ehr', mode)
            self.save_features(feats_img, 'feats_img', mode)      
            self.save_features(y, 'y', mode)
            

    def test_epoch_end(self, outputs):
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            mode = self.args.eval_set
            feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr_0, 'feats_ehr_0', mode)
            self.save_features(feats_ehr_3, 'feats_ehr_3', mode)
            self.save_features(feats_img_0, 'feats_img_0', mode)
            self.save_features(feats_img_3, 'feats_img_3', mode)      
            self.save_features(y, 'y', mode)
        else:
            if self.task =='phenotyping':
                mode = 'test'
                y, preds = self.process_features(outputs, mode)

                auroc_per_label = np.round(roc_auc_score(y, preds, average=None), 4)
                auprc_per_label = np.round(average_precision_score(y, preds, average=None), 4)


                auroc_label={}
                auprc_label={}
                for i, name in enumerate(self.LABEL_COLUMNS):
                    auroc_label[name]=auroc_per_label[i].item()
                    auprc_label[name]=auprc_per_label[i].item()
                    #print(name, auroc_per_label[i], auprc_per_label[i])

                self.log('auroc_label', auroc_label)
                self.log('auprc_label', auprc_label)
            
    def calculate_auroc_epoch(self, outputs, mode):
        labels = []
        predictions = []
        auroc_label={}
        outputs=outputs[self.args.fusion_type].squeeze()
        
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)

        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        for i, name in enumerate(self.LABEL_COLUMNS):
            class_roc_auc = roc_auc_score(labels[:, i], predictions[:, i])
            auroc_label[name]=class_roc_auc
            
        auroc = roc_auc_score(labels, predictions)
        auprc = average_precision_score(labels, predictions)
        
        return auroc, auprc
       
    
def return_model_version(trainer):
    filename = trainer.checkpoint_callback.filename
    best_model_path = trainer.checkpoint_callback.best_model_path
    return  filename+best_model_path.split('.ckpt')[0].split(filename)[1]



def train(model, args, train_loader, val_loader, **kwargs): 
    filename = args.file_name+'_epoch_{epoch:02d}'
    
    model_path = args.save_dir+'/'+args.file_name
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    
    # For model selection
    if 'logger' not in kwargs:
        checkpoint_callback = ModelCheckpoint(monitor="val_auroc", mode='max', 
                                              filename=args.load_state+'_{epoch:02d}',
                                              every_n_epochs=1,             
                                              save_on_train_epoch_end=True) 
        
        trainer = pl.Trainer(default_root_dir=os.path.join(model_path),
                         accelerator="auto",
                         max_epochs=args.epochs, gpus=gpu,
                         callbacks=[checkpoint_callback],
                         enable_progress_bar=False,
                         num_sanity_val_steps=0)
                         
        
    # For SIMCLR training
    else:
        logger = kwargs['logger']
        checkpoints = ModelCheckpoint(dirpath=model_path,
                                  filename=filename,
                                  save_weights_only=True, 
                                  save_top_k=-1,
                                  auto_insert_metric_name=False, 
                                  every_n_epochs=1,             
                                  save_on_train_epoch_end=True)
        if args.num_gpu == 1:
            strategy = None
        else:
            strategy = 'ddp'
        
        trainer = pl.Trainer(default_root_dir=os.path.join(model_path),
                             max_epochs=args.epochs, #gpus=gpu,
                             callbacks=[checkpoints, LearningRateMonitor('epoch')],
                             logger=logger,  
                             log_every_n_steps=5,  enable_progress_bar=True,
                             num_sanity_val_steps=0,
                            accelerator='gpu', devices=args.num_gpu, strategy=strategy)

   
    num_batches_train = len(train_loader)   
    # import pdb; pdb.set_trace()
    trainer.fit(model, train_loader, val_loader)
        
    return trainer
    

# Call this function if doing test without training
def test(model, args, test_loader, **kwargs):
    if 'logger' not in kwargs:
        trainer = pl.Trainer(default_root_dir=os.path.join(args.save_dir))
        
    else:
        logger = kwargs['logger']
        trainer = pl.Trainer(default_root_dir=os.path.join(args.save_dir), logger=logger)
    
    trainer.test(model, test_loader)
    
    return trainer

    
# Prepare data features for downstream tasks 
@torch.no_grad()
def prepare_data_features(device, model, data_loader, bs, fusion_layer, fusion_type):
    #print(fusion_layer)
    # Prepare model
    network = deepcopy(model)
    if 'ehr' not in fusion_type:
        network.model.cxr_model.vision_backbone.fc = nn.Identity() # Removing projection head g(.) 
     
    if 'cxr' not in fusion_type:
        network.model.ehr_model.dense_layer = nn.Identity() # Removing projection head g(.)
    
    network.eval()
    network.to(device)

    # Encode all images
    feats_ehr, feats_imgs, labels_ehr, labels_imgs = [], [], [], []
    
    for batch_ehr, batch_imgs, batch_ehr_labels, batch_cxr_labels, seq_lengths, pairs in data_loader:
        labels_ehr.append(torch.from_numpy(batch_ehr_labels).detach())
        #time_diff.append(torch.from_numpy(np.array(batch_time)).detach())
        
        if 'cxr' not in fusion_type:
            batch_ehr = torch.from_numpy(batch_ehr).float().to(device)
            #batch_ehr = batch_ehr.to(device)
            batch_ehr_feats = network.model.ehr_model(batch_ehr, seq_lengths)
            if fusion_layer == 3:
                batch_ehr_feats = network.model.ehr_model_g(batch_ehr_feats)
                
            #print('ehr batch shape', np.shape(batch_ehr_feats))
            #batch_ehr_feats=torch.reshape(batch_ehr_feats, (1, np.shape(batch_ehr_feats)[0])) #TODO need this for other code
            feats_ehr.append(batch_ehr_feats.detach().cpu()) 

        if 'ehr' not in fusion_type:
            batch_imgs = batch_imgs.to(device)
            batch_imgs_feats = network.model.cxr_model(batch_imgs)
            if fusion_layer == 3:
                batch_imgs_feats = network.model.cxr_model_g(batch_imgs_feats)
                
            #print('cxr batch shape', np.shape(batch_imgs_feats))
            feats_imgs.append(batch_imgs_feats.detach().cpu())
            labels_imgs.append(batch_cxr_labels)
    
    labels_ehr = torch.cat(labels_ehr, dim=0)
    #time_diff = torch.cat(time_diff, dim=0)
    
#     print('shape ehr', np.shape(feats_ehr))
#     print('shape imgs', np.shape(feats_imgs))
    
#     print('type ehr', type(feats_ehr))
#     print('type cxr', type(feats_imgs))
    
#     print(type(feats_ehr[0]))
#     print(type(feats_ehr[0][0]))
#     print(feats_ehr[0][0])
    
#     print(type(feats_imgs[0]))
#     print(type(feats_imgs[0][0]))
#     print(feats_imgs[0][0])
    
    if 'cxr' not in fusion_type:
        #if len(feats_ehr) == len(labels_ehr):
        #    feats_ehr=torch.as_tensor(feats_ehr)
        feats_ehr = torch.cat(feats_ehr, dim=0)
        

    if 'ehr' not in fusion_type:
        feats_imgs = torch.cat(feats_imgs, dim=0)
        labels_imgs = torch.cat(labels_imgs, dim=0)

    if 'cxr' in fusion_type:
        return data.DataLoader(data.TensorDataset(feats_imgs, labels_imgs, labels_ehr), batch_size=bs, shuffle=False, drop_last=False)
    elif 'ehr' in fusion_type:
        return data.DataLoader(data.TensorDataset(feats_ehr, labels_ehr), batch_size=bs, shuffle=False, drop_last=False)
    else:
        return data.DataLoader(data.TensorDataset(feats_ehr, feats_imgs, labels_ehr, labels_imgs), batch_size=bs, shuffle=False, drop_last=False)


Using [0] device...


In [4]:
# %load run_gpu.py
# Copyright 2022 Farah E. Shamout
#
# TODO: licsense
# ==============================================================================
"""This script defines the SimCLR model and performs training and evaluation."""


data_dir = '/scratch/fs999/shamoutlab/data/mimic-iv-extracted/'
img_dir = '/scratch/fs999/shamoutlab/data/physionet.org/files//mimic-cxr-jpg/2.0.0'
code_dir = '/scratch/se1525/mml-ssl'
task = 'phenotyping'

# Import libraries
import sys
#sys.path.append(f'{code_dir_medfuse}')
import numpy as np
import argparse
import os
import importlib as imp
import re
from pathlib import Path
import pandas as pd
import neptune.new as neptune
from pathlib import Path

# ## Visualization
# import matplotlib.pyplot as plt
# from matplotlib.pyplot import figure
# from tqdm.notebook import tqdm
# import matplotlib
# matplotlib.use('Agg')

# Import Pytorch 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision


## Performance metrics
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score, average_precision_score

# Import custom functions
import custom_parser as par
import data_utils as prep
# from fusion_trainer_farah import FusionTrainer
# from mmtm_trainer import MMTMTrainer
# from daft_trainer import DAFTTrainer


# Import functions from MedFuse
import datasets.ehr_dataset 
from datasets.ehr_dataset import get_datasets
from datasets.ehr_dataset import EHRdataset
from datasets.cxr_dataset import get_cxr_datasets
import datasets.fusion
from datasets.fusion import load_cxr_ehr
from ehr_preprocess import ehr_funcs
import load_tasks as tasks

#sys.path.append('/home/shamoutlab/.local/bin')

import warnings
warnings.filterwarnings("ignore")

import importlib
%load_ext autoreload
%autoreload 2

In [5]:
%reload_ext autoreload

In [6]:
def initiate_logger(tags):
    logger = pl_loggers.NeptuneLogger(project="shaza-workspace/mml-ssl",
    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI4NDU3ZDlmMi01OGEyLTQzMTAtODJmYS01Mjc5N2U4ZjgyMTAifQ==", tags=tags, log_model_checkpoints=False)
    return logger

seed = 1002
torch.manual_seed(seed)
np.random.seed(seed)

In [7]:
# Parse arguments
parser = par.initiate_parsing()
args = parser.parse_args([ '--device' , '$CUDA_VISIBLE_DEVICES',
'--vision-backbone', 'resnet34' ,
'--resize', '256' , 
'--job_number' , '${SLURM_JOBID}',
'--file_name' , 'SIMCLR-${SLURM_JOBID}' ,
'--epochs' , '2' , '--transforms_cxr' , 'simclrv2' , '--temperature' , '0.01' ,
'--batch_size' , '30' , '--lr' , '0.8' ,
'--num_gpu' , '1' ,
'--pretrain_type' , 'simclr' ,
'--mode' , 'train' ,
'--fusion_type' , 'None' ,
'--save_dir' , '/scratch/se1525/mml-ssl/checkpoints/phenotyping/models' ,
'--tag' , 'simclr_train'])
job_number = args.job_number

path = Path(args.save_dir)
path.mkdir(parents=True, exist_ok=True)

In [8]:
# Set cuda device
if args.device=='0':
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  
elif args.device=='1':
    device = 'cuda:1' if torch.cuda.is_available() else 'cpu'   
elif args.device=='2':
    device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
elif args.device=='3':
    device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
elif args.device=='4':
    device = 'cuda:4' if torch.cuda.is_available() else 'cpu'
elif args.device=='5':
    device = 'cuda:5' if torch.cuda.is_available() else 'cpu'
elif args.device=='6':
    device = 'cuda:6' if torch.cuda.is_available() else 'cpu'
elif args.device=='7':
    device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
elif args.device=='8':
    device = 'cuda:8' if torch.cuda.is_available() else 'cpu'
else:
    device = 'None'
print('Using {} device...'.format(device))        


Using None device...


In [9]:
# Load datasets and initiate dataloaders
importlib.reload(datasets.fusion)
print('Loading datasets...')
discretizer, normalizer = ehr_funcs(args)
ehr_train_ds, ehr_val_ds, ehr_test_ds = get_datasets(discretizer, normalizer, args)
cxr_train_ds, cxr_val_ds, cxr_test_ds = get_cxr_datasets(args)
train_dl, val_dl, test_dl = load_cxr_ehr(args, ehr_train_ds, ehr_val_ds, cxr_train_ds, cxr_val_ds, ehr_test_ds, cxr_test_ds)

Loading datasets...
Appling SimCLR image transforms...
Number of CXR images= 377110
Number of ICU stays= 59372
Number of CXR associated with ICU stay based on subject ID= 368350
Number of unique CXR dicoms= 181195
Number of unique CXR study id= 122087
Mean time cxr - intime=  68.3979428307123
Minimum time = 0.009
Maximum time = 2368.942
Excluding CXR with missing radiology reports =  7756
7756
882
2166


In [10]:
# Store arguments after loading datasets
# line added by shaza
os.makedirs(os.path.dirname(f"{args.save_dir}/args/args_{job_number}.txt"), exist_ok=True)
with open(f"{args.save_dir}/args/args_{job_number}.txt", 'w') as results_file:
    print("Storing arguments...")
    for arg in vars(args): 
        print(f"  {arg:<40}: {getattr(args, arg)}") 
        results_file.write(f"  {arg:<40}: {getattr(args, arg)}\n")


Storing arguments...
  device                                  : $CUDA_VISIBLE_DEVICES
  num_gpu                                 : 1
  epochs                                  : 2
  lr                                      : 0.8
  save_dir                                : /scratch/se1525/mml-ssl/checkpoints/phenotyping/models
  labels_set                              : pheno
  task                                    : phenotyping
  data_pairs                              : paired
  mode                                    : train
  tag                                     : simclr_train
  pretrain_type                           : simclr
  file_name                               : SIMCLR-${SLURM_JOBID}
  load_state                              : None
  eval_set                                : val
  job_number                              : ${SLURM_JOBID}
  eval_epoch                              : 0
  load_state_ehr                          : None
  num_classes                             

In [11]:
# Initiate logger
neptune_logger = initiate_logger([args.tag, args.job_number])  
neptune_logger.experiment["args"] = vars(args)


https://app.neptune.ai/shaza-workspace/mml-ssl/e/MMLSSL-433
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [12]:
# Load the model, weights (if any), and freeze layers (if any)
print("Loading model...")
if args.pretrain_type == 'simclr':
    model = SimCLR(args, train_dl)
print('Printing model architecture...')
print(model)

Loading model...




Identity()
Printing model architecture...
SimCLR(
  (model): Fusion(
    (cxr_model): CXRModels(
      (vision_backbone): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): 

In [24]:
# # EHR input 
# ehr_input = torch.randn(30, 76, 128, 512)
# # Vision input 
# vision_input = torch.randn(30, 3, 512, 615)

# print(model.model.cxr_model.feats_dim)
# print(model)

# data_iter = iter(train_dl)
# batch = next(data_iter)
# ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

# print(ehr.shape, len(seq_lengths), imgs.shape)

# vision_output = model.model.cxr_model.vision_backbone(imgs)
# print(vision_output.shape)

# print(model.model.cxr_model.vision_backbone)

# Print vision backbone arch 
# print(model.model.cxr_model.vision_backbone)
 # Pass through vision backbone
# imgs = imgs.to("cuda:0")
# vision_output = model.model.cxr_model.vision_backbone(imgs)
# print("Output Shape of vision_backbone" , vision_output.shape)

# print(" ---------------------- ")
# # Pass through the layers one by one and examine the output
# input_data = imgs
# layer_outputs = []
# for name, layer in model.model.cxr_model.vision_backbone.named_children():
#     input_data = layer(input_data)
#     layer_outputs.append((name, input_data.clone()))

# for name, output in layer_outputs:
#     print(f"Layer: {name}, Output Shape: {output.shape}")
    
    
#  # dim 512 x 128
#  # self.cxr_model.vision_backbone.fc: Identity() 
#  # nn.Linear(512, w*hidden_dim, bias=True)
# cxr_proj = model.model.cxr_model_g(output)

    

# ehr = torch.from_numpy(ehr).float().to("cuda:0")
# imgs = imgs.to("cuda:0")
# feats_img = model.model.cxr_model(imgs)
# print(feats_img.shape)
# feats_img = model.model.cxr_model_g(feats_img)
# print(feats_img.shape)
# feats_ehr = model.model.ehr_model(ehr, seq_lengths)
# print(feats_ehr.shape)
# feats_ehr = model.model.ehr_model_g(feats_ehr)
# print(feats_ehr.shape)

# print(feats_ehr.shape, feats_img.shape)

# # # Pass through EHR model
# ehr = torch.from_numpy(ehr)
# ehr = ehr.float()
# ehr_output = model.model.ehr_model(ehr, seq_lengths)
# final_output = model.model.ehr_model_g(ehr_output)




# # Pass through CXR model
# cxr_output = model.model.cxr_model(imgs)
# print(cxr_output.shape)

# ehr (30, 628, 76) seq_lengths 30 imgs torch.Size([30, 3, 256, 256])

data_iter = iter(train_dl)
batch = next(data_iter)
ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
print(ehr.shape)
print(len(seq_lengths))
print(pairs)
test = torch.randn(30,3,224,224)

vision_output = model.model.ehr_model(ehr, seq_lengths)
print("Output Shape of vision_backbone" , vision_output.shape)

print(" ---------------------- ")
# Pass through the layers one by one and examine the output
input_data = imgs
layer_outputs = []
for name, layer in model.model.ehr_model.named_children():
    input_data = layer(input_data)
    layer_outputs.append((name, input_data.clone()))

for name, output in layer_outputs:
    print(f"Layer: {name}, Output Shape: {output.shape}")
    
    
 # self.cxr_model.vision_backbone.fc: Identity() 
 # nn.Linear(512, w*hidden_dim, bias=True) :  512 x 128
# cxr_proj = model.model.cxr_model_g(output)


(30, 382, 76)
30
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]


AttributeError: 'numpy.ndarray' object has no attribute 'device'

In [16]:
print('==> training')        
filename = args.file_name+'_epoch_{epoch:02d}'
    
model_path = args.save_dir+'/'+args.file_name
if not os.path.exists(model_path):
        os.makedirs(model_path)
        
logger = neptune_logger
checkpoints = ModelCheckpoint(dirpath=model_path,
                                  filename=filename,
                                  save_weights_only=True, 
                                  save_top_k=-1,
                                  auto_insert_metric_name=False, 
                                  every_n_epochs=1,             
                                  save_on_train_epoch_end=True)

strategy = None
trainer = pl.Trainer(default_root_dir=os.path.join(model_path),
                             max_epochs=args.epochs, #gpus=gpu,
                             callbacks=[checkpoints, LearningRateMonitor('epoch')],
                             logger=logger,  
                             log_every_n_steps=5,  enable_progress_bar=True,
                             num_sanity_val_steps=0,
                            accelerator='gpu', devices=args.num_gpu, strategy=strategy)

train_loader=train_dl
val_loader=val_dl
num_batches_train = len(train_loader)
print(num_batches_train)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


==> training
258


In [17]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name  | Type   | Params
---------------------------------
0 | model | Fusion | 21.5 M
---------------------------------
21.5 M    Trainable params
0         Non-trainable params
21.5 M    Total params
86.152    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

ehr (30, 959, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 745, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 555, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 425, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 648, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 539, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 406, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 557, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 331, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 577, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 279, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 512, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 302, 76) seq_lengths 30 imgs torch.Size([30, 3, 224, 224])
cuda:0
ehr (30, 359, 76) seq_len

In [17]:
# # For efficiency, change data loaders for lineareval (not finetune)
# if ('lineareval' in args.fusion_type) & (not args.finetune):
#     print("Processing features for linear evaluation...")
#     train_dl = prepare_data_features(device, model, train_dl, args.batch_size, args.fusion_layer, args.fusion_type) 
#     val_dl = prepare_data_features(device, model, val_dl, args.batch_size, args.fusion_layer, args.fusion_type)
#     test_dl = prepare_data_features(device, model, test_dl, args.batch_size, args.fusion_layer, args.fusion_type)

# if args.mode == 'train':
#     print('==> training')        
#     print(len(train_dl))
#     train(model, args, train_dl, val_dl,
#           logger=neptune_logger,
#           load_state_prefix=args.load_state_simclr)

# elif args.mode == 'eval':
#     print('==> evaluating on the '+args.eval_set)
#     if args.eval_set=='val':
#         test_dl=val_dl
#     elif args.eval_set=='train':
#         test_dl=train_dl
#     print(len(test_dl))
#     test(model, args, test_dl, logger=neptune_logger)
#     #trainer.eval()

# else:
#     raise ValueError("Incorrect value for args.mode")