In [11]:
import os
import random
import numpy as np
import pandas as pd
import json
import torchvision
import torch
import pytorch_lightning as pl
from torch import nn
import torch.nn.functional  as F
from collections import Counter
import warnings
warnings.filterwarnings("ignore")
from datetime import timedelta
pd.set_option('display.max_columns', None)

In [2]:
class EHRModel(nn.Module):

    def __init__(self, 
                 hidden_dim: int =256, 
                 input_dim: int =76,  
                 batch_first: bool = True, 
                 dropout: float = 0.0, 
                 layers: int = 1,
                 projection_dim: int = 512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layers = layers
        for layer in range(layers):
            setattr(self, f'layer{layer}', nn.LSTM(
                input_dim, hidden_dim,
                batch_first=batch_first,
                dropout = dropout)
            )
            input_dim = hidden_dim

        self.do = nn.Dropout(dropout)
        self.feats_dim = hidden_dim
        self.projection_layer = nn.Linear(hidden_dim, projection_dim)
        self.initialize_weights()

    def initialize_weights(self):
        for model in self.modules():

            if type(model) in [nn.Linear]:
                nn.init.xavier_uniform_(model.weight)
                nn.init.zeros_(model.bias)
            elif type(model) in [nn.LSTM, nn.RNN, nn.GRU]:
                nn.init.orthogonal_(model.weight_hh_l0)
                nn.init.xavier_uniform_(model.weight_ih_l0)
                nn.init.zeros_(model.bias_hh_l0)
                nn.init.zeros_(model.bias_ih_l0)

    def forward(self, x, seq_lengths):
        x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
        for layer in range(self.layers):
             x, (ht, _) = getattr(self, f'layer{layer}')(x)
        feats = ht.squeeze()
        out = self.do(feats)
        out = self.projection_layer(out)
        return out

In [3]:
EHRModel()(torch.randn(2,113,76),[33,69])

tensor([[-0.0452,  0.0201,  0.0941,  ...,  0.0908, -0.0756,  0.1801],
        [-0.0902,  0.0859,  0.2003,  ..., -0.1018,  0.0363,  0.0034]],
       grad_fn=<AddmmBackward>)

In [4]:
class CXRModel(nn.Module):

    def __init__(self,
                 backbone: str = 'resnet34',
                 projection_dim: int = 512):
        super().__init__()
        
        self.vision_backbone = getattr(torchvision.models, backbone)(pretrained=False)
        self.vision_backbone.fc = nn.Linear(self.vision_backbone.fc.in_features,projection_dim)



    def forward(self, x: torch.Tensor):
        visual_feats = self.vision_backbone(x)
        return  visual_feats

In [5]:
cxr = CXRModel()

In [6]:
cxr(torch.randn(2,3,224,224))

tensor([[ 0.2145,  0.5096,  0.0987,  ..., -0.6843, -1.3869,  0.7160],
        [ 0.3406,  0.7467,  0.3305,  ..., -0.6176, -1.4298,  0.6900]],
       grad_fn=<AddmmBackward>)

In [7]:
class ALIGN(nn.Module):
    def __init__(self,
                 hidden_dim: int =256, 
                 input_dim: int =76, 
                 batch_first: bool = True, 
                 dropout: float = 0.0, 
                 layers: int = 1,
                 backbone: str = 'resnet34',
                 projection_dim: int = 512):
        super().__init__()
        
        self.cxr_encoder = CXRModel(backbone=backbone,
                                     projection_dim=projection_dim)
        
        self.ehr_encoder = EHRModel(hidden_dim=hidden_dim,
                                    input_dim=input_dim,
                                    batch_first=batch_first,
                                    dropout=dropout,
                                    layers=layers,
                                    projection_dim=projection_dim)
        
    def forward(self,
               cxr: torch.Tensor,
               ehr: torch.Tensor,
               seq_lengths: list):
        
        cxr_projections = self.cxr_encoder(cxr)
        ehr_projections = self.ehr_encoder(ehr,seq_lengths)
        
        return {'cxr': cxr_projections, 
                'ehr': ehr_projections}

In [8]:
embeds = ALIGN()(torch.randn(5,3,224,224),torch.randn(5,113,76),[33,69,77,36,96])

In [9]:
embeds

{'cxr': tensor([[ 1.7096, -0.2067, -0.8512,  ...,  1.1037,  0.5145,  0.5686],
         [ 1.7030, -0.4490, -0.6695,  ...,  1.1321,  0.4178,  0.3128],
         [ 1.5802, -0.4385, -0.7142,  ...,  1.2737,  0.3301,  0.2907],
         [ 1.5846, -0.4202, -0.4880,  ...,  0.9678,  0.3520,  0.2787],
         [ 1.9196, -0.3036, -0.6729,  ...,  1.0678,  0.4076,  0.3792]],
        grad_fn=<AddmmBackward>),
 'ehr': tensor([[ 0.1214,  0.0020, -0.0869,  ..., -0.0972, -0.0283,  0.0600],
         [ 0.0198, -0.0503,  0.0010,  ...,  0.0978, -0.0425,  0.0070],
         [-0.1725, -0.0335,  0.0275,  ...,  0.1081,  0.0453,  0.0440],
         [-0.0454,  0.0133, -0.0996,  ..., -0.0259,  0.1215, -0.0659],
         [-0.1060,  0.0310,  0.0621,  ..., -0.1430, -0.1087, -0.0959]],
        grad_fn=<AddmmBackward>)}

In [10]:
class ContrastiveLoss(nn.Module):
    def __init__(self,
                temperature: float =0.07):
        
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor(temperature))


    def forward(self, cxr_feats, ehr_feats):

        cos_sim = F.cosine_similarity(cxr_feats[:,None,:], ehr_feats[None,:,:], dim=-1)

        cos_sim = cos_sim / self.temperature
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)

        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0) 

        # Total loss 
        loss = -(nll_1 + nll_2).mean()
                     
        return loss

In [11]:
ContrastiveLoss()(embeds['ehr'],embeds['cxr'])

tensor(2.9495, grad_fn=<NegBackward>)

In [12]:
class ALIGNTrainer(pl.LightningModule):
    def __init__(self,

                 hidden_dim: int =256, 
                 input_dim: int =76, 
                 batch_first: bool = True, 
                 dropout: float = 0.0, 
                 layers: int = 1,
                 backbone: str = 'resnet34',
                 projection_dim: int = 512,
                 temperature: float = 0.07,
                 lr: float = 0.0001,
                 wd=0.001,
                 max_epochs: int = 100):
        super().__init__()


        self.model = ALIGN(hidden_dim=hidden_dim,
                          input_dim=input_dim,
                          batch_first=batch_first,
                          dropout=dropout,
                          layers=layers,
                          backbone=backbone,
                          projection_dim=projection_dim)
        


        self.criterion = ContrastiveLoss(temperature=temperature)        
        
        
        self.lr = lr
        self.wd = wd
        self.max_epochs = max_epochs
        
    
        
    def training_step(self, batch, batch_idx):
        
        ehr, cxr,_ , _, seq_lengths, _ = batch
        ehr,seq_lengths = self._swap(ehr,seq_lengths)
        
        embeddings = self.model(cxr.cuda(),ehr.cuda(),seq_lengths.to('cpu'))
        
        loss = self.criterion(embeddings['cxr'], embeddings['ehr']) 
        self.log("train_loss", loss, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        
        return loss



    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.lr,
                                      weight_decay=self.wd)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                               eta_min=0.0,
                                                               T_max=self.max_epochs)
        return {'optimizer': optimizer,
               'lr_scheduler': scheduler
               }
    
    def _swap(self,ehr,seqs):
    
        ehr = torch.tensor(ehr,dtype= torch.float32)
        seqs = torch.tensor(seqs, dtype= torch.float32)
        b_size = ehr.shape[0]
    
        # number of samples to sap
        count = random.randint(int(0.16*b_size),int(0.2*b_size))
    
        # first slice limits retrieval 
        group1_start = random.randint(0,int(0.4*b_size))
        group1_end = group1_start + count
        ehr1 = torch.clone(ehr[group1_start:group1_end])
        seqs1 = torch.clone(seqs[group1_start:group1_end])
    
        # second slice limits retrieval
        group2_start = random.randint(int(0.6*b_size),int(0.8*b_size))
        group2_end = group2_start + count
        ehr2 = torch.clone(ehr[group2_start:group2_end])
        seqs2 = torch.clone(seqs[group2_start:group2_end])
    
        # perform swapping
        ehr[group1_start:group1_end] = ehr2
        seqs[group1_start:group1_end] = seqs2
    
        ehr[group2_start:group2_end] = ehr1
        seqs[group2_start:group2_end] = seqs1
    
        return ehr, seqs

In [13]:
# ALIGNTrainer().training_step(z,0)

In [14]:
import argparse

def initiate_parsing():
    parser = argparse.ArgumentParser()
    
    # Task setup
    parser.add_argument('--device', type=str, help='cuda device', default='0')
    parser.add_argument('--num_gpu', type=int, help='number of gpus for training', default=1)
    parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.0001, help='base learning rate simclr pretraining')
    parser.add_argument('--save_dir', type=str, help='Directory where all output files are stored', default='.')
    parser.add_argument('--labels_set', type=str, default='pheno', help='pheno, radiology')
    parser.add_argument('--task', type=str, default='phenotyping', help='train or eval for in-hospital-mortality or phenotyping')
    parser.add_argument('--retrive_cxr',type=str,default='recent', choices=['recent','all'], help='either to retrieve all cxr or only the most recent')
    parser.add_argument('--data_pairs', type=str, default='paired', help='paired, ehr_only, radiology, joint_ehr')
    parser.add_argument('--mode', type=str, default="train", help='mode: train or eval')  
    parser.add_argument('--tag', type=str, default="simclr train", help='simclr_train_phenotyping')      
    parser.add_argument('--pretrain_type', type=str, default="simclr", help='type of pretraining')    
    parser.add_argument('--file_name', type=str, default=None, help='prefix of model file name')      
    parser.add_argument('--load_state', type=str, default=None, help='state dir path for simclr model')
    parser.add_argument('--eval_set', type=str, default='val', help='evaluation set: val or test')
    parser.add_argument('--job_number', type=str, default='0', help='slurm job number for jubail')
    parser.add_argument('--eval_epoch', type=int, help='epoch to evaluate for model selection', default=0)

    
    # EHR setup
    parser.add_argument('--load_state_ehr', type=str, default=None, help='state dir path for uni ehr model')
    parser.add_argument('--num_classes', type=int, default=25, help='number of classes for ehr related tasks')
    parser.add_argument('--rec_dropout', type=float, default=0.0, help="dropout rate for recurrent connections")
    parser.add_argument('--timestep', type=float, default=1.0, help="fixed timestep used in the dataset")
    parser.add_argument('--imputation', type=str, default='previous')
    parser.add_argument('--ehr_data_root', type=str, help='Path to the ehr data', default='/scratch/fs999/shamoutlab/data/mimic-iv-extracted')
    parser.add_argument('--layers', default=1, type=int, help='number of lstm stacked layers')
    parser.add_argument('--dim', type=int, default=256,
                        help='number of hidden units for uni ehr lstm model')

    # CXR setup
    parser.add_argument('--load_state_cxr', type=str, default=None, help='state dir path for uni cxr model')
    parser.add_argument('--cxr_data_root', type=str, help='Path to the cxr data', default='/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0')
    parser.add_argument('--vision-backbone', default='resnet34', type=str, help='[densenet121, densenet169, densenet201, resnet34]')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',  help='load imagenet pretrained model for cxr')
    parser.add_argument('--layer_after', default=4, type=int, help='apply mmtm/daft module after fourth layer[-1, 0,1,2,3,4] -1 indicates mmtm after every layer')
    parser.add_argument('--vision_num_classes', default=14, type=int, help='number of cxr classes')
    parser.add_argument('--resize', default=256, type=int, help='cxr transform resize')
    parser.add_argument('--crop', default=224, type=int, help='cxr transform crop size')
    parser.add_argument('--dropout', type=float, default=0.0)# TODO: double check
    parser.add_argument('--hidden_dim', type=int, default=128)# TODO: double check



    # SimCLR setup
    parser.add_argument('--load_state_simclr', type=str, default=None, help='state dir path for simclr model')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--transforms_cxr', type=str, default='simclrv2', help='set image transforms of simclrv2')
    parser.add_argument('--temperature', type=float, default=0.01, help='simclr temperature')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--finetune', action='store_true',  help='finetune simclr')  
    parser.add_argument('--dataset', type=str, default='evaluation_task', help='type of dataset to work with (all being unrestricted pairs)')
    parser.add_argument('--width', type=int, default=1, help='width of projection module')
    parser.add_argument('--save_features', action='store_true', help='save features after each epoch')
    parser.add_argument('--beta_infonce', action='store_true', help='include time difference in loss computation')


    
    parser.add_argument('--linearclassify', action='store_true',  help='perform linear classification after simclr')  
    parser.add_argument('--load_state_lc', type=str, default=None, help='state dir path for linear class model')
    parser.add_argument('--lr_linearclassify', type=float, default=0.0001, help='learning rate for linear classification')
    parser.add_argument('--epochs_linearclassify', type=int, default=100, help='number of epochs to train for for linear class')
    parser.add_argument('--overwrite_classifier', action='store_true',  help='retrain the logistic regression model and overwrite')    

    
    # Fusion setup
    parser.add_argument('--fusion_type', type=str, default='None', help='train or eval for fusion types [joint, early, uni_cxr, uni_ehr, lstm]')
    parser.add_argument('--data_ratio', type=float, default=1.0, help='percentage of uppaired data samples')
    parser.add_argument('--mmtm_ratio', type=float, default=4, help='mmtm ratio hyperparameter')
    parser.add_argument('--fusion_layer', type=int, default=0, help='fusion layer')



    # Unknown classify later 
    parser.add_argument('--beta_1', type=float, default=0.9,
                        help='beta_1 param for Adam optimizer')
    parser.add_argument('--normalizer_state', type=str, default=None,
                        help='Path to a state file of a normalizer. Leave none if you want to use one of the provided ones.')
    
    
    # Vicreg
    parser.add_argument('--sim_coeff', type=float, default=25, help='vicreg sim coeff')
    parser.add_argument('--std_coeff', type=float, default=25, help='vicreg std coeff')
    parser.add_argument('--cov_coeff', type=float, default=1, help='vicreg cov coeff')
    parser.add_argument('--vicreg', action='store_true', help='vicreg loss computation')
    
    parser.add_argument('-f')
    return parser

In [471]:
parser = initiate_parsing()
args = parser.parse_args()
args.beta_infonce =True

In [472]:
for arg in vars(args):
    print (arg,'=' ,getattr(args, arg))

device = 0
num_gpu = 1
epochs = 300
lr = 0.0001
save_dir = .
labels_set = pheno
task = phenotyping
retrive_cxr = recent
data_pairs = paired
mode = train
tag = simclr train
pretrain_type = simclr
file_name = None
load_state = None
eval_set = val
job_number = 0
eval_epoch = 0
load_state_ehr = None
num_classes = 25
rec_dropout = 0.0
timestep = 1.0
imputation = previous
ehr_data_root = /scratch/fs999/shamoutlab/data/mimic-iv-extracted
layers = 1
dim = 256
load_state_cxr = None
cxr_data_root = /scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0
vision_backbone = resnet34
pretrained = False
layer_after = 4
vision_num_classes = 14
resize = 256
crop = 224
dropout = 0.0
hidden_dim = 128
load_state_simclr = None
batch_size = 256
transforms_cxr = simclrv2
temperature = 0.01
weight_decay = 0.0001
finetune = False
dataset = evaluation_task
width = 1
save_features = False
beta_infonce = True
linearclassify = False
load_state_lc = None
lr_linearclassify = 0.0001
epochs_linearclassi

In [473]:
import os
import numpy as np
from PIL import Image
import pandas as pd 
import json
import torch
from torch.utils.data import Dataset
import glob
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import random
import matplotlib.pyplot as plt

In [474]:
R_CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
       'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
       'Pneumonia', 'Pneumothorax', 'Support Devices']

CLASSES = [
       'Acute and unspecified renal failure', 'Acute cerebrovascular disease',
       'Acute myocardial infarction', 'Cardiac dysrhythmias',
       'Chronic kidney disease',
       'Chronic obstructive pulmonary disease and bronchiectasis',
       'Complications of surgical procedures or medical care',
       'Conduction disorders', 'Congestive heart failure; nonhypertensive',
       'Coronary atherosclerosis and other heart disease',
       'Diabetes mellitus with complications',
       'Diabetes mellitus without complication',
       'Disorders of lipid metabolism', 'Essential hypertension',
       'Fluid and electrolyte disorders', 'Gastrointestinal hemorrhage',
       'Hypertension with complications and secondary hypertension',
       'Other liver diseases', 'Other lower respiratory disease',
       'Other upper respiratory disease',
       'Pleurisy; pneumothorax; pulmonary collapse',
       'Pneumonia (except that caused by tuberculosis or sexually transmitted disease)',
       'Respiratory failure; insufficiency; arrest (adult)',
       'Septicemia (except in labor)', 'Shock'
    ]

# CXR Dataset

In [475]:
def get_transforms(args):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    train_transforms = []
    train_transforms.append(transforms.Resize(args.resize))
    train_transforms.append(transforms.RandomHorizontalFlip())
    train_transforms.append(transforms.RandomAffine(degrees=45, scale=(.85, 1.15), shear=0, translate=(0.15, 0.15)))
    train_transforms.append(transforms.CenterCrop(224))#args.crop))
    train_transforms.append(transforms.ToTensor())
    #train_transforms.append(normalize)      


    test_transforms = []
    test_transforms.append(transforms.Resize(args.resize))
    test_transforms.append(transforms.CenterCrop(224))#args.crop))
    test_transforms.append(transforms.ToTensor())
    #test_transforms.append(normalize)

    return train_transforms, test_transforms


In [476]:
class Clip(object):
    """Transformation to clip image values between 0 and 1."""

    def __call__(self, sample):
        return torch.clip(sample, 0, 1)

In [477]:
class RandomCrop(object):
    "Randomly crop an image"
    
    def __call__(self, sample):
        resize = 256
        #print(np.random.uniform(0.4*resize,resize,1))
        random_crop_size = int(np.random.uniform(0.6*resize,resize,1))
        sample=transforms.RandomCrop(random_crop_size)(sample)
        return sample

In [478]:
class RandomColorDistortion(object):
    "Apply random color distortions to the image"
    
    def __call__(self, sample):
        resize=256

        # Random color distortion
        strength = 1.0 # 1.0 imagenet setting and CIFAR uses 0.5
        brightness = 0.8 * strength 
        contrast = 0.8 * strength
        saturation = 0.8 * strength
        hue = 0.2 * strength
        prob = np.random.uniform(0,1,1) 
        if prob < 0.8:
            sample=transforms.ColorJitter(brightness, contrast, saturation, hue)(sample)

        # Random Grayscale
        sample=transforms.RandomGrayscale(p=0.2)(sample)

        # Gaussian blur also based on imagenet but not used for CIFAR
        #prob = np.random.uniform(0,1,1)
        #if prob < 0.3:
        #    sample=transforms.GaussianBlur(kernel_size=resize//10)(sample)
        #    sample=transforms.Pad(0)(sample)
        return sample 
    

In [479]:
def get_transforms_simclr(args):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    train_transforms = []
    # Resize all images to same size, then randomly crop and resize again
    train_transforms.append(transforms.Resize([args.resize, args.resize]))
    # Random affine
    train_transforms.append(transforms.RandomAffine(degrees=(-45, 45), translate=(0.1,0.1), scale=(0.7, 1.5), shear=(-25, 25)))
    # Random crop
    train_transforms.append(RandomCrop())
    # Resize again
    # train_transforms.append(transforms.Resize([args.resize, args.resize], interpolation=3))
    train_transforms.append(transforms.Resize([224, 224], interpolation=3))
    # Random horizontal flip 
    train_transforms.append(transforms.RandomHorizontalFlip())
    # Random color distortions
    train_transforms.append(RandomColorDistortion())
    # Convert to tensor
    train_transforms.append(transforms.ToTensor())
    # Clip values between 0 and 1 and normalize
    #train_transforms.append(Clip())
    #train_transforms.append(normalize)      

    test_transforms = []
    # Resize all images to same size, then center crop and resize again
    test_transforms.append(transforms.Resize([args.resize, args.resize]))
    crop_proportion=0.875
    test_transforms.append(transforms.CenterCrop([int(0.875*args.resize), int(0.875*args.resize)]))
    # test_transforms.append(transforms.Resize([args.resize, args.resize], interpolation=3))
    test_transforms.append(transforms.Resize([224, 224], interpolation=3))
    #Convert to tensor
    test_transforms.append(transforms.ToTensor())
    # Clip values between 0 and 1 and normalize
    #test_transforms.append(Clip())
    #test_transforms.append(normalize)

    return train_transforms, test_transforms

In [480]:
# Note this function needs to be editted to mimic function above 
def visualize_transforms_simclr(args, orig_img, split='train'):
    # Create array of images 
    print(orig_img)
    new_images = [orig_img]
    tt = ['Original image']
    #normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if split == 'train':
        # Resize all images to same size
        new_images = new_images + [transforms.Resize([args.resize, args.resize])(orig_img)]
        tt = tt + ['Resize original image']
        # Random affine
        new_images = new_images + [transforms.RandomAffine(degrees=(-45, 45), translate=(0.1,0.1), scale=(0.7, 1.5), shear=(-25, 25))(new_images[-1])]
        tt = tt + ['Random affine']
        # Random crop
        new_images = new_images + [RandomCrop()(new_images[-1])]
        tt = tt + ['Random Crop']
        # Resize to 256 x 256
        new_images = new_images + [transforms.Resize([args.resize, args.resize], interpolation=3)(new_images[-1])]
        tt = tt + ['Resize patch']
        # Random horizontal flip 
        new_images = new_images + [transforms.RandomHorizontalFlip()(new_images[-1])]
        tt = tt + ['Random horizontal flip']
        # Random color distortions
        new_images = new_images + [RandomColorDistortion()(new_images[-1])]
        tt = tt + ['Random color distortion']
        
        # Convert all to tensors
        for i in range(0, len(new_images)):
            new_images[i]=transforms.ToTensor()(new_images[i])
#         # Clip values between 0 and 1 and normalize
#         new_images = new_images + [Clip()(new_images[-1])]
#         tt = tt + ['Clip values (0,1)']
#         # Normalize values
#         new_images = new_images + [normalize(new_images[-1])]
#         tt = tt + ['Normalize values']
    return new_images, tt


In [481]:
class MIMICCXR(Dataset):
    def __init__(self, paths, args, transform=None, split='train'):
        self.data_dir = args.cxr_data_root
        self.args = args
        self.CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
       'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
       'Pneumonia', 'Pneumothorax', 'Support Devices']
        self.filenames_to_path = {path.split('/')[-1].split('.')[0]: path for path in paths}

        metadata = pd.read_csv(f'{self.data_dir}/mimic-cxr-2.0.0-metadata.csv')
        labels = pd.read_csv(f'{self.data_dir}/mimic-cxr-2.0.0-chexpert.csv')
        labels[self.CLASSES] = labels[self.CLASSES].fillna(0)
        labels = labels.replace(-1.0, 0.0)
        
        splits = pd.read_csv(f'{self.data_dir}/mimic-cxr-ehr-split.csv')


        metadata_with_labels = metadata.merge(labels[self.CLASSES+['study_id'] ], how='inner', on='study_id')


        self.filesnames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[self.CLASSES].values))
        self.filenames_loaded = splits.loc[splits.split==split]['dicom_id'].values
        self.transform = transform
        self.filenames_loaded = [filename  for filename in self.filenames_loaded if filename in self.filesnames_to_labels]

    def __getitem__(self, index):
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            #print(self.filenames_to_path[index])
            labels = torch.tensor(self.filesnames_to_labels[index]).float()
            if self.transform is not None:
                img = self.transform(img)
            return img, labels
          
        
        filename = self.filenames_loaded[index]
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')
        labels = torch.tensor(self.filesnames_to_labels[filename]).float()

        if self.transform is not None:
            img = self.transform(img)
        return img, labels
    
    def __len__(self):
        return len(self.filenames_loaded)

In [482]:
def get_cxr_datasets(args):
    if args.transforms_cxr=='simclrv2':
        print("Appling SimCLR image transforms...")
        train_transforms, test_transforms = get_transforms_simclr(args)
    else:
        print("Applying linear evaluation transforms...")
        train_transforms, test_transforms = get_transforms(args)

    data_dir = args.cxr_data_root
    filepath = f'{args.cxr_data_root}/new_paths.npy'
    if os.path.exists(filepath):
        paths = np.load(filepath)
    else:
        paths = glob.glob(f'{data_dir}/resized/**/*.jpg', recursive = True)
        np.save(filepath, paths)
    
    dataset_train = MIMICCXR(paths, args, split='train', transform=transforms.Compose(train_transforms))
    dataset_validate = MIMICCXR(paths, args, split='validate', transform=transforms.Compose(test_transforms),)
    dataset_test = MIMICCXR(paths, args, split='test', transform=transforms.Compose(test_transforms),)

    return dataset_train, dataset_validate, dataset_test

# EHR utils

In [483]:
from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import platform
import pickle
import json
import os

In [484]:
class Discretizer:
    def __init__(self, timestep=0.8, store_masks=True, impute_strategy='zero', start_time='zero',
                 config_path= '../ehr_utils/resources/discretizer_config.json'):

        with open(config_path) as f:
            config = json.load(f)
            self._id_to_channel = config['id_to_channel']
            self._channel_to_id = dict(zip(self._id_to_channel, range(len(self._id_to_channel))))
            self._is_categorical_channel = config['is_categorical_channel']
            self._possible_values = config['possible_values']
            self._normal_values = config['normal_values']

        self._header = ["Hours"] + self._id_to_channel
        self._timestep = timestep
        self._store_masks = store_masks
        self._start_time = start_time
        self._impute_strategy = impute_strategy

        # for statistics
        self._done_count = 0
        self._empty_bins_sum = 0
        self._unused_data_sum = 0

    def transform(self, X, header=None, end=None):
        if header is None:
            header = self._header
        assert header[0] == "Hours"
        eps = 1e-6

        N_channels = len(self._id_to_channel)
        ts = [float(row[0]) for row in X]
        for i in range(len(ts) - 1):
            assert ts[i] < ts[i+1] + eps

        if self._start_time == 'relative':
            first_time = ts[0]
        elif self._start_time == 'zero':
            first_time = 0
        else:
            raise ValueError("start_time is invalid")

        if end is None:
            max_hours = max(ts) - first_time
        else:
            max_hours = end - first_time

        N_bins = int(max_hours / self._timestep + 1.0 - eps)

        cur_len = 0
        begin_pos = [0 for i in range(N_channels)]
        end_pos = [0 for i in range(N_channels)]
        for i in range(N_channels):
            channel = self._id_to_channel[i]
            begin_pos[i] = cur_len
            if self._is_categorical_channel[channel]:
                end_pos[i] = begin_pos[i] + len(self._possible_values[channel])
            else:
                end_pos[i] = begin_pos[i] + 1
            cur_len = end_pos[i]

        data = np.zeros(shape=(N_bins, cur_len), dtype=float)
        mask = np.zeros(shape=(N_bins, N_channels), dtype=int)
        original_value = [["" for j in range(N_channels)] for i in range(N_bins)]
        total_data = 0
        unused_data = 0

        def write(data, bin_id, channel, value, begin_pos):
            channel_id = self._channel_to_id[channel]
            if self._is_categorical_channel[channel]:
                # print("list: ", self._possible_values[channel], "val: ", value, "channel:", channel)
                category_id = self._possible_values[channel].index(value)
                N_values = len(self._possible_values[channel])
                one_hot = np.zeros((N_values,))
                one_hot[category_id] = 1
                for pos in range(N_values):
                    data[bin_id, begin_pos[channel_id] + pos] = one_hot[pos]
            else:
                data[bin_id, begin_pos[channel_id]] = float(value)

        for row in X:
            t = float(row[0]) - first_time
            if t > max_hours + eps:
                continue
            bin_id = int(t / self._timestep - eps)
            assert 0 <= bin_id < N_bins

            for j in range(1, len(row)):
                if row[j] == "":
                    continue
                channel = header[j]
                channel_id = self._channel_to_id[channel]

                total_data += 1
                if mask[bin_id][channel_id] == 1:
                    unused_data += 1
                mask[bin_id][channel_id] = 1

                write(data, bin_id, channel, row[j], begin_pos)
                original_value[bin_id][channel_id] = row[j]

        # impute missing values

        if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']:
            raise ValueError("impute strategy is invalid")

        if self._impute_strategy in ['normal_value', 'previous']:
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if self._impute_strategy == 'normal_value':
                        imputed_value = self._normal_values[channel]
                    if self._impute_strategy == 'previous':
                        if len(prev_values[channel_id]) == 0:
                            imputed_value = self._normal_values[channel]
                        else:
                            imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        if self._impute_strategy == 'next':
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins-1, -1, -1):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if len(prev_values[channel_id]) == 0:
                        imputed_value = self._normal_values[channel]
                    else:
                        imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        empty_bins = np.sum([1 - min(1, np.sum(mask[i, :])) for i in range(N_bins)])
        self._done_count += 1
        self._empty_bins_sum += empty_bins / (N_bins + eps)
        self._unused_data_sum += unused_data / (total_data + eps)

        if self._store_masks:
            data = np.hstack([data, mask.astype(np.float32)])

        # create new header
        new_header = []
        for channel in self._id_to_channel:
            if self._is_categorical_channel[channel]:
                values = self._possible_values[channel]
                for value in values:
                    new_header.append(channel + "->" + value)
            else:
                new_header.append(channel)

        if self._store_masks:
            for i in range(len(self._id_to_channel)):
                channel = self._id_to_channel[i]
                new_header.append("mask->" + channel)

        new_header = ",".join(new_header)

        return (data, new_header)

    def print_statistics(self):
        print("statistics of discretizer:")
        print("\tconverted {} examples".format(self._done_count))
        print("\taverage unused data = {:.2f} percent".format(100.0 * self._unused_data_sum / self._done_count))
        print("\taverage empty  bins = {:.2f} percent".format(100.0 * self._empty_bins_sum / self._done_count))


In [485]:
class Normalizer:
    def __init__(self, fields=None):
        self._means = None
        self._stds = None
        self._fields = None
        if fields is not None:
            self._fields = [col for col in fields]

        self._sum_x = None
        self._sum_sq_x = None
        self._count = 0

    def _feed_data(self, x):
        x = np.array(x)
        self._count += x.shape[0]
        if self._sum_x is None:
            self._sum_x = np.sum(x, axis=0)
            self._sum_sq_x = np.sum(x**2, axis=0)
        else:
            self._sum_x += np.sum(x, axis=0)
            self._sum_sq_x += np.sum(x**2, axis=0)

    def _save_params(self, save_file_path):
        eps = 1e-7
        with open(save_file_path, "wb") as save_file:
            N = self._count
            self._means = 1.0 / N * self._sum_x
            self._stds = np.sqrt(1.0/(N - 1) * (self._sum_sq_x - 2.0 * self._sum_x * self._means + N * self._means**2))
            self._stds[self._stds < eps] = eps
            pickle.dump(obj={'means': self._means,
                             'stds': self._stds},
                        file=save_file,
                        protocol=2)

    def load_params(self, load_file_path):
        with open(load_file_path, "rb") as load_file:
            if platform.python_version()[0] == '2':
                dct = pickle.load(load_file)
            else:
                dct = pickle.load(load_file, encoding='latin1')
            self._means = dct['means']
            self._stds = dct['stds']

    def transform(self, X):
        if self._fields is None:
            fields = range(X.shape[1])
        else:
            fields = self._fields
        ret = 1.0 * X
        for col in fields:
            ret[:, col] = (X[:, col] - self._means[col]) / self._stds[col]
        return ret

# EHR Dataset

In [486]:
discretizer = Discretizer()
normalizer = Normalizer()
normalizer.load_params('../ph_ts0.8.input_str:previous.start_time:zero.normalizer')

In [487]:
class MultiTransform(object):

    def __init__(
        self,
        views,
        normal_values,
        _is_categorical_channel,
        augmentation,
        begin_pos
    ):
        self.views = views
        self.normal_values = normal_values
        self.rows = np.array([value for value in self.normal_values.values()])
        self.augmentation = augmentation
        self.continuous_variable = [0 if _is_categorical_channel[key] == True else 1 for key in _is_categorical_channel]
        self.begin_pos = begin_pos
        
    def vertical_mask(self, data, max_percent=0.4):
        # mask over each timestep (t, features)
        length = data.shape[0]
        if length < 4:
            return data
        size = int(np.random.randint(low=0, high=max(int(max_percent*length),1), size=1))
        a = np.zeros(length , dtype=int)
        a[:size] = 1
        np.random.shuffle(a)
        a = a.astype(bool)
        data[a,1:] = self.rows
        return data

    def horizontal_mask(self, data, max_percent=0.4):
        # mask over each feature (t, features)
        length = data.shape[1] - 1
        size = int(np.random.randint(low=0, high=max(int(max_percent*length),1), size=1))
        features = np.unique(np.random.randint(low=1, high=length, size=size))
        for i in features:
            data[:,i+1] = self.normal_values[i]
        return data
    
    def drop_start(self, data, max_percent=0.4):
        length = data.shape[0]
        start = int(np.random.randint(low=0, high=max(int(max_percent*length),1), size=1))
        return data[start:,:]

    def gaussian_blur(self, data):
        mean, std = 1,0 
        data[:, self.begin_pos] = data[:, self.begin_pos]  + np.random.normal(mean, std, (data.shape[0], len(self.begin_pos)))
        return data

    def rotation(self, data):
        if choice([0,1]):
            return np.flip(data, axis=0)
        return data

    def downsample(self, data):
        if data.shape[0] < 20:
            return data
        step = choice([1, 2, 3])
        return data[::step]

    def __call__(self, data):
        data_views = []                    
        data_views.append(self.vertical_mask(data))
        data_views.append(self.horizontal_mask(data))
        data_views.append(self.horizontal_mask(self.vertical_mask(data)))
        data_views.append((self.drop_start(data)))
        data_views.append(data)

        return data_views

In [488]:
def get_transforms(args):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    train_transforms = []
    train_transforms.append(transforms.Resize(args.resize))
    train_transforms.append(transforms.RandomHorizontalFlip())
    train_transforms.append(transforms.RandomAffine(degrees=45, scale=(.85, 1.15), shear=0, translate=(0.15, 0.15)))
    train_transforms.append(transforms.CenterCrop(args.crop))
    train_transforms.append(transforms.ToTensor())
    #train_transforms.append(normalize)      


    test_transforms = []
    test_transforms.append(transforms.Resize(args.resize))
    test_transforms.append(transforms.CenterCrop(args.crop))
    test_transforms.append(transforms.ToTensor())
    #test_transforms.append(normalize)

    return train_transforms, test_transforms

In [559]:
class EHRdataset(Dataset):
    def __init__(self, discretizer, normalizer, listfile, dataset_dir, return_names=True, period_length=48.0, transforms=None):
        self.return_names = return_names
        self.discretizer = discretizer
        self.normalizer = normalizer
        self._period_length = period_length

        self._dataset_dir = dataset_dir
        listfile_path = listfile
        with open(listfile_path, "r") as lfile:
            self._data = lfile.readlines()
        self._listfile_header = self._data[0]
        self.CLASSES = self._listfile_header.strip().split(',')[3:]
        self._data = self._data[1:]
        self.transforms = transforms


        self._data = [line.split(',') for line in self._data]
        self.data_map = {
            mas[0]: {
                'labels': list(map(int, mas[3:])),
                'stay_id': float(mas[2]),
                'time': float(mas[1]),
                }
                for mas in self._data
        }

        self.names = list(self.data_map.keys())
    
#     def _read_timeseries(self, ts_filename):
        
#         ret = []
#         with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
#             header = tsfile.readline().strip().split(',')
#             assert header[0] == "Hours"
#             for line in tsfile:
#                 mas = line.strip().split(',')
#                 ret.append(np.array(mas))
#         return (np.stack(ret), header)

    def read_timeseries(self,ts_filename, lower_bound=0,upper_bound=12):
        
        ret = []
        with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
            header = tsfile.readline().strip().split(',')
            assert header[0] == "Hours"
            for line in tsfile:
                mas = line.strip().split(',')
            
                t = float(mas[0])
                if t < lower_bound:
                    continue
                elif (t> lower_bound) & (t <upper_bound) :
                    ret.append(np.array(mas))
                elif t > upper_bound:
                    break
            
    #             if time_bound is not None:
    #                 t = float(mas[0])
    #                 if t > time_bound + 1e-6:
    #                     break
    #             ret.append(np.array(mas))
        try: 
            return (np.stack(ret), header)
        except ValueError:
            
            ret = ([['0.11666666666666667', '', '', '', '', '', '', '', '', '109', '',
                     '', '', '30', '', '', '', ''],
                    ['0.16666666666666666', '', '61.0', '', '', '', '', '', '', '109',
                    '', '64', '97.0', '29', '74.0', '', '', '']])
            return (np.stack(ret), header)
            
         
    
#     def read_by_file_name(self, index):
#         t = self.data_map[index]['time']
#         y = self.data_map[index]['labels']
#         stay_id = self.data_map[index]['stay_id']
#         (X, header) = self._read_timeseries(index)
#         print(index)
#         return {"X": X,
#                 "t": t,
#                 "y": y,
#                 'stay_id': stay_id,
#                 "header": header,
#                 "name": index}

    def read_by_file_name(self,index, lower_bound=0,upper_bound=12):
        t = self.data_map[index]['time'] #if upper_bound is None else upper_bound
        y = self.data_map[index]['labels']
        stay_id = self.data_map[index]['stay_id']
        (X, header) = self.read_timeseries(index, lower_bound=lower_bound,upper_bound=upper_bound)
        
        return {"X": X,
                "t": t,
                "y": y,
                'stay_id': stay_id,
                "header": header,
                "name": index}

    def __getitem__(self, index,lower,upper):
        if isinstance(index, int):
            index = self.names[index]
        ret = self.read_by_file_name(index,lower,upper)
        data = ret["X"]
#         print(index)
        ts = data.shape[0]#ret["t"] if ret['t'] > 0.0 else self._period_length
        
        
        ## Added block
        if self.transforms is not None:
            data = self.transforms(data)
            
            for i in range(len(data)):
                data[i] = self.discretizer.transform(data[i], end=ts)[0]
                print(data[i]).shape
                if 'gaussian' in self.transforms.augmentation and i != 0:
                    data[i] = self.transforms.gaussian_blur(data[i])
                if 'sampling' in self.transforms.augmentation and i != 0: #carry last value forward 
                    data[i] = self.transforms.downsample(data[i])
                if (self.normalizer is not None):
                    data[i] = self.normalizer.transform(data[i])
        else:
            data = self.discretizer.transform(data, end=ts)[0]
            if (self.normalizer is not None):
                data = self.normalizer.transform(data)
        #########  
        
        # data = self.discretizer.transform(data, end=ts)[0] 
        # if (self.normalizer is not None):
        #     data = self.normalizer.transform(data)
        # print(data.shape)

        ys = ret["y"]
        names = ret["name"]
        ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
        stay_ids = ret['stay_id']
        return data, ys

    
    def __len__(self):
        return len(self.names)

In [560]:
def get_datasets(discretizer, normalizer, args):
    # if context == True:
    #     transform = MultiTransform(views=11, normal_values=discretizer._id_normal_values, _is_categorical_channel=discretizer._is_categorical_channel, augmentation=augmentation, begin_pos=begin_pos)
    # else:
    #     transform = None
    # changed definition of normal_values
    # augmentation = 'gaussian'
    # transform = MultiTransform(views=11, normal_values=discretizer._normal_values, _is_categorical_channel=discretizer._is_categorical_channel, augmentation=augmentation, begin_pos=discretizer._start_time)
    transform = None
    train_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/train_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'), transforms=transform)
    val_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/val_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'), transforms = transform)
    test_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/test_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/test'), transforms = transform)
    return train_ds, val_ds, test_ds

In [561]:
def my_collate(batch):
    x = [item[0] for item in batch]
    x, seq_length = pad_zeros(x)
    targets = np.array([item[1] for item in batch])
    return [x, targets, seq_length]

In [562]:
def pad_zeros(arr, min_length=None):

    dtype = arr[0].dtype
    seq_length = [x.shape[0] for x in arr]
    max_len = max(seq_length)
    ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
           for x in arr]
    if (min_length is not None) and ret[0].shape[0] < min_length:
        ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
               for x in ret]
    return np.array(ret), seq_length

In [563]:
def get_data_loader(discretizer, normalizer, dataset_dir, batch_size):
    train_ds, val_ds, test_ds = get_datasets(discretizer, normalizer, args)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=my_collate, pin_memory=True, num_workers=16)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=my_collate, pin_memory=True, num_workers=16)

    return train_dl, val_dl

In [564]:
ehrtr, ehrva, ehrts = get_datasets(discretizer,normalizer,args)
cxrtr, cxrva, cxrts = get_cxr_datasets(args)

Appling SimCLR image transforms...


In [566]:
ehrtr.__getitem__(4,0,12)

(array([[-1.66062755e+01, -6.02181801e-02,  9.75482416e-03, ...,
         -5.58505849e-01,  6.88005369e+00, -2.91264127e-01],
        [-1.66062755e+01, -6.02181801e-02,  4.87291875e-02, ...,
         -5.58505849e-01, -1.45347657e-01, -2.91264127e-01],
        [-1.66062755e+01, -6.02181801e-02,  2.27462786e-02, ...,
         -5.58505849e-01, -1.45347657e-01, -2.91264127e-01],
        ...,
        [-1.66062755e+01, -6.02181801e-02, -2.67396204e-01, ...,
         -5.58505849e-01, -1.45347657e-01, -2.91264127e-01],
        [-1.66062755e+01, -6.02181801e-02, -2.67396204e-01, ...,
         -5.58505849e-01, -1.45347657e-01, -2.91264127e-01],
        [-1.66062755e+01, -6.02181801e-02, -2.67396204e-01, ...,
         -5.58505849e-01, -1.45347657e-01, -2.91264127e-01]]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0], dtype=int32),
 31205490.0)

# Fusion Dataset

In [496]:
class MIMIC_CXR_EHR(Dataset):
    def __init__(self, args, metadata_with_labels, ehr_ds, cxr_ds, split='train'):
        
        # select classes
        self.CLASSES = CLASSES
        if 'radiology' in args.labels_set:
            self.CLASSES = R_CLASSES
        
        self.metadata_with_labels = metadata_with_labels
        
        self.cxr_files_paired = self.metadata_with_labels.dicom_id.values
        self.ehr_files_paired = (self.metadata_with_labels['stay'].values)
        self.time_diff = self.metadata_with_labels.time_diff
        self.lower = self.metadata_with_labels.lower
        self.upper = self.metadata_with_labels.upper
        
        self.cxr_files_all = cxr_ds.filenames_loaded
        self.ehr_files_all = ehr_ds.names
        
        self.ehr_files_unpaired = list(set(self.ehr_files_all) - set(self.ehr_files_paired))
        
        self.ehr_ds = ehr_ds
        self.cxr_ds = cxr_ds
        
        self.args = args
        self.split = split
        self.data_ratio = self.args.data_ratio if split=='train' else 1.0

    def __getitem__(self, index):
        if self.args.data_pairs == 'paired':
            cxr_data, labels_cxr = self.cxr_ds[self.cxr_files_paired[index]]
            
            lower = self.metadata_with_labels.iloc[index].lower
            upper = self.metadata_with_labels.iloc[index].upper
            
            ehr_data, labels_ehr = self.ehr_ds.__getitem__(self.ehr_files_paired[index],lower,upper)
            time_diff = self.metadata_with_labels.iloc[index].time_diff
            
            #dicom_id =  self.cxr_files_paired[index]
            #stay_id = self.ehr_files_paired[index]
            #time_diff = 
                        
            if self.args.beta_infonce:
                return ehr_data, cxr_data, labels_ehr, labels_cxr, time_diff
            else:
                return ehr_data, cxr_data, labels_ehr, labels_cxr
        
        elif self.args.data_pairs == 'radiology':
            ehr_data, labels_ehr = np.zeros((1, 10)), np.zeros(self.args.num_classes)
            cxr_data, labels_cxr = self.cxr_ds[self.cxr_files_all[index]]
            return ehr_data, cxr_data, labels_ehr, labels_cxr
        
        elif self.args.data_pairs == 'ehr_only':
            ehr_data, labels_ehr = self.ehr_ds[self.ehr_files_all[index]]
            cxr_data, labels_cxr = None, None
            return ehr_data, cxr_data, labels_ehr, labels_cxr
        
        elif self.args.data_pairs == 'joint_ehr':
            if index < len(self.ehr_files_paired):
                ehr_data, labels_ehr = self.ehr_ds[self.ehr_files_paired[index]]
                cxr_data, labels_cxr = self.cxr_ds[self.cxr_files_paired[index]]
            else:
                index = random.randint(0, len(self.ehr_files_unpaired)-1) 
                ehr_data, labels_ehr = self.ehr_ds[self.ehr_files_unpaired[index]]
                cxr_data, labels_cxr = None, None
            return ehr_data, cxr_data, labels_ehr, labels_cxr

       
    
    def __len__(self):
        if self.args.data_pairs == 'paired':
            return len(self.ehr_files_paired)
        elif self.args.data_pairs == 'ehr_only':
            return len(self.ehr_files_all)
        elif self.args.data_pairs == 'radiology':
            return len(self.cxr_files_all)
        elif self.args.data_pairs == 'joint_ehr':
            return len(self.ehr_files_paired) + int(self.data_ratio * len(self.ehr_files_unpaired)) 

In [497]:
MIMIC_CXR_EHR(args,train_meta_with_labels,ehrtr,cxrtr).__getitem__(3)[0].shape

(43, 76)

In [498]:
def loadmetadata(args):

    def time_offsets(table):
        ids = table.stay_id.unique()
        data = []
        for id in ids:
            temp = table[table.stay_id == id]
            offsets = list(range(0,int(temp.LOS.max())+12,12))
            times = []
            for i,time in enumerate(temp.intime):
                times.append(time+ timedelta(hours=offsets[i])) 
            temp.intime = times
            data.append(temp)
        data = pd.concat(data,ignore_index=True)
        data.time_diff = (data.StudyDateTime - data.intime).apply(lambda x: np.round(x.total_seconds()/60/60,3))
        data['lower'] = data.LOS + (data.intime - data.outtime).apply(lambda x: np.round(x.total_seconds()/60/60,3))
        data['upper'] = data.apply(lambda x: x.lower + 12 if (x.lower + 12) < x.LOS else (x.LOS+1),axis=1)
        return data
    
    cxr_metadata = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-metadata.csv')
    print('Number of CXR images=', len(cxr_metadata))
    icu_stay_metadata = pd.read_csv(f'{args.ehr_data_root}/root/all_stays.csv')
    print('Number of ICU stays=', len(icu_stay_metadata))
    columns = ['subject_id', 'stay_id', 'intime', 'outtime']
    
    # only common subjects with both icu stay and an xray
    # Note that inner merge includes rows if a chest X-ray is associated with multiple stays
    cxr_merged_icustays = cxr_metadata.merge(icu_stay_metadata[columns], how='inner', on='subject_id')
    print('Number of CXR associated with ICU stay based on subject ID=', len(cxr_merged_icustays))
    print('Number of unique CXR dicoms=', len(cxr_merged_icustays.dicom_id.unique()))
    print('Number of unique CXR study id=', len(cxr_merged_icustays.study_id.unique()))
        
    # combine study date time
    cxr_merged_icustays['StudyTime'] = cxr_merged_icustays['StudyTime'].apply(lambda x: f'{int(float(x)):06}' )
    cxr_merged_icustays['StudyDateTime'] = pd.to_datetime(cxr_merged_icustays['StudyDate'].astype(str) + ' ' + cxr_merged_icustays['StudyTime'].astype(str) ,format="%Y%m%d %H%M%S")

    cxr_merged_icustays.intime=pd.to_datetime(cxr_merged_icustays.intime)
    cxr_merged_icustays.outtime=pd.to_datetime(cxr_merged_icustays.outtime)
    
    cxr_merged_icustays['time_diff'] = cxr_merged_icustays.StudyDateTime-cxr_merged_icustays.intime
    cxr_merged_icustays['time_diff'] = cxr_merged_icustays['time_diff'].apply(lambda x: np.round(x.total_seconds()/60/60,3))
    
    cxr_merged_icustays['LOS'] = cxr_merged_icustays.outtime-cxr_merged_icustays.intime
    cxr_merged_icustays['LOS'] = cxr_merged_icustays['LOS'].apply(lambda x: np.round(x.total_seconds()/60/60,3))
    
    
    # For LE/ FT  (evaluation datasets)
    if (args.dataset !='all'):

        if args.task == 'decompensation' or args.task == 'length-of-stay':
            train_listfile = pd.read_csv(f'/scratch/se1525/mml-ssl/{args.task}/train_listfile.csv')
            train_listfile.columns = ['stay' , 'period_length' , 'stay_id' ,'y_true', 'intime' , 'endtime']
            test_listfile = pd.read_csv(f'/scratch/se1525/mml-ssl/{args.task}/test_listfile.csv')
            test_listfile.columns = ['stay' , 'period_length' , 'stay_id' ,'y_true', 'intime' , 'endtime']
            listfile = train_listfile.append(test_listfile)
            listfile['subject_id'] = listfile['stay'].apply(lambda x: x.split("_")[0])
            print(listfile.head)

            columns2 = ['subject_id', 'endtime']
            listfile['subject_id'] = listfile['subject_id'].astype('int64')
            cxr_merged_icustays = cxr_merged_icustays.merge(listfile[columns2], how='inner', on='subject_id')
            cxr_merged_icustays.endtime=pd.to_datetime(cxr_merged_icustays.endtime)
            cxr_merged_icustays_during = cxr_merged_icustays.loc[((cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&(cxr_merged_icustays.StudyDateTime<=cxr_merged_icustays.endtime))]

        if args.task == 'in-hospital-mortality':
            end_time = cxr_merged_icustays.intime + pd.DateOffset(hours=48)
            cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&((cxr_merged_icustays.StudyDateTime<=end_time))]

        if args.task == 'phenotyping':
            end_time = cxr_merged_icustays.outtime
            cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&
                                                                 ((cxr_merged_icustays.StudyDateTime<=end_time))]
        
        # select cxrs with the ViewPosition == 'AP
        cxr_merged_icustays_AP = cxr_merged_icustays_during[cxr_merged_icustays_during['ViewPosition'] == 'AP']
        
        if args.retrive_cxr == 'recent':
            groups = cxr_merged_icustays_AP.groupby('stay_id')
            groups_selected = []
            for group in groups:
                # select the latest cxr for the icu stay
                selected = group[1].sort_values('StudyDateTime').tail(1).reset_index()
                groups_selected.append(selected)
            groups = pd.concat(groups_selected, ignore_index=True)
            groups = groups.groupby('study_id').first()
            groups = groups.reset_index()
            groups = groups.groupby('study_id').first().sort_values(by=['stay_id','StudyDateTime'])
            groups = groups.reset_index()
            #groups['num_cxr_windows'] = groups.groupby(['stay_id'])['stay_id'].transform('count')
            #groups['cxr_window_length'] = groups['LOS']/groups['num_cxr_windows']
            groups['num_ehr_windows'] = np.ceil(groups['LOS']/12).astype(int)
            groups = groups.loc[groups.index.repeat(groups.num_ehr_windows)].reset_index(drop=True)
            groups = time_offsets(groups)
        else: 
            groups = cxr_merged_icustays_AP.groupby('study_id').first()
            groups = groups.reset_index()
            groups = groups.groupby('study_id').first().sort_values(by=['stay_id','StudyDateTime'])
            groups = groups.reset_index()
            #groups['num_cxr_windows'] = groups.groupby(['stay_id'])['stay_id'].transform('count')
            #groups['cxr_window_length'] = groups['LOS']/groups['num_cxr_windows']
            groups['num_ehr_windows'] = np.ceil(groups['LOS']/12).astype(int)
            
            
    # For SIMCLR pretraining (large dataset)
#     else:
#         # print(cxr_merged_icustays.ViewPosition.unique())
#         cxr_merged_icustays_AP = cxr_merged_icustays[cxr_merged_icustays['ViewPosition'] == 'AP']
#         print("Number of CXR associated with ICU stay and in AP view=", len(cxr_merged_icustays_AP))
#         groups = cxr_merged_icustays_AP
        
    print("Mean time cxr - intime= ", groups.time_diff.mean())
    print("Minimum time =", groups.time_diff.min())
    print("Maximum time =", groups.time_diff.max())

#     plt.hist(groups.time_diff.apply(lambda x: x.days).astype("float64"))
#     plt.xlabel('Time difference in days')
#     plt.show()

    #print(groups.iloc[0])
    return groups

In [499]:
meta = loadmetadata(args)

Number of CXR images= 377110
Number of ICU stays= 59372
Number of CXR associated with ICU stay based on subject ID= 368350
Number of unique CXR dicoms= 181195
Number of unique CXR study id= 122087
Mean time cxr - intime=  74.23168108953072
Minimum time = -601.239
Maximum time = 2368.942


In [500]:
def load_cxr_ehr(args, ehr_train_ds, ehr_val_ds, ehr_test_ds, cxr_train_ds, cxr_val_ds, cxr_test_ds):
    
    # Load cxr and ehr groups
    cxr_merged_icustays = loadmetadata(args) 
    
    # Add the labels 
    splits_labels_train = pd.read_csv(f'{args.ehr_data_root}/{args.task}/train_listfile.csv')
    splits_labels_val = pd.read_csv(f'{args.ehr_data_root}/{args.task}/val_listfile.csv')
    splits_labels_test = pd.read_csv(f'{args.ehr_data_root}/{args.task}/test_listfile.csv')
    
    #TODO: investigate why total size of cxr_merged_icustays drops after the three steps below
    train_meta_with_labels = cxr_merged_icustays.merge(splits_labels_train, how='inner', on='stay_id')#change dataset size here
    val_meta_with_labels = cxr_merged_icustays.merge(splits_labels_val, how='inner', on='stay_id')
    test_meta_with_labels = cxr_merged_icustays.merge(splits_labels_test, how='inner', on='stay_id')
    
    # Get rid of chest X-rays that don't have radiology reports
    metadata = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-metadata.csv')
    labels = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-chexpert.csv')
    metadata_with_labels = metadata.merge(labels[['study_id']], how='inner', on='study_id').drop_duplicates(subset=['dicom_id'])
    train_meta_with_labels = train_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    val_meta_with_labels = val_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    test_meta_with_labels = test_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    
    print("Excluding CXR with missing radiology reports = ",len(train_meta_with_labels))

    # Multimodal class
    train_ds = MIMIC_CXR_EHR(args, train_meta_with_labels, ehr_train_ds, cxr_train_ds)
    print(len(train_ds))
    val_ds = MIMIC_CXR_EHR(args, val_meta_with_labels, ehr_val_ds, cxr_val_ds, split='val')
    print(len(val_ds))
    test_ds = MIMIC_CXR_EHR(args, test_meta_with_labels, ehr_test_ds, cxr_test_ds, split='test')
    print(len(test_ds))
    
    if args.beta_infonce:
        collate = my_collate_beta
    else:
        collate = my_collate
    
    # Multimodal data loader 
    train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, collate_fn=collate, drop_last=True)#, pin_memory=True, num_workers=24)
    val_dl = DataLoader(val_ds, args.batch_size, shuffle=False, collate_fn=collate, drop_last=False) #pin_memory=True, num_workers=16,
    test_dl = DataLoader(test_ds, args.batch_size, shuffle=False, collate_fn=collate, drop_last=False) # pin_memory=True,num_workers=16,

    return train_dl, val_dl, test_dl


In [501]:
def my_collate(batch):
    x = [item[0] for item in batch]
    pairs = [False if item[1] is None else True for item in batch]
    img = torch.stack([torch.zeros(3, 224, 224) if item[1] is None else item[1] for item in batch])
    x, seq_length = pad_zeros(x)
    targets_ehr = np.array([item[2] for item in batch])
    targets_cxr = torch.stack([torch.zeros(14) if item[3] is None else item[3] for item in batch])
    return [x, img, targets_ehr, targets_cxr, seq_length, pairs]

In [502]:
def my_collate_beta(batch, beta_infonce=False):
    x = [item[0] for item in batch]
    pairs = [False if item[1] is None else True for item in batch]
    img = torch.stack([torch.zeros(3, 224, 224) if item[1] is None else item[1] for item in batch])
    x, seq_length = pad_zeros(x)
    targets_ehr = np.array([item[2] for item in batch])
    targets_cxr = torch.stack([torch.zeros(14) if item[3] is None else item[3] for item in batch])
    time_diff = [item[4] for item in batch]
    
    return [x, img, targets_ehr, targets_cxr, seq_length, pairs, time_diff]

In [503]:
def pad_zeros(arr, min_length=None):
    dtype = arr[0].dtype
    seq_length = [x.shape[0] for x in arr]
    max_len = max(seq_length)
    ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
           for x in arr]
    if (min_length is not None) and ret[0].shape[0] < min_length:
        ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
               for x in ret]
    return np.array(ret), seq_length

In [504]:
splits_labels_train = pd.read_csv(f'{args.ehr_data_root}/{args.task}/train_listfile.csv')
splits_labels_train.head(1)

Unnamed: 0,stay,period_length,stay_id,Acute and unspecified renal failure,Acute cerebrovascular disease,Acute myocardial infarction,Cardiac dysrhythmias,Chronic kidney disease,Chronic obstructive pulmonary disease and bronchiectasis,Complications of surgical procedures or medical care,Conduction disorders,Congestive heart failure; nonhypertensive,Coronary atherosclerosis and other heart disease,Diabetes mellitus with complications,Diabetes mellitus without complication,Disorders of lipid metabolism,Essential hypertension,Fluid and electrolyte disorders,Gastrointestinal hemorrhage,Hypertension with complications and secondary hypertension,Other liver diseases,Other lower respiratory disease,Other upper respiratory disease,Pleurisy; pneumothorax; pulmonary collapse,Pneumonia (except that caused by tuberculosis or sexually transmitted disease),Respiratory failure; insufficiency; arrest (adult),Septicemia (except in labor),Shock
0,10000032_episode1_timeseries.csv,9.846389,39553978,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0


In [505]:
train_meta_with_labels = meta.merge(splits_labels_train, how='inner', on='stay_id')
train_meta_with_labels.head(1)

Unnamed: 0,study_id,index,dicom_id,subject_id,PerformedProcedureStepDescription,ViewPosition,Rows,Columns,StudyDate,StudyTime,ProcedureCodeSequence_CodeMeaning,ViewCodeSequence_CodeMeaning,PatientOrientationCodeSequence_CodeMeaning,stay_id,intime,outtime,StudyDateTime,time_diff,LOS,num_ehr_windows,lower,upper,stay,period_length,Acute and unspecified renal failure,Acute cerebrovascular disease,Acute myocardial infarction,Cardiac dysrhythmias,Chronic kidney disease,Chronic obstructive pulmonary disease and bronchiectasis,Complications of surgical procedures or medical care,Conduction disorders,Congestive heart failure; nonhypertensive,Coronary atherosclerosis and other heart disease,Diabetes mellitus with complications,Diabetes mellitus without complication,Disorders of lipid metabolism,Essential hypertension,Fluid and electrolyte disorders,Gastrointestinal hemorrhage,Hypertension with complications and secondary hypertension,Other liver diseases,Other lower respiratory disease,Other upper respiratory disease,Pleurisy; pneumothorax; pulmonary collapse,Pneumonia (except that caused by tuberculosis or sexually transmitted disease),Respiratory failure; insufficiency; arrest (adult),Septicemia (except in labor),Shock
0,59469162,291809,5dfd960b-2e6378a2-1de9c84f-24ec4b38-f9d1a19a,17938576,CHEST (PORTABLE AP),AP,2539,3050,21580123,213357,CHEST (PORTABLE AP),antero-posterior,Erect,30002498,2158-01-23 16:00:00,2158-01-24 17:36:04,2158-01-23 21:33:57,5.566,25.601,3,0.0,12.0,17938576_episode1_timeseries.csv,25.601111,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1


In [506]:
metadata = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-metadata.csv')
labels = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-chexpert.csv')

In [507]:
metadata.head(1)

Unnamed: 0,dicom_id,subject_id,study_id,PerformedProcedureStepDescription,ViewPosition,Rows,Columns,StudyDate,StudyTime,ProcedureCodeSequence_CodeMeaning,ViewCodeSequence_CodeMeaning,PatientOrientationCodeSequence_CodeMeaning
0,02aa804e-bde0afdd-112c0b34-7bc16630-4e384014,10000032,50414267,CHEST (PA AND LAT),PA,3056,2544,21800506,213014.531,CHEST (PA AND LAT),postero-anterior,Erect


In [508]:
labels.head(1)

Unnamed: 0,subject_id,study_id,Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices
0,10000032,50414267,,,,,,,,,1.0,,,,,


In [509]:
metadata_with_labels = metadata.merge(labels[['study_id']], how='inner', on='study_id').drop_duplicates(subset=['dicom_id'])
metadata_with_labels

Unnamed: 0,dicom_id,subject_id,study_id,PerformedProcedureStepDescription,ViewPosition,Rows,Columns,StudyDate,StudyTime,ProcedureCodeSequence_CodeMeaning,ViewCodeSequence_CodeMeaning,PatientOrientationCodeSequence_CodeMeaning
0,02aa804e-bde0afdd-112c0b34-7bc16630-4e384014,10000032,50414267,CHEST (PA AND LAT),PA,3056,2544,21800506,213014.531,CHEST (PA AND LAT),postero-anterior,Erect
1,174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962,10000032,50414267,CHEST (PA AND LAT),LATERAL,3056,2544,21800506,213014.531,CHEST (PA AND LAT),lateral,Erect
2,2a2277a9-b0ded155-c0de8eb9-c124d10e-82c5caab,10000032,53189527,CHEST (PA AND LAT),PA,3056,2544,21800626,165500.312,CHEST (PA AND LAT),postero-anterior,Erect
3,e084de3b-be89b11e-20fe3f9f-9c8d8dfe-4cfd202c,10000032,53189527,CHEST (PA AND LAT),LATERAL,3056,2544,21800626,165500.312,CHEST (PA AND LAT),lateral,Erect
4,68b5c4b1-227d0485-9cc38c3f-7b84ab51-4b472714,10000032,53911762,CHEST (PORTABLE AP),AP,2705,2539,21800723,80556.875,CHEST (PORTABLE AP),antero-posterior,
...,...,...,...,...,...,...,...,...,...,...,...,...
377090,428e2c18-5721d8f3-35a05001-36f3d080-9053b83c,19999733,57132437,CHEST (PA AND LAT),PA,3056,2544,21520708,224550.171,CHEST (PA AND LAT),postero-anterior,Erect
377091,58c403aa-35ff8bd9-73e39f54-8dc9cc5d-e0ec3fa9,19999733,57132437,CHEST (PA AND LAT),LATERAL,3056,2544,21520708,224550.171,CHEST (PA AND LAT),lateral,Erect
377092,58766883-376a15ce-3b323a28-6af950a0-16b793bd,19999987,55368167,CHEST (PORTABLE AP),AP,2544,3056,21451104,51448.218,CHEST (PORTABLE AP),antero-posterior,Erect
377093,7ba273af-3d290f8d-e28d0ab4-484b7a86-7fc12b08,19999987,58621812,CHEST (PORTABLE AP),AP,3056,2544,21451102,202809.234,CHEST (PORTABLE AP),antero-posterior,Erect


In [510]:
train_meta_with_labels = train_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
train_meta_with_labels.head(5)

Unnamed: 0,study_id,index,dicom_id,subject_id,PerformedProcedureStepDescription,ViewPosition,Rows,Columns,StudyDate,StudyTime,ProcedureCodeSequence_CodeMeaning,ViewCodeSequence_CodeMeaning,PatientOrientationCodeSequence_CodeMeaning,stay_id,intime,outtime,StudyDateTime,time_diff,LOS,num_ehr_windows,lower,upper,stay,period_length,Acute and unspecified renal failure,Acute cerebrovascular disease,Acute myocardial infarction,Cardiac dysrhythmias,Chronic kidney disease,Chronic obstructive pulmonary disease and bronchiectasis,Complications of surgical procedures or medical care,Conduction disorders,Congestive heart failure; nonhypertensive,Coronary atherosclerosis and other heart disease,Diabetes mellitus with complications,Diabetes mellitus without complication,Disorders of lipid metabolism,Essential hypertension,Fluid and electrolyte disorders,Gastrointestinal hemorrhage,Hypertension with complications and secondary hypertension,Other liver diseases,Other lower respiratory disease,Other upper respiratory disease,Pleurisy; pneumothorax; pulmonary collapse,Pneumonia (except that caused by tuberculosis or sexually transmitted disease),Respiratory failure; insufficiency; arrest (adult),Septicemia (except in labor),Shock
0,59469162,291809,5dfd960b-2e6378a2-1de9c84f-24ec4b38-f9d1a19a,17938576,CHEST (PORTABLE AP),AP,2539,3050,21580123,213357,CHEST (PORTABLE AP),antero-posterior,Erect,30002498,2158-01-23 16:00:00,2158-01-24 17:36:04,2158-01-23 21:33:57,5.566,25.601,3,0.0,12.0,17938576_episode1_timeseries.csv,25.601111,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1
1,59469162,291809,5dfd960b-2e6378a2-1de9c84f-24ec4b38-f9d1a19a,17938576,CHEST (PORTABLE AP),AP,2539,3050,21580123,213357,CHEST (PORTABLE AP),antero-posterior,Erect,30002498,2158-01-24 04:00:00,2158-01-24 17:36:04,2158-01-23 21:33:57,-6.434,25.601,3,12.0,24.0,17938576_episode1_timeseries.csv,25.601111,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1
2,59469162,291809,5dfd960b-2e6378a2-1de9c84f-24ec4b38-f9d1a19a,17938576,CHEST (PORTABLE AP),AP,2539,3050,21580123,213357,CHEST (PORTABLE AP),antero-posterior,Erect,30002498,2158-01-24 16:00:00,2158-01-24 17:36:04,2158-01-23 21:33:57,-18.434,25.601,3,24.0,26.601,17938576_episode1_timeseries.csv,25.601111,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1
3,51967233,323833,d2bae3a3-3917d71b-f44edcd7-81c7017a-f15288e7,18730522,CHEST (PORTABLE AP),AP,2539,3050,21530912,123408,DX CHEST PORTABLE PICC LINE PLACEMENT,antero-posterior,Erect,30004391,2153-09-05 13:12:00,2153-09-13 18:21:18,2153-09-12 12:34:08,167.369,197.155,17,0.0,12.0,18730522_episode2_timeseries.csv,197.155,1,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,1,1,1
4,51967233,323833,d2bae3a3-3917d71b-f44edcd7-81c7017a-f15288e7,18730522,CHEST (PORTABLE AP),AP,2539,3050,21530912,123408,DX CHEST PORTABLE PICC LINE PLACEMENT,antero-posterior,Erect,30004391,2153-09-06 01:12:00,2153-09-13 18:21:18,2153-09-12 12:34:08,155.369,197.155,17,12.0,24.0,18730522_episode2_timeseries.csv,197.155,1,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,1,1,1


In [567]:
train_meta_with_labels.columns

Index(['study_id', 'index', 'dicom_id', 'subject_id',
       'PerformedProcedureStepDescription', 'ViewPosition', 'Rows', 'Columns',
       'StudyDate', 'StudyTime', 'ProcedureCodeSequence_CodeMeaning',
       'ViewCodeSequence_CodeMeaning',
       'PatientOrientationCodeSequence_CodeMeaning', 'stay_id', 'intime',
       'outtime', 'StudyDateTime', 'time_diff', 'LOS', 'num_ehr_windows',
       'lower', 'upper', 'stay', 'period_length',
       'Acute and unspecified renal failure', 'Acute cerebrovascular disease',
       'Acute myocardial infarction', 'Cardiac dysrhythmias',
       'Chronic kidney disease',
       'Chronic obstructive pulmonary disease and bronchiectasis',
       'Complications of surgical procedures or medical care',
       'Conduction disorders', 'Congestive heart failure; nonhypertensive',
       'Coronary atherosclerosis and other heart disease',
       'Diabetes mellitus with complications',
       'Diabetes mellitus without complication',
       'Disorders of lipi

In [511]:
tr, va, ts = load_cxr_ehr(args,
             ehr_train_ds=ehrtr,
             ehr_val_ds=ehrva,
             ehr_test_ds=ehrts,
             cxr_train_ds=cxrtr,
             cxr_val_ds=cxrva,
             cxr_test_ds=cxrts)

Number of CXR images= 377110
Number of ICU stays= 59372
Number of CXR associated with ICU stay based on subject ID= 368350
Number of unique CXR dicoms= 181195
Number of unique CXR study id= 122087
Mean time cxr - intime=  74.23168108953072
Minimum time = -601.239
Maximum time = 2368.942
Excluding CXR with missing radiology reports =  69657
69657
7696
20442


In [512]:
a = (next(iter(tr)))

In [513]:
a

[array([[[-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.29126413],
         [-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.29126413],
         [-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.29126413],
         ...,
         [  0.        ,   0.        ,   0.        , ...,   0.        ,
            0.        ,   0.        ],
         [  0.        ,   0.        ,   0.        , ...,   0.        ,
            0.        ,   0.        ],
         [  0.        ,   0.        ,   0.        , ...,   0.        ,
            0.        ,   0.        ]],
 
        [[-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.29126413],
         [-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.29126413],
         [-16.60627547,  -0.06021818,  -0.2673962 , ...,  -0.55850585,
           -0.14534766,  -0.2912

In [159]:
import torch
ehr = torch.randn(5,512)
ehr.shape

torch.Size([5, 512])

In [160]:
cxr = torch.randn(5,512)
cxr.shape

torch.Size([5, 512])

In [161]:
time_diff1 =  [0.1,0.2,0.1,0.1,0.1]
time_diff2 =  [10.,50.,30.,70.,90.]

In [174]:
def modified_info_nce_loss(feats_ehr, feats_img, time_diff, mode='train'):
        # Calculate cosine similarity matrix
        cos_sim = F.cosine_similarity(feats_img[:,None,:], feats_ehr[None,:,:], dim=-1)
#         print(cos_sim)
        cos_sim = cos_sim /  0.07
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)
        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute the values of beta
        k = 1
        time_diff = torch.FloatTensor(time_diff)
#         print(time_diff.mean())
#         time_diff = time_diff/ time_diff.mean()
#         print(time_diff)
        beta = torch.exp(-k*time_diff)#.to(self.device)
        print(beta)
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
#         print(nll_1)
        nll_1 = (k-beta)*nll_1
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0)
        nll_2 = (k-beta)*nll_2
        
        # Total loss 
        loss = -(nll_1 + nll_2).mean()
       
        return loss

In [175]:
modified_info_nce_loss(ehr,cxr,time_diff1)

tensor([0.9048, 0.8187, 0.9048, 0.9048, 0.9048])


tensor(0.3130)

In [176]:
modified_info_nce_loss(ehr,cxr,time_diff2)

tensor([4.5400e-05, 1.9287e-22, 9.3576e-14, 3.9754e-31, 8.1940e-40])


tensor(2.4332)