In [17]:
from __future__ import absolute_import
from __future__ import print_function

import os
import sys
import numpy as np
import pandas as pd 
import platform
import pickle
import json
from PIL import Image
import glob
import random
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import custom_parser as par

parser = par.initiate_parsing()
args = parser.parse_args([ '--device' , '$CUDA_VISIBLE_DEVICES',
'--vision-backbone', 'resnet34' ,
'--resize', '256' , 
'--task' , 'phenotyping' ,
'--job_number' , '${SLURM_JOBID}',
'--file_name' , 'SIMCLR-${SLURM_JOBID}' ,
'--epochs' , '2' , '--transforms_cxr' , 'simclrv2' , '--temperature' , '0.01' ,
'--batch_size' , '30' , '--lr' , '0.8' ,
'--num_gpu' , '1' ,
'--pretrain_type' , 'simclr' ,
'--mode' , 'train' ,
'--fusion_type' , 'None' ,
'--save_dir' , '/scratch/se1525/mml-ssl/checkpoints/phenotyping/models' ,
'--tag' , 'simclr_train'])
    

In [18]:
class Discretizer:
    def __init__(self, timestep=0.8, store_masks=True, impute_strategy='zero', start_time='zero',
                 config_path= 'ehr_utils/resources/discretizer_config.json'):

        with open(config_path) as f:
            config = json.load(f)
            # print(config)
            self._id_to_channel = config['id_to_channel']
            self._channel_to_id = dict(zip(self._id_to_channel, range(len(self._id_to_channel))))
            self._is_categorical_channel = config['is_categorical_channel']
            self._possible_values = config['possible_values']
            self._normal_values = config['normal_values']
            for channel in self._is_categorical_channel:
                    print(channel)
            #         print(len(self._possible_values[channel]))

        self._header = ["Hours"] + self._id_to_channel
        self._timestep = timestep
        self._store_masks = store_masks
        self._start_time = start_time
        self._impute_strategy = impute_strategy

        # for statistics
        self._done_count = 0
        self._empty_bins_sum = 0
        self._unused_data_sum = 0

    def transform(self, X, header=None, end=None):
        # print("end" , end)
        if header is None:
            # print("none")
            header = self._header
        assert header[0] == "Hours"
        eps = 1e-6

        N_channels = len(self._id_to_channel)
        ts = [float(row[0]) for row in X]
        for i in range(len(ts) - 1):
            assert ts[i] < ts[i+1] + eps

        if self._start_time == 'relative':
            first_time = ts[0]
        elif self._start_time == 'zero':
            first_time = 0
        else:
            raise ValueError("start_time is invalid")

        if end is None:
            max_hours = max(ts) - first_time
        else:
            max_hours = end - first_time

        N_bins = int(max_hours / self._timestep + 1.0 - eps)
        # print(max_hours, self._timestep, N_bins)

        cur_len = 0
        begin_pos = [0 for i in range(N_channels)]
        end_pos = [0 for i in range(N_channels)]
        for i in range(N_channels):
            channel = self._id_to_channel[i]
            begin_pos[i] = cur_len
            if self._is_categorical_channel[channel]:
                end_pos[i] = begin_pos[i] + len(self._possible_values[channel])
            else:
                end_pos[i] = begin_pos[i] + 1
            cur_len = end_pos[i]

        data = np.zeros(shape=(N_bins, cur_len), dtype=float)
        mask = np.zeros(shape=(N_bins, N_channels), dtype=int)
        original_value = [["" for j in range(N_channels)] for i in range(N_bins)]
        total_data = 0
        unused_data = 0

        def write(data, bin_id, channel, value, begin_pos):
            channel_id = self._channel_to_id[channel]
            if self._is_categorical_channel[channel]:
                # print("list: ", self._possible_values[channel], "val: ", value, "channel:", channel)
                category_id = self._possible_values[channel].index(value)
                N_values = len(self._possible_values[channel])
                one_hot = np.zeros((N_values,))
                one_hot[category_id] = 1
                for pos in range(N_values):
                    data[bin_id, begin_pos[channel_id] + pos] = one_hot[pos]
            else:
                data[bin_id, begin_pos[channel_id]] = float(value)

        for row in X:
            t = float(row[0]) - first_time
            if t > max_hours + eps:
                continue
            bin_id = int(t / self._timestep - eps)
            assert 0 <= bin_id < N_bins

            for j in range(1, len(row)):
                if row[j] == "":
                    continue
                channel = header[j]
                channel_id = self._channel_to_id[channel]

                total_data += 1
                if mask[bin_id][channel_id] == 1:
                    unused_data += 1
                mask[bin_id][channel_id] = 1

                write(data, bin_id, channel, row[j], begin_pos)
                original_value[bin_id][channel_id] = row[j]

        # impute missing values

        if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']:
            raise ValueError("impute strategy is invalid")

        if self._impute_strategy in ['normal_value', 'previous']:
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if self._impute_strategy == 'normal_value':
                        imputed_value = self._normal_values[channel]
                    if self._impute_strategy == 'previous':
                        if len(prev_values[channel_id]) == 0:
                            imputed_value = self._normal_values[channel]
                        else:
                            imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        if self._impute_strategy == 'next':
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins-1, -1, -1):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if len(prev_values[channel_id]) == 0:
                        imputed_value = self._normal_values[channel]
                    else:
                        imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        empty_bins = np.sum([1 - min(1, np.sum(mask[i, :])) for i in range(N_bins)])
        self._done_count += 1
        self._empty_bins_sum += empty_bins / (N_bins + eps)
        self._unused_data_sum += unused_data / (total_data + eps)

        if self._store_masks:
            data = np.hstack([data, mask.astype(np.float32)])

        # create new header
        new_header = []
        for channel in self._id_to_channel:
            if self._is_categorical_channel[channel]:
                values = self._possible_values[channel]
                for value in values:
                    new_header.append(channel + "->" + value)
            else:
                new_header.append(channel)

        if self._store_masks:
            for i in range(len(self._id_to_channel)):
                channel = self._id_to_channel[i]
                new_header.append("mask->" + channel)

        new_header = ",".join(new_header)

        return (data, new_header)

    def print_statistics(self):
        print("statistics of discretizer:")
        print("\tconverted {} examples".format(self._done_count))
        print("\taverage unused data = {:.2f} percent".format(100.0 * self._unused_data_sum / self._done_count))
        print("\taverage empty  bins = {:.2f} percent".format(100.0 * self._empty_bins_sum / self._done_count))
        


# data preprocessor
def read_timeseries(args):
    path = f'{args.ehr_data_root}/{args.task}/train/14991576_episode3_timeseries.csv'
    ret = []
    with open(path, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        assert header[0] == "Hours"
        for line in tsfile:
            mas = line.strip().split(',')
            # print(mas)
            ret.append(np.array(mas))
    return np.stack(ret)


def ehr_funcs_discretizer(args):
    
    discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero',
                          config_path=f'/scratch/se1525/mml-ssl/ehr_utils/resources/discretizer_config.json')

    # print(read_timeseries(args))
    discretizer_header = discretizer.transform(read_timeseries(args))[1].split(',')
    # print(discretizer_header)
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]
    indices_and_values = [(i, x) for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]
    # print(len(indices_and_values))
    # for i, x in indices_and_values:
        # print("Index:", i)
        # print("Value:", x)   
    return discretizer, cont_channels

discretizer, cont_channels = ehr_funcs_discretizer(args)

class Normalizer:
    def __init__(self, fields=None):
        self._means = None
        self._stds = None
        self._fields = None
        if fields is not None:
            self._fields = [col for col in fields]

        self._sum_x = None
        self._sum_sq_x = None
        self._count = 0

    def _feed_data(self, x):
        x = np.array(x)
        self._count += x.shape[0]
        if self._sum_x is None:
            self._sum_x = np.sum(x, axis=0)
            self._sum_sq_x = np.sum(x**2, axis=0)
        else:
            self._sum_x += np.sum(x, axis=0)
            self._sum_sq_x += np.sum(x**2, axis=0)

    def _save_params(self, save_file_path):
        eps = 1e-7
        with open(save_file_path, "wb") as save_file:
            N = self._count
            self._means = 1.0 / N * self._sum_x
            self._stds = np.sqrt(1.0/(N - 1) * (self._sum_sq_x - 2.0 * self._sum_x * self._means + N * self._means**2))
            self._stds[self._stds < eps] = eps
            pickle.dump(obj={'means': self._means,
                             'stds': self._stds},
                        file=save_file,
                        protocol=2)

    def load_params(self, load_file_path):
        with open(load_file_path, "rb") as load_file:
            if platform.python_version()[0] == '2':
                dct = pickle.load(load_file)
            else:
                dct = pickle.load(load_file, encoding='latin1')
            self._means = dct['means']
            self._stds = dct['stds']

    def transform(self, X):
        if self._fields is None:
            fields = range(X.shape[1])
        else:
            fields = self._fields
        ret = 1.0 * X
        for col in fields:
            ret[:, col] = (X[:, col] - self._means[col]) / self._stds[col]
        return ret
    
def ehr_funcs_normalizer(args, cont_channels):
    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = 'ph_ts{}.input_str:previous.start_time:zero.normalizer'.format(args.timestep)
        normalizer_state = os.path.join('/scratch/se1525/mml-ssl/', normalizer_state)
    normalizer.load_params(normalizer_state)
    
    return normalizer

normalizer = ehr_funcs_normalizer(args, cont_channels)

Capillary refill rate
Diastolic blood pressure
Fraction inspired oxygen
Glascow coma scale eye opening
Glascow coma scale motor response
Glascow coma scale total
Glascow coma scale verbal response
Glucose
Heart Rate
Height
Mean blood pressure
Oxygen saturation
Respiratory rate
Systolic blood pressure
Temperature
Weight
pH


In [23]:
class EHRdataset(Dataset):
    def __init__(self, discretizer, normalizer, listfile, dataset_dir, return_names=True, period_length=48.0, transforms=None):
        self.return_names = return_names
        self.discretizer = discretizer
        self.normalizer = normalizer
        self._period_length = period_length

        self._dataset_dir = dataset_dir
        listfile_path = listfile
        with open(listfile_path, "r") as lfile:
            self._data = lfile.readlines()
            print("Length of list file" , len(self._data))
        self._listfile_header = self._data[0]
        self.CLASSES = self._listfile_header.strip().split(',')[3:]
        # print(self.CLASSES)
        self._data = self._data[1:]
        # print(self._data[:1])
        self.transforms = transforms
        self._data = [line.split(',') for line in self._data]
        # print(self._data)
        # self._data_map_pheno = [(mas[0], float(mas[1]), list(map(int, mas[2:]))) for mas in self._data]

        self.data_map = {
            mas[0]: {
                'labels': list(map(int, mas[3:])),
                'stay_id': float(mas[2]),
                'time': float(mas[1]),
                }
                for mas in self._data
        }
        # self.data = [(x, float(t), int(stay_id) ,int(y)) for (x, t, stay_id , y) in self._data]
        print("Length of data_map" , len(list(self.data_map)))
        # print("Length of data" , len(self.data))

        
        # count = 0
        # for value in self._data_map_pheno:
        #     print(f"Value: {value}")
        #     count += 1
        #     if count == 5:
        #         break
             
        # print (' ---------------- ')
        # self._data_map_decomp = [(x, float(t), int(y)) for (x, t, s, y) in self._data]
        # count = 0
        # for value in self._data_map_decomp:
        #     print(f"Value: {value}")
        #     count += 1
        #     if count == 5:
        #         break

        self.names = list(self.data_map.keys())
    
    def _read_timeseries(self, ts_filename, time_bound):
        
        ret = []
        with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
            header = tsfile.readline().strip().split(',')
            assert header[0] == "Hours"
            for line in tsfile:
                mas = line.strip().split(',')
                if time_bound is not None:
                    t = float(mas[0])
                    if t > time_bound + 1e-6:
                        break
                ret.append(np.array(mas))
        return (np.stack(ret), header)

    
    def read_by_file_name(self, index):
        t = self.data_map[index]['time']
        time_bound = t
        print("time bound" , time_bound)
        print(len(list(self.data_map.keys())))
        y = self.data_map[index]['labels']
        stay_id = self.data_map[index]['stay_id']
        (X, header) = self._read_timeseries(index, time_bound)

        return {"X": X,
                "t": t,
                "y": y,
                'stay_id': stay_id,
                "header": header,
                "name": index}
 

    def __getitem__(self, index):
        if isinstance(index, int):
            index = self.names[index]
        ret = self.read_by_file_name(index)
        print("length after read_by_file_name" , len(ret["X"]))
        # print("ret" , ret)
        data = ret["X"]
        ts = ret["t"] if ret['t'] > 0.0 else self._period_length
        # print("ts" , ts)
        

        # print("data length before: ", len(data[0]))
        # print("data before: " , data[0])
        data = self.discretizer.transform(data, end=ts)[0]
        # print("data length after: ", len(data[0]))
        # print("data after:" ,  data[0])
        if (self.normalizer is not None):
            data = self.normalizer.transform(data)
               

        ys = ret["y"]
        names = ret["name"]
        ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
        return data, ys

    
    def __len__(self):
        return len(self.names)

In [24]:
def get_datasets(discretizer, normalizer, args):
    transform = None
    # print(f'{args.ehr_data_root}/{args.task}/train_listfile.csv')
    # print(os.path.join(args.ehr_data_root, f'{args.task}/train'))
    train_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/train_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'), transforms=transform)
    val_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/val_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'), transforms = transform)
    test_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/test_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/test'), transforms = transform)
    return train_ds, val_ds, test_ds
ehr_train_ds, ehr_val_ds, ehr_test_ds = get_datasets(discretizer, normalizer, args)
# print(len(ehr_train_ds[1][0]))
# print(ehr_train_ds[1])
get_item = ehr_test_ds[1]
get_item

Length of list file 42629
Length of data_map 42628
Length of list file 4803
Length of data_map 4802
Length of list file 11915
Length of data_map 11914
time bound 25.104444
11914
length after read_by_file_name 39


(array([[ 1.        ,  0.        , -0.00332347, ...,  0.        ,
          1.        ,  0.        ],
        [ 1.        ,  0.        , -0.05289614, ...,  0.        ,
          0.        ,  0.        ],
        [ 1.        ,  0.        , -0.03036311, ...,  1.        ,
          1.        ,  0.        ],
        ...,
        [ 1.        ,  0.        ,  0.01019634, ...,  0.        ,
          0.        ,  0.        ],
        [ 1.        ,  0.        ,  0.01019634, ...,  1.        ,
          0.        ,  0.        ],
        [ 1.        ,  0.        ,  0.01019634, ...,  0.        ,
          0.        ,  0.        ]]),
 array([0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0], dtype=int32))

In [1]:
# with open('/scratch/fs999/shamoutlab/data/mimic-iv-extracted/phenotyping/train_listfile.csv', "r") as lfile:
#     data = lfile.readlines()
# listfile_header = data[0]
# CLASSES = listfile_header.strip().split(',')[3:]
# data = data[1:]

# data = [line.split(',') for line in data]
# print(data[1])
# data_map = {
#     mas[0]: {
#         'labels': list(map(float, mas[3:])),
#         'stay_id': float(mas[2]),
#         'time': float(mas[1]),
#         }
#         for mas in data
# }
# print(data_map['16505791_episode1_timeseries.csv'])
# def _read_timeseries(ts_filename, time_bound=None):
#     ret = []
#     with open(os.path.join('/scratch/fs999/shamoutlab/data/mimic-iv-extracted/phenotyping/train', ts_filename), "r") as tsfile:
#         header = tsfile.readline().strip().split(',')
#         assert header[0] == "Hours"
#         for line in tsfile:
#             mas = line.strip().split(',')
#             if time_bound is not None:
#                 t = float(mas[0])
#                 if t > time_bound + 1e-6:
#                     break
#             ret.append(np.array(mas))
#     return (np.stack(ret), header)
# def read_by_file_name(data_map, index, time_bound=None):
#     t = data_map[index]['time'] if time_bound is None else time_bound
#     y = data_map[index]['labels']
#     stay_id = data_map[index]['stay_id']
#     (X, header) = read_timeseries(index, time_bound=time_bound)

#     return {"X": X,
#             "t": t,
#             "y": y,
#             'stay_id': stay_id,
#             "header": header,
#             "name": index}

# def __getitem__(data_map, index, time_bound=None):
#     if isinstance(index, int):
#         index = self.names[index]
#     ret = self.read_by_file_name(data_map, index, time_bound)
#     data = ret["X"]
#     ts = ret["t"] if ret['t'] > 0.0 else 48
#     ys = ret["y"]
#     names = ret["name"]
#     data = discretizer.transform(data, end=ts)[0] 
#     if (normalizer is not None):
#         data = self.normalizer.transform(data)
#     ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
#     return data, ys

In [3]:
# listfile = f'{args.ehr_data_root}/length-of-stay/train_listfile.csv'
# print(listfile)
# with open(listfile, "r") as lfile:
#     data = lfile.readlines()
# listfile_header = data[0]
# print(listfile_header)
# classes = listfile_header.strip().split(',')[3:]
# print(classes)
# data = data[1:]

# print(data[1:10])
# data = [line.split(',') for line in data]
# print(data[8][0])
# # data_map = {
#             mas[0]: {
#                 'labels': list(map(float, mas[3:])),
#                 'stay_id': float(mas[2]),
#                 'time': float(mas[1]),
#                 }
#                 for mas in data
#         }

# data_map = {
#     x: {
#         'time': float(t),
#         'stay_id': int(stay_id),
#         'labels': float(y)
#     }
#     for x, t, stay_id, y in data
# }
# for mas in data[0:10]:
#     print(mas)
#     print(mas[3])

data_map={}
# for mas in data:
#     data_map[mas[0]]=['labels': mas[3], 
#                     'stay_id': mas[2],
#                     'time': mas[1]}
# for mas in data:
#     data_map[mas[0]]= mas[3]
# data_map = {
#             mas[0]: {
#                 'labels': mas[3],
#                 'stay_id': mas[2],
#                 'time': mas[1],
#                 }
#                 for mas in data
#         }
# data_map = [(x, float(t), int(stay_id) ,float(y)) for (x, t, stay_id , y) in data]
# # print(data_map['10001884_episode1_timeseries.csv'])
# print(data_map[8])



In [4]:
# class EHRdataset(Dataset):
#     def __init__(self, discretizer, normalizer, listfile, dataset_dir, return_names=True, period_length=48.0):
#         self.return_names = return_names
#         self.discretizer = discretizer
#         self.normalizer = normalizer
#         self._period_length = period_length

#         self._dataset_dir = dataset_dir
#         listfile_path = listfile
#         with open(listfile_path, "r") as lfile:
#             self._data = lfile.readlines()
#         self._listfile_header = self._data[0]
#         self.CLASSES = self._listfile_header.strip().split(',')[3:]
#         self._data = self._data[1:]
#         print(self._data[12])


#         self._data = [line.split(',') for line in self._data]
#         if 'length-of-stay' or 'decompensation' in self._dataset_dir:
#             self.data_map = [(x, float(t), int(stay_id) ,[float(y)]) for (x, t, stay_id , y) in self._data]
#             self.names = [x[0] for x in self.data_map]
#             self.times= [x[1] for x in self.data_map]
#         else:
#             self.data_map = {
#             mas[0]: {
#                 'labels': list(map(float, mas[3:])),
#                 'stay_id': float(mas[2]),
#                 'time': float(mas[1]),
#                 }
#                 for mas in self._data
#         }
#             self.names = list(self.data_map.keys())
#         print(self.data_map[12])
        

#         # import pdb; pdb.set_trace()

#         # self._data = [(line_[0], float(line_[1]), line_[2], float(line_[3])  ) for line_ in self._data]



#         # self.names = [x[0] for x in self.data_map]
#         # print(self.names[0:10])
    
#     def _read_timeseries(self, ts_filename, time_bound=None):
        
#         ret = []
#         with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
#             header = tsfile.readline().strip().split(',')
#             assert header[0] == "Hours"
#             for line in tsfile:
#                 mas = line.strip().split(',')
#                 if time_bound is not None:
#                     t = float(mas[0])
#                     if t > time_bound + 1e-6:
#                         break
#                 ret.append(np.array(mas))
#         return (np.stack(ret), header)
    
#     def read_by_file_name(self, index, time, time_bound=None):
#         print("index", index)
#         if 'length-of-stay' or 'decompensation' in self._dataset_dir:
#             # for x in self.data_map:
#             #     if x[0] == index:
#             #         entry = x 
#             #         break
#             entry = next((x for x in self.data_map if x[0] == index and x[1] ==  time), None)
#             print("entry", entry)
#             if entry is None:
#                 raise ValueError(f"Entry with name {index} not found")
            
#             t = float(entry[1])  # time is the second element in the tuple
#             stay_id = int(entry[2])  # stay_id is the third element
#             y = entry[3]  # labels are the fourth element
#             print("this is entry 3", y)
#             (X, header) = self._read_timeseries(index, time_bound=time_bound if time_bound is not None else t)
#         else:     
#             t = self.data_map[index]['time'] if time_bound is None else time_bound
#             y = self.data_map[index]['labels']
#             stay_id = self.data_map[index]['stay_id']
#             (X, header) = self._read_timeseries(index, time_bound=time_bound)

#         return {"X": X,
#                 "t": t,
#                 "y": y,
#                 'stay_id': stay_id,
#                 "header": header,
#                 "name": index}

#     def get_decomp_los(self, index, time_bound=None):
#         # name = self._data[index][0]
#         # time_bound = self._data[index][1]
#         # ys = self._data[index][3]

#         # (data, header) = self._read_timeseries(index, time_bound=time_bound)
#         # data = self.discretizer.transform(data, end=time_bound)[0] 
#         # if (self.normalizer is not None):
#         #     data = self.normalizer.transform(data)
#         # ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
#         # return data, ys

#         # data, ys = 
#         return self.__getitem__(index, time_bound)


#     def __getitem__(self, index,  time_bound=None):
#         if 'length-of-stay' or 'decompensation' in self._dataset_dir: 
#             if isinstance(index, int):
#                 time = self.times[index]
#                 index = self.names[index]
#         else:
#             if isinstance(index, int):
#                 index = self.names[index]
#                 time = None
                
#         ret = self.read_by_file_name(index, time, time_bound)
#         data = ret["X"]
#         ts = ret["t"] if ret['t'] > 0.0 else self._period_length
#         ys = ret["y"]
#         print("this is ys" , ys)
#         names = ret["name"]
#         data = self.discretizer.transform(data, end=ts)[0] 
#         if (self.normalizer is not None):
#             data = self.normalizer.transform(data)
#         if 'length-of-stay' in self._dataset_dir:
#             ys = np.array(ys, dtype=np.float32) if len(ys) > 1 else np.array(ys, dtype=np.float32)[0]
#         else:
#             ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
#         return data, ys

    
#     def __len__(self):
#         return len(self.names)


# def get_datasets(discretizer, normalizer, args):
#     train_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/train_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/train'))
#     val_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/val_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/train'))
#     test_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/test_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/test'))
#     return train_ds, val_ds, test_ds

# def get_data_loader(discretizer, normalizer, dataset_dir, batch_size):
#     train_ds, val_ds, test_ds = get_datasets(discretizer, normalizer, dataset_dir)
#     train_dl = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=my_collate, pin_memory=True, num_workers=16)
#     val_dl = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=my_collate, pin_memory=True, num_workers=16)

#     return train_dl, val_dl
        
# def my_collate(batch):
#     x = [item[0] for item in batch]
#     x, seq_length = pad_zeros(x)
#     targets = np.array([item[1] for item in batch])
#     return [x, targets, seq_length]

# def pad_zeros(arr, min_length=None):

#     dtype = arr[0].dtype
#     seq_length = [x.shape[0] for x in arr]
#     max_len = max(seq_length)
#     ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
#            for x in arr]
#     if (min_length is not None) and ret[0].shape[0] < min_length:
#         ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
#                for x in ret]
#     return np.array(ret), seq_length

In [280]:
# class EHRdataset(Dataset):
#     def __init__(self, discretizer, normalizer, listfile, dataset_dir, return_names=True, period_length=48.0):
#         self.return_names = return_names
#         self.discretizer = discretizer
#         self.normalizer = normalizer
#         self._period_length = period_length

#         self._dataset_dir = dataset_dir
#         listfile_path = listfile
#         with open(listfile_path, "r") as lfile:
#             self._data = lfile.readlines()
#         self._listfile_header = self._data[0]
#         self.CLASSES = self._listfile_header.strip().split(',')[3:]
#         self._data = self._data[1:]
#         print(self._data[12])


#         self._data = [line.split(',') for line in self._data]
#         self.data_map = {
#         mas[0]: {
#             'labels': list(map(float, mas[3:])),
#             'stay_id': float(mas[2]),
#             'time': float(mas[1]),
#             }
#             for mas in self._data
#         }
#         self.names = list(self.data_map.keys())
#         self.names = [[x[0],x[1]] for x in self.data_map]
#         print(self.data_map[12])
        

#         # import pdb; pdb.set_trace()

#         # self._data = [(line_[0], float(line_[1]), line_[2], float(line_[3])  ) for line_ in self._data]



#         # self.names = [x[0] for x in self.data_map]
#         # print(self.names[0:10])
    
#     def _read_timeseries(self, ts_filename, time_bound=None):
        
#         ret = []
#         with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
#             header = tsfile.readline().strip().split(',')
#             assert header[0] == "Hours"
#             for line in tsfile:
#                 mas = line.strip().split(',')
#                 if time_bound is not None:
#                     t = float(mas[0])
#                     if t > time_bound + 1e-6:
#                         break
#                 ret.append(np.array(mas))
#         return (np.stack(ret), header)
    
#     def read_by_file_name(self, index, time, time_bound=None):
#         print("index", index)
#         if 'length-of-stay' or 'decompensation' in self._dataset_dir:
#             # for x in self.data_map:
#             #     if x[0] == index:
#             #         entry = x 
#             #         break
#             entry = next((x for x in self.data_map if x[0] == index and x[1] ==  time), None)
#             print("entry", entry)
#             if entry is None:
#                 raise ValueError(f"Entry with name {index} not found")
            
#             t = float(entry[1])  # time is the second element in the tuple
#             stay_id = int(entry[2])  # stay_id is the third element
#             y = entry[3]  # labels are the fourth element
#             print("this is entry 3", y)
#             (X, header) = self._read_timeseries(index, time_bound=time_bound if time_bound is not None else t)
#         else:     
#             t = self.data_map[index]['time'] if time_bound is None else time_bound
#             y = self.data_map[index]['labels']
#             stay_id = self.data_map[index]['stay_id']
#             (X, header) = self._read_timeseries(index, time_bound=time_bound)

#         return {"X": X,
#                 "t": t,
#                 "y": y,
#                 'stay_id': stay_id,
#                 "header": header,
#                 "name": index}

#     def get_decomp_los(self, index, time_bound=None):
#         # name = self._data[index][0]
#         # time_bound = self._data[index][1]
#         # ys = self._data[index][3]

#         # (data, header) = self._read_timeseries(index, time_bound=time_bound)
#         # data = self.discretizer.transform(data, end=time_bound)[0] 
#         # if (self.normalizer is not None):
#         #     data = self.normalizer.transform(data)
#         # ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
#         # return data, ys

#         # data, ys = 
#         return self.__getitem__(index, time_bound)


#     def __getitem__(self, index,  time_bound=None):
#         if 'length-of-stay' or 'decompensation' in self._dataset_dir: 
#             if isinstance(index, int):
#                 index, time = self.names[index]
#         else:
#             if isinstance(index, int):
#                 index = self.names[index]
#                 time = None
                
#         ret = self.read_by_file_name(index, time, time_bound)
#         data = ret["X"]
#         ts = ret["t"] if ret['t'] > 0.0 else self._period_length
#         ys = ret["y"]
#         print("this is ys" , ys)
#         names = ret["name"]
#         data = self.discretizer.transform(data, end=ts)[0] 
#         if (self.normalizer is not None):
#             data = self.normalizer.transform(data)
#         if 'length-of-stay' in self._dataset_dir:
#             ys = np.array(ys, dtype=np.float32) if len(ys) > 1 else np.array(ys, dtype=np.float32)[0]
#         else:
#             ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
#         return data, ys

    
#     def __len__(self):
#         return len(self.names)


# def get_datasets(discretizer, normalizer, args):
#     train_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/train_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/train'))
#     val_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/val_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/train'))
#     test_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_dir}/{args.task}/test_listfile.csv', os.path.join(args.ehr_data_dir, f'{args.task}/test'))
#     return train_ds, val_ds, test_ds

# def get_data_loader(discretizer, normalizer, dataset_dir, batch_size):
#     train_ds, val_ds, test_ds = get_datasets(discretizer, normalizer, dataset_dir)
#     train_dl = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=my_collate, pin_memory=True, num_workers=16)
#     val_dl = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=my_collate, pin_memory=True, num_workers=16)

#     return train_dl, val_dl
        
# def my_collate(batch):
#     x = [item[0] for item in batch]
#     x, seq_length = pad_zeros(x)
#     targets = np.array([item[1] for item in batch])
#     return [x, targets, seq_length]

# def pad_zeros(arr, min_length=None):

#     dtype = arr[0].dtype
#     seq_length = [x.shape[0] for x in arr]
#     max_len = max(seq_length)
#     ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
#            for x in arr]
#     if (min_length is not None) and ret[0].shape[0] < min_length:
#         ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
#                for x in ret]
#     return np.array(ret), seq_length

In [25]:
# def get_datasets(discretizer, normalizer, args):
#     transform = None
#     # print(f'{args.ehr_data_root}/{args.task}/train_listfile.csv')
#     # print(os.path.join(args.ehr_data_root, f'{args.task}/train'))
#     train_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/train_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'))
#     val_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/val_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/train'))
#     test_ds = EHRdataset(discretizer, normalizer, f'{args.ehr_data_root}/{args.task}/test_listfile.csv', os.path.join(args.ehr_data_root, f'{args.task}/test'))
#     return train_ds, val_ds, test_ds
# ehr_train_ds, ehr_val_ds, ehr_test_ds = get_datasets(discretizer, normalizer, args)
# print(len(ehr_train_ds[1][0]))
# print(ehr_train_ds[1])

In [26]:
# print(ehr_test_ds[8])
# for i in range(0, 10):  # Print labels for the first 10 entries, or less if the dataset is smaller
#     _, labels = ehr_val_ds[i]
#     print(f"Labels for entry {i}: {float(labels)}")

In [27]:
# get_item = ehr_test_ds[2]
# get_item

In [29]:
class Clip(object):
    """Transformation to clip image values between 0 and 1."""

    def __call__(self, sample):
        return torch.clip(sample, 0, 1)

    

class RandomCrop(object):
    "Randomly crop an image"
    
    def __call__(self, sample):
        resize = 256
        #print(np.random.uniform(0.4*resize,resize,1))
        random_crop_size = int(np.random.uniform(0.6*resize,resize,1))
        sample=transforms.RandomCrop(random_crop_size)(sample)
        return sample
    
    
class RandomColorDistortion(object):
    "Apply random color distortions to the image"
    
    def __call__(self, sample):
        resize=256

        # Random color distortion
        strength = 1.0 # 1.0 imagenet setting and CIFAR uses 0.5
        brightness = 0.8 * strength 
        contrast = 0.8 * strength
        saturation = 0.8 * strength
        hue = 0.2 * strength
        prob = np.random.uniform(0,1,1) 
        if prob < 0.8:
            sample=transforms.ColorJitter(brightness, contrast, saturation, hue)(sample)

        # Random Grayscale
        sample=transforms.RandomGrayscale(p=0.2)(sample)

        # Gaussian blur also based on imagenet but not used for CIFAR
        #prob = np.random.uniform(0,1,1)
        #if prob < 0.3:
        #    sample=transforms.GaussianBlur(kernel_size=resize//10)(sample)
        #    sample=transforms.Pad(0)(sample)
        return sample 
    

def get_transforms_simclr(args):
    normalize = transforms.Normalize([0.5, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    train_transforms = []
    # Resize all images to same size, then randomly crop and resize again
    train_transforms.append(transforms.Resize([args.resize, args.resize]))
    # Random affine
    train_transforms.append(transforms.RandomAffine(degrees=(-45, 45), translate=(0.1,0.1), scale=(0.7, 1.5), shear=(-25, 25)))
    # Random crop
    train_transforms.append(RandomCrop())
    # Resize again
    # train_transforms.append(transforms.Resize([args.resize, args.resize], interpolation=3))
    train_transforms.append(transforms.Resize([224, 224], interpolation=3))
    # Random horizontal flip 
    train_transforms.append(transforms.RandomHorizontalFlip())
    # Random color distortions
    train_transforms.append(RandomColorDistortion())
    # Convert to tensor
    train_transforms.append(transforms.ToTensor())
    # Clip values between 0 and 1 and normalize
    #train_transforms.append(Clip())
    #train_transforms.append(normalize)      

    test_transforms = []
    # Resize all images to same size, then center crop and resize again
    test_transforms.append(transforms.Resize([args.resize, args.resize]))
    crop_proportion=0.875
    test_transforms.append(transforms.CenterCrop([int(0.875*args.resize), int(0.875*args.resize)]))
    # test_transforms.append(transforms.Resize([args.resize, args.resize], interpolation=3))
    test_transforms.append(transforms.Resize([224, 224], interpolation=3))
    #Convert to tensor
    test_transforms.append(transforms.ToTensor())
    # Clip values between 0 and 1 and normalize
    #test_transforms.append(Clip())
    #test_transforms.append(normalize)

    return train_transforms, test_transforms

In [30]:
class MIMICCXR(Dataset):
    def __init__(self, paths, args, transform=None, split='train'):
        self.data_dir = args.cxr_data_root
        self.args = args
        self.CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
       'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
       'Pneumonia', 'Pneumothorax', 'Support Devices']
        self.filenames_to_path = {path.split('/')[-1].split('.')[0]: path for path in paths}
        # for item in self.filenames_to_path.items():
        #     print (item)
        #     break

        metadata = pd.read_csv(f'{self.data_dir}/mimic-cxr-2.0.0-metadata.csv')
        # print(metadata.keys())
        labels = pd.read_csv(f'{self.data_dir}/mimic-cxr-2.0.0-chexpert.csv')
        labels[self.CLASSES] = labels[self.CLASSES].fillna(0)
        labels = labels.replace(-1.0, 0.0)
        # print(labels.head)
        # print(labels.keys())
        
        splits = pd.read_csv(f'{self.data_dir}/mimic-cxr-ehr-split.csv')


        metadata_with_labels = metadata.merge(labels[self.CLASSES+['study_id'] ], how='inner', on='study_id')
        # print(metadata_with_labels.keys())


        self.filesnames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[self.CLASSES].values))
        # for item in self.filesnames_to_labels.items():
        #     print (item)
        #     break
        self.filenames_loaded = splits.loc[splits.split==split]['dicom_id'].values
        # print(self.filenames_loaded[1])
        # exclude any files included in the split.csv file but not in the chexpert.csv (i.e. with no labels)
        self.filenames_loaded = [filename  for filename in self.filenames_loaded if filename in self.filesnames_to_labels]
        # print(self.filenames_loaded[1])
        self.transform = transform

    def __getitem__(self, index):
        # print(index)
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            labels = torch.tensor(self.filesnames_to_labels[index]).float()
            if self.transform is not None:
                img = self.transform(img)
            return img, labels
          
        
        filename = self.filenames_loaded[index]
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')
        labels = torch.tensor(self.filesnames_to_labels[filename]).float()

        if self.transform is not None:
            img = self.transform(img)
        return img, labels
    
    def __len__(self):
        return len(self.filenames_loaded)

In [35]:
def get_cxr_datasets(args):
    if args.transforms_cxr=='simclrv2':
        # print("Appling SimCLR image transforms...")
        train_transforms, test_transforms = get_transforms_simclr(args)

    data_dir = args.cxr_data_root
    # print(data_dir)
    filepath = f'{args.cxr_data_root}/new_paths.npy'
    if os.path.exists(filepath):
        paths = np.load(filepath)
        # print(len(paths))
    else:
        paths = glob.glob(f'{data_dir}/resized/**/*.jpg', recursive = True)
        np.save(filepath, paths)
    
    dataset_train = MIMICCXR(paths, args, split='train', transform=transforms.Compose(train_transforms))
    dataset_validate = MIMICCXR(paths, args, split='validate', transform=transforms.Compose(test_transforms),)
    dataset_test = MIMICCXR(paths, args, split='test', transform=transforms.Compose(test_transforms),)

    return dataset_train , dataset_validate, dataset_test

cxr_train_ds, cxr_val_ds, cxr_test_ds = get_cxr_datasets(args)
# cxr_train_ds =  get_cxr_datasets(args)
# get_item = next(iter((cxr_train_ds)))


In [36]:
CLASSES = [
       'Acute and unspecified renal failure', 'Acute cerebrovascular disease',
       'Acute myocardial infarction', 'Cardiac dysrhythmias',
       'Chronic kidney disease',
       'Chronic obstructive pulmonary disease and bronchiectasis',
       'Complications of surgical procedures or medical care',
       'Conduction disorders', 'Congestive heart failure; nonhypertensive',
       'Coronary atherosclerosis and other heart disease',
       'Diabetes mellitus with complications',
       'Diabetes mellitus without complication',
       'Disorders of lipid metabolism', 'Essential hypertension',
       'Fluid and electrolyte disorders', 'Gastrointestinal hemorrhage',
       'Hypertension with complications and secondary hypertension',
       'Other liver diseases', 'Other lower respiratory disease',
       'Other upper respiratory disease',
       'Pleurisy; pneumothorax; pulmonary collapse',
       'Pneumonia (except that caused by tuberculosis or sexually transmitted disease)',
       'Respiratory failure; insufficiency; arrest (adult)',
       'Septicemia (except in labor)', 'Shock'
    ]

class MIMIC_CXR_EHR(Dataset):
    def __init__(self, args, metadata_with_labels, ehr_ds, cxr_ds, split='train'):
        
        self.CLASSES = CLASSES
        
        self.metadata_with_labels = metadata_with_labels
        
        self.cxr_files_paired = self.metadata_with_labels.dicom_id.values
        self.ehr_files_paired = (self.metadata_with_labels['stay'].values)
        self.time_diff = self.metadata_with_labels.time_diff
        
        self.cxr_files_all = cxr_ds.filenames_loaded
        self.ehr_files_all = ehr_ds.names
        
        self.ehr_files_unpaired = list(set(self.ehr_files_all) - set(self.ehr_files_paired))
        
        self.ehr_ds = ehr_ds
        self.cxr_ds = cxr_ds
        
        self.args = args
        self.split = split
        self.data_ratio = self.args.data_ratio if split=='train' else 1.0

    def __getitem__(self, index):
        if self.args.data_pairs == 'paired':
            cxr_data, labels_cxr = self.cxr_ds[self.cxr_files_paired[index]]
            ehr_data, labels_ehr = self.ehr_ds[self.ehr_files_paired[index]]
            time_diff = self.metadata_with_labels.iloc[index].time_diff                      
            return ehr_data, cxr_data, labels_ehr, labels_cxr
  
    def __len__(self):
        if self.args.data_pairs == 'paired':
            return len(self.ehr_files_paired)

In [39]:
def loadmetadata(args):

    data_dir = args.cxr_data_root
    cxr_metadata = pd.read_csv(f'{data_dir}/mimic-cxr-2.0.0-metadata.csv')
    print('Number of CXR images=', len(cxr_metadata))
    icu_stay_metadata = pd.read_csv(f'{args.ehr_data_root}/root/all_stays.csv')
    print('Number of ICU stays=', len(icu_stay_metadata))
    columns = ['subject_id', 'stay_id', 'intime', 'outtime']
    
    # only common subjects with both icu stay and an xray
    # Note that inner merge includes rows if a chest X-ray is associated with multiple stays
    cxr_merged_icustays = cxr_metadata.merge(icu_stay_metadata[columns], how='inner', on='subject_id')
    # print(cxr_merged_icustays.keys())
    print('Number of CXR associated with ICU stay based on subject ID=', len(cxr_merged_icustays))
    print('Number of unique CXR dicoms=', len(cxr_merged_icustays.dicom_id.unique()))
    print('Number of unique CXR study id=', len(cxr_merged_icustays.study_id.unique()))
        
    # combine study date time
    # just changing the format and combining study date and study time to study datetime
    cxr_merged_icustays['StudyTime'] = cxr_merged_icustays['StudyTime'].apply(lambda x: f'{int(float(x)):06}' )
    cxr_merged_icustays['StudyDateTime'] = pd.to_datetime(cxr_merged_icustays['StudyDate'].astype(str) + ' ' + cxr_merged_icustays['StudyTime'].astype(str) ,format="%Y%m%d %H%M%S")
    print(cxr_merged_icustays.head)
    # note that study datetime is for cxr images and intime/outtime are for the icu stays
    cxr_merged_icustays.intime=pd.to_datetime(cxr_merged_icustays.intime)
    cxr_merged_icustays.outtime=pd.to_datetime(cxr_merged_icustays.outtime)
    end_time = cxr_merged_icustays.outtime
    
    cxr_merged_icustays['time_diff'] = cxr_merged_icustays.StudyDateTime-cxr_merged_icustays.intime
    cxr_merged_icustays['time_diff'] = cxr_merged_icustays['time_diff'].apply(lambda x: np.round(x.total_seconds()/60/60,3))
    
    # either only take the latest CXR related with an icu stay or take all CXR images associated with icu stays
    # For LE/ FT  (evaluation datasets)
    if (args.dataset!='all'):
        cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&((cxr_merged_icustays.StudyDateTime<=end_time))]

        if args.task == 'decompensation':
            print('Only include the last CXR before the current time of prediction')
            cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime<=cxr_merged_icustays.intime)]
        
        if args.task == 'in-hospital-mortality':
            print("Excluding chest X-rays beyond 48 hours for in-hospital mortality")
            end_time = cxr_merged_icustays.intime + pd.DateOffset(hours=48)
            cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&((cxr_merged_icustays.StudyDateTime<=end_time))]
        
        # select cxrs with the ViewPosition == 'AP'
        cxr_merged_icustays_AP = cxr_merged_icustays_during[cxr_merged_icustays_during['ViewPosition'] == 'AP']
        # print(cxr_merged_icustays_AP.loc[cxr_merged_icustays_AP.stay_id==30001947])
        # returns the indices of records with the same stay_id
        groups = cxr_merged_icustays_AP.groupby('stay_id')
        # print(groups.groups)

        groups_selected = []
        for group in groups:
            # select the latest cxr for the icu stay
            selected = group[1].sort_values('StudyDateTime').tail(1).reset_index()
            groups_selected.append(selected)
        groups = pd.concat(groups_selected, ignore_index=True)
        # print(groups.head)
    
    # For SIMCLR pretraining (large dataset)
    else:
        print(cxr_merged_icustays.ViewPosition.unique())
        cxr_merged_icustays_AP = cxr_merged_icustays[cxr_merged_icustays['ViewPosition'] == 'AP']
        print("Number of CXR associated with ICU stay and in AP view=", len(cxr_merged_icustays_AP))
        groups = cxr_merged_icustays_AP
        
    print("Mean time cxr - intime= ", groups.time_diff.mean())
    print("Minimum time =", groups.time_diff.min())
    print("Maximum time =", groups.time_diff.max())

#     plt.hist(groups.time_diff.apply(lambda x: x.days).astype("float64"))
#     plt.xlabel('Time difference in days')
#     plt.show()

    #print(groups.iloc[0])
    return groups

def my_collate(batch):
    x = [item[0] for item in batch]
    pairs = [False if item[1] is None else True for item in batch]
    img = torch.stack([torch.zeros(3, 224, 224) if item[1] is None else item[1] for item in batch])
    x, seq_length = pad_zeros(x)
    targets_ehr = np.array([item[2] for item in batch])
    targets_cxr = torch.stack([torch.zeros(14) if item[3] is None else item[3] for item in batch])
    return [x, img, targets_ehr, targets_cxr, seq_length, pairs]
    

def pad_zeros(arr, min_length=None):
    dtype = arr[0].dtype
    seq_length = [x.shape[0] for x in arr]
    max_len = max(seq_length)
    ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
           for x in arr]
    if (min_length is not None) and ret[0].shape[0] < min_length:
        ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
               for x in ret]
    return np.array(ret), seq_length

In [40]:
def load_cxr_ehr(args, ehr_train_ds, ehr_val_ds, cxr_train_ds, cxr_val_ds, ehr_test_ds, cxr_test_ds):
    
    # Load cxr and ehr groups
    cxr_merged_icustays = loadmetadata(args) 
    
    # Add the labels based on the EHR splits 
    splits_labels_train = pd.read_csv(f'{args.ehr_data_root}/{args.task}/train_listfile.csv')
    splits_labels_val = pd.read_csv(f'{args.ehr_data_root}/{args.task}/val_listfile.csv')
    splits_labels_test = pd.read_csv(f'{args.ehr_data_root}/{args.task}/test_listfile.csv')
    
    # split the groups from cxr_merged_icustays based on the labels from EHR 
    #TODO: investigate why total size of cxr_merged_icustays drops after the three steps below
    train_meta_with_labels = cxr_merged_icustays.merge(splits_labels_train, how='inner', on='stay_id')#change dataset size here
    # print(len(train_meta_with_labels))
    val_meta_with_labels = cxr_merged_icustays.merge(splits_labels_val, how='inner', on='stay_id')
    test_meta_with_labels = cxr_merged_icustays.merge(splits_labels_test, how='inner', on='stay_id')
    
    # Get rid of chest X-rays that don't have radiology reports
    # get the x-ray images and their lables, split them based on the prev split on ehr 
    metadata = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-metadata.csv')
    labels = pd.read_csv(f'{args.cxr_data_root}/mimic-cxr-2.0.0-chexpert.csv')
    metadata_with_labels = metadata.merge(labels[['study_id']], how='inner', on='study_id').drop_duplicates(subset=['dicom_id'])
    train_meta_with_labels = train_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    val_meta_with_labels = val_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    test_meta_with_labels = test_meta_with_labels.merge(metadata_with_labels[['dicom_id']], how='inner', on='dicom_id')
    
    print("Excluding CXR with missing radiology reports = ",len(train_meta_with_labels))

    # Multimodal class
    train_ds = MIMIC_CXR_EHR(args, train_meta_with_labels, ehr_train_ds, cxr_train_ds)
    print(len(train_ds))
    val_ds = MIMIC_CXR_EHR(args, val_meta_with_labels, ehr_val_ds, cxr_val_ds, split='val')
    print(len(val_ds))
    test_ds = MIMIC_CXR_EHR(args, test_meta_with_labels, ehr_test_ds, cxr_test_ds, split='test')
    print(len(test_ds))
    
    collate = my_collate
    
    # Multimodal data loader 
    train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, collate_fn=collate, drop_last=True) #pin_memory=True, num_workers=16,
    val_dl = DataLoader(val_ds, args.batch_size, shuffle=False, collate_fn=collate, drop_last=False) #pin_memory=True, num_workers=16,
    test_dl = DataLoader(test_ds, args.batch_size, shuffle=False, collate_fn=collate, drop_last=False) # pin_memory=True,num_workers=16,

    return train_dl, val_dl, test_dl

train_dl, val_dl, test_dl = load_cxr_ehr(args, ehr_train_ds, ehr_val_ds, cxr_train_ds, cxr_val_ds, ehr_test_ds, cxr_test_ds)


Number of CXR images= 377110
Number of ICU stays= 59372
Number of CXR associated with ICU stay based on subject ID= 368350
Number of unique CXR dicoms= 181195
Number of unique CXR study id= 122087
<bound method NDFrame.head of                                             dicom_id  subject_id  study_id  \
0       02aa804e-bde0afdd-112c0b34-7bc16630-4e384014    10000032  50414267   
1       174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962    10000032  50414267   
2       2a2277a9-b0ded155-c0de8eb9-c124d10e-82c5caab    10000032  53189527   
3       e084de3b-be89b11e-20fe3f9f-9c8d8dfe-4cfd202c    10000032  53189527   
4       68b5c4b1-227d0485-9cc38c3f-7b84ab51-4b472714    10000032  53911762   
...                                              ...         ...       ...   
368345  ee9155f3-944c056b-c76c73d0-3f792f2c-92ae461e    19999442  58497551   
368346  16b6c70f-6d36bd77-89d2fef4-9c4b8b0a-79c69135    19999442  58708861   
368347  58766883-376a15ce-3b323a28-6af950a0-16b793bd    19999987  55368

In [None]:
ehr_data, cxr_data, ehr_labels, cxr_labels, seq_lengths, pairs = next(iter(train_dl))
print(ehr_data.shape)
print(seq_lengths)
print(cxr_labels)

In [3]:
# Import Pytorch 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, Callback, TQDMProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision
import math

# Import other useful libraries
from sklearn.linear_model import LogisticRegression as LR
from sklearn.neural_network import MLPClassifier
import pickle
from flash.core.optimizers import LARS
import os
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from copy import deepcopy
from tqdm import tqdm

# Import custom libraries/functions
from encoders import LSTM, CXRModels
import load_tasks as tasks
from fusion_models import Fusion
class SimCLR(pl.LightningModule):

    def __init__(self, args, train_dl):

        super().__init__()
        assert args.temperature > 0.0, 'The temperature must be a positive float!'
        self.warmup_epochs= 10 #int(0.05*max_epochs) (10 as in SimCLR)
        self.automatic_optimization = False
        
        self.num_train_batches=len(train_dl)
        self.batch_size=args.batch_size
        hidden_dim=args.hidden_dim
        self.args=args
        self.LABEL_COLUMNS = tasks.load_labels(args.task)
        self.task = args.task
        
        # Load the architecture based on args
        self.model = Fusion(args)
        self.load_weights()
        self.freeze_weights()
        
               
    def load_weights(self):
        # loads both encoders for simclr
        load_dir_simclr = self.args.save_dir

        if self.args.load_state is not None: 
            # why this condition?
            # if 'LC' not in self.args.load_state:
            #     if 'mortality' in self.args.save_dir:
            #         load_dir_simclr = load_dir_simclr.replace('mortality', 'phenotyping')
                    
            if 'epoch' in self.args.load_state:
                model_dir='/'+self.args.load_state.split('_epoch')[0] + '/'
                if self.args.tag == 'eval_epoch':
                    checkpoint = torch.load(load_dir_simclr + model_dir + self.args.load_state+".ckpt", map_location="cpu")
                else:
                    checkpoint = torch.load(load_dir_simclr + model_dir + self.args.load_state+".ckpt")    
            else:
                if self.args.tag == 'eval_epoch':
                    checkpoint = torch.load(os.path.join(load_dir_simclr, self.args.load_state+".ckpt"), map_location="cpu")
                else:
                    checkpoint = torch.load(os.path.join(load_dir_simclr, self.args.load_state+".ckpt"))
            own_state = self.model.state_dict()
            own_keys = list(own_state.keys())
            checkpoint_keys = list(checkpoint['state_dict'].keys())
            
            print('Total number of checkpoint params = {}'.format(len(checkpoint_keys)))
            print('Total number of current model params = {}'.format(len(own_keys)))

            count = 0
            changed = []
            for name in own_keys:
                if name not in checkpoint_keys:
                    # double check if name exists in a different format
                    for x in checkpoint_keys:
                        if name in x:
                            param=checkpoint['state_dict'][x]
                            if isinstance(param, torch.nn.Parameter):
                                param=param.data
                            own_state[name].copy_(param)
                            count+=1
                else:
                    param=checkpoint['state_dict'][name]
                    if isinstance(param, torch.nn.Parameter):
                        param=param.data
                    own_state[name].copy_(param)
                    count+=1
            print('Total number params loaded for model weights = {}'.format(count))
        
    def freeze_weights(self):
        if self.args.finetune:
            if 'ehr' not in self.args.fusion_type:
                print("freezing cxr projection head")
                self.freeze(self.model.cxr_model_g)
            if 'cxr' not in self.args.fusion_type: 
                print("freezing ehr projection head")   
                self.freeze(self.model.ehr_model_g)
        else: 
            if 'lineareval' in self.args.fusion_type:
                print('freezing encoders')
                if 'ehr' not in self.args.fusion_type:
                    self.freeze(self.model.cxr_model)
                    self.freeze(self.model.cxr_model_g)
                if 'cxr' not in self.args.fusion_type:
                    self.freeze(self.model.ehr_model)
                    self.freeze(self.model.ehr_model_g) 

    def freeze_weights_2(self):
        if self.args.finetune:
            if 'ehr' not in self.args.fusion_type:
                print("finetuning")
        else: 
            if 'lineareval' in self.args.fusion_type:
                print('freezing encoders')
                if 'ehr' not in self.args.fusion_type and 'cxr' not in self.args.fusion_type:
                    self.freeze(self.model.cxr_model)
                    self.freeze(self.model.ehr_model)
                if 'ehr' not in self.args.fusion_type:
                    self.freeze(self.model.cxr_model)
                if 'cxr' not in self.args.fusion_type:
                    self.freeze(self.model.ehr_model)

    def freeze(self, model):
        for p in model.parameters():
            p.requires_grad = False     
    
    def configure_optimizers(self):
        
        if self.args.fusion_type == 'None':
            # Scaled learning rate in case of multiple GPUs
            if self.args.num_gpu > 1:
                effective_batchsize = self.args.batch_size*self.args.num_gpu
                scaled_lr = self.args.lr*effective_batchsize/self.args.batch_size
            else:
                scaled_lr = self.args.lr 
                        
            # Optimizer
            optimizer = LARS(self.parameters(), lr=scaled_lr, momentum=0.9, weight_decay=self.args.weight_decay)
            
            # Note that the order of the below affects the initial starting learning rate, hence do not change.
            # Main scheduler
            mainscheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 500, verbose=False)
            # Learning rate warmup
            lambda1= lambda epoch : (epoch+1)/self.warmup_epochs
            warmupscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda1, verbose=False)
                         
            return [optimizer], [mainscheduler, warmupscheduler]
        
        
        else:
            optimizer_adam = optim.AdamW(self.parameters(), lr=self.args.lr) #, weight_decay=self.args.weight_decay)
            lr_scheduler_adam = optim.lr_scheduler.MultiStepLR(optimizer_adam,milestones=[int(self.args.epochs*0.6),
                                                       int(self.args.epochs*0.8)],gamma=0.1)
            return [optimizer_adam], [lr_scheduler_adam]
                
    def logging_status(self, mode):
        if mode == 'train':
            on_step=True
            on_epoch=True
        else:
            on_step=False # Report for the sake of naming but it's not useful
            on_epoch=True
        return on_step, on_epoch
    
#     # TODO: Make this more efficient
#     def accuracy_top_k(self, k, temp):
#         temp = temp.argsort(dim=1, descending=True)[:, :k]
#         batchsize=temp.shape[0]
#         b_idx = np.arange(0,batchsize)
#         tot=0
#         for j in range(0, batchsize):
#             tot+= b_idx[j] in temp[j]
#         return tot*100/batchsize
    
    def bce_loss(self, preds, y, mode='train'):
        
        loss = nn.BCELoss()(preds, y)
        
        if torch.is_tensor(y):
            y = y.detach().cpu().numpy()
            
        auroc = np.round(roc_auc_score(y, preds.detach().cpu()), 4)
        auprc = np.round(average_precision_score(y, preds.detach().cpu()), 4)
        
        on_step=False
        on_epoch=True
        #self.log(mode + '_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        self.log(mode + '_auroc', auroc, on_step=on_step, on_epoch=on_epoch) #, logger=True)
        self.log(mode + '_auprc', auprc, on_step=on_step, on_epoch=on_epoch) #, logger=True)
        
        return loss 
    
    
    def info_nce_loss(self, feats_ehr, feats_img, mode='train'):
        # Calculate cosine similarity matrix
        cos_sim = F.cosine_similarity(feats_img[:,None,:], feats_ehr[None,:,:], dim=-1)
        #print(cos_sim.size())
        cos_sim = cos_sim /  self.args.temperature
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)
        #print(self_mask.size())
        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
        #temp_1=torch.reshape(cos_sim, (cos_sim.shape[0],cos_sim.shape[1]))
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0) 
        #temp_2=torch.reshape(cos_sim_2, (cos_sim_2.shape[0],cos_sim_2.shape[1]))
        
        # Total loss 
        loss = -(nll_1 + nll_2).mean()
            
        # Logging ranking metrics
        #self.log(mode+'_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        on_step, on_epoch = self.logging_status(mode)
        #self.log(mode+'_acc_top1', self.accuracy_top_k(1, temp_1), on_step=on_step, on_epoch=on_epoch) #, logger=True)
        #self.log(mode+'_acc_top5', self.accuracy_top_k(5, temp_1), on_step=on_step, on_epoch=on_epoch) #, logger=True)
                     
        return loss
    
    
    
    def modified_info_nce_loss(self, feats_ehr, feats_img, time_diff, mode='train'):
        # Calculate cosine similarity matrix
        cos_sim = F.cosine_similarity(feats_img[:,None,:], feats_ehr[None,:,:], dim=-1)
        cos_sim = cos_sim /  self.args.temperature
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool,  device=cos_sim.device)
        cos_sim_negative = torch.clone(cos_sim)
        cos_sim_negative.masked_fill_(self_mask, -9e15)
        
        # Compute the values of beta
        k = 1
        time_diff = torch.FloatTensor(time_diff)
        beta = torch.exp(-k*time_diff).to(self.device)
        
        # Compute based on img->ehr
        nll_1 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=1)
        nll_1 = beta*nll_1
        
        # Compute based on ehr->img
        nll_2 = cos_sim[self_mask] - torch.logsumexp(cos_sim_negative, dim=0)
        nll_2 = beta*nll_2
        
        # Total loss 
        loss = -(nll_1 + nll_2).mean()
            
        # Logging ranking metrics
        #self.log(mode+'_loss', loss, on_step=on_step, on_epoch=on_epoch, logger=True)
        on_step, on_epoch = self.logging_status(mode)
       
        return loss 
    
    
    def off_diagonal(self,x):
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    
    def vicreg_loss(self, feats_ehr, feats_img, mode='train'):
        x = feats_ehr
        y = feats_img
        repr_loss = F.mse_loss(x, y)

        #x = torch.cat(FullGatherLayer.apply(x), dim=0) #
        #y = torch.cat(FullGatherLayer.apply(y), dim=0) #
        
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)

        std_loss_x = torch.mean(F.relu(1 - std_x)) / 2 
        std_loss_y = torch.mean(F.relu(1 - std_y)) / 2
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
                
        cov_x = (x.T @ x) / (self.args.batch_size - 1)
        cov_y = (y.T @ y) / (self.args.batch_size - 1)
        
        num_features = len(cov_x) #TODO as arg
                
        cov_loss_x = self.off_diagonal(cov_x).pow_(2).sum().div(num_features)
        cov_loss_y = self.off_diagonal(cov_y).pow_(2).sum().div(num_features)
        cov_loss = self.off_diagonal(cov_x).pow_(2).sum().div(num_features) + self.off_diagonal(cov_y).pow_(2).sum().div(num_features)

        loss = (
            self.args.sim_coeff * repr_loss
            + self.args.std_coeff * std_loss
            + self.args.cov_coeff * cov_loss
        )
        on_step, on_epoch = self.logging_status(mode)
        
        return loss, std_loss_x, std_loss_y, cov_loss_x, cov_loss_y, repr_loss
        
        
    
    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        mode = 'train'
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False) & (self.args.vicreg == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            # print("ehr", ehr, "seq_lengths" , seq_lengths , "imgs", imgs)
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            # print(feats_ehr, feats_img)
            # Compute and log infoNCE loss
            loss = self.info_nce_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
            
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == True) & (self.args.vicreg == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log infoNCE loss
            loss = self.modified_info_nce_loss(feats_ehr, feats_img, time_diff, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == False) & (self.args.vicreg == True)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log vicreg loss
            loss, std_loss_x, std_loss_y, cov_loss_x, cov_loss_y, repr_loss = self.vicreg_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_std_ehr', std_loss_x, on_step=True, on_epoch=True) 
            self.log(mode+'_std_img', std_loss_y, on_step=True, on_epoch=True)
            self.log(mode+'_cov_ehr', cov_loss_x, on_step=True, on_epoch=True) 
            self.log(mode+'_cov_img', cov_loss_y, on_step=True, on_epoch=True)  
            self.log(mode+'_repr_loss', repr_loss, on_step=True, on_epoch=True)  


            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
            
                        
        else:
            
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
                # output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                    # print ("linear eval" , output)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            print(preds)
            exit()

            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            self.log(mode+'_loss', loss, on_step=True, on_epoch=True) #, logger=True)
    
        # import pdb; pdb.set_trace()
        # print("loss" , loss.shape)
        # Backpropagate
        self.manual_backward(loss)
        # Optimizer step
        opt.step()
        # Learning rate step
        if self.args.fusion_type=='None':
            mainscheduler, warmupscheduler = self.lr_schedulers()
            if (self.trainer.is_last_batch) and (self.trainer.current_epoch < self.warmup_epochs-1):
                warmupscheduler.step()
            elif (self.trainer.is_last_batch) and (self.trainer.current_epoch >= self.warmup_epochs-1):
                mainscheduler.step()
                
#             if (batch_idx==self.num_train_batches-1) & (self.trainer.current_epoch < self.warmup_epochs-1):
#                 warmupscheduler.step()
#             elif (batch_idx==self.num_train_batches-1) & (self.trainer.current_epoch >= self.warmup_epochs-1):
#                 mainscheduler.step()
            

            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
        else:
            return {'loss': loss}

    
    def validation_step(self, batch, batch_idx):
        mode='val'
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs) 
            
            # Compute and log infoNCE loss
            loss = self.info_nce_loss(feats_ehr, feats_img, mode)
            self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
        
        elif ((self.args.fusion_type=='None') & (self.args.beta_infonce == True)):
            ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            feats_ehr, feats_img = self.model(ehr, seq_lengths, imgs)
            
            # Compute and log infoNCE loss
            loss = self.modified_info_nce_loss(feats_ehr, feats_img, time_diff, mode)
            self.log(mode+'_loss', loss, on_step=False, on_epoch=True) #, logger=True)
            return {'loss': loss, 'feats_ehr': feats_ehr.detach().cpu(), 'feats_img': feats_img.detach().cpu(), 'y_ehr':y_ehr}
            
        else:
            
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            # print(preds)
            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            # Compute and log BCE loss
            #loss = self.bce_loss(batch, mode='val')
            self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
            return {'loss': loss}
        
    def test_step(self, batch, batch_idx):
        mode='test'
        
        
        # Forward pass for SimCLR
        if ((self.args.fusion_type=='None') & (self.args.beta_infonce == False)):
            if self.args.beta_infonce == True:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs, time_diff = batch
            else:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch
            ehr = torch.from_numpy(ehr).float()
            ehr = ehr.to(self.device)
            
            # At test time of SIMCLR, always return all the layer features
            if self.args.mode == 'eval':
                feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3 = self.model(ehr, seq_lengths, imgs) 
            
                # Compute and log infoNCE loss
                if self.args.beta_infonce == True:
                    loss = self.modified_info_nce_loss(feats_ehr_3, feats_img_3, time_diff, mode)
                else:
                    loss = self.info_nce_loss(feats_ehr_3, feats_img_3, mode)
                self.log(mode+'_loss_epoch', loss, on_step=False, on_epoch=True) #, logger=True)
            
                return {'loss': loss,   'feats_ehr_0': feats_ehr_0.detach().cpu(), 
                                        'feats_ehr_3': feats_ehr_3.detach().cpu(), 
                                        'feats_img_0': feats_img_0.detach().cpu(), 
                                        'feats_img_3': feats_img_3.detach().cpu(), 
                                        'y_ehr':y_ehr}        
        
        else:
            if self.args.finetune:
                ehr, imgs, y_ehr, y_cxr, seq_lengths, pairs = batch

                ehr = torch.from_numpy(ehr).float()
                ehr = ehr.to(self.device)
                imgs = imgs.to(self.device)
                output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)
                y_ehr = torch.from_numpy(y_ehr)
        
            else: # Features are already processed for linear classifier
                seq_lengths=None
                if 'ehr' in self.args.fusion_type:
                    ehr, y_ehr = batch
                    ehr = ehr.to(self.device)
                    output = self.model(x=ehr,seq_lengths=seq_lengths)
                elif 'cxr' in self.args.fusion_type:
                    imgs, y_cxr, y_ehr = batch
                    imgs = imgs.to(self.device)
                    output = self.model(img=imgs)
                else:
                    ehr, imgs, y_ehr, y_cxr = batch
                    ehr = ehr.to(self.device)
                    imgs = imgs.to(self.device)
                    output = self.model(x=ehr, seq_lengths=seq_lengths, img=imgs)


            y = y_ehr.float()
            y = y.to(self.device)

            preds = output[self.args.fusion_type].squeeze()
            
            # print(y.shape, preds.shape)
            # Compute and log BCE loss
            loss = self.bce_loss(preds, y, mode)
            #loss = self.bce_loss(batch, mode=mode)
            return {'loss': loss, 'preds': preds, 'y_ehr': y}
                
    
    def process_features(self, outputs, mode):
        y = []
        if self.args.mode=='eval':
            feats_ehr_0=[]
            feats_ehr_3=[]
            feats_img_0=[]
            feats_img_3=[]
        elif mode == 'test':
            preds = []
        else:
            feats_ehr = []
            feats_img = []
        # Iterate through batches and append
        i=0
        for output in outputs:
            if i ==0:
                if mode == 'test':
                    preds = output['preds'].detach().cpu()
                elif self.args.mode == 'eval':
                    feats_ehr_0 = output['feats_ehr_0'].detach().cpu()
                    feats_ehr_3 = output['feats_ehr_3'].detach().cpu()
                    feats_img_0 = output['feats_img_0'].detach().cpu()
                    feats_img_3 = output['feats_img_3'].detach().cpu()
                else: 
                    feats_ehr = output['feats_ehr'].detach().cpu()
                    feats_img = output['feats_img'].detach().cpu()
                y = output['y_ehr'].tolist()
                
            else:
                if mode == 'test':
                    preds = torch.cat((preds, output['preds'].detach().cpu()))
                elif self.args.mode == 'eval':
                    feats_ehr_0 = torch.cat((feats_ehr_0, output['feats_ehr_0'].detach().cpu()))
                    feats_ehr_3 = torch.cat((feats_ehr_3, output['feats_ehr_3'].detach().cpu()))
                    feats_img_0 = torch.cat((feats_img_0, output['feats_img_0'].detach().cpu()))
                    feats_img_3 = torch.cat((feats_img_3, output['feats_img_3'].detach().cpu()))
                else:
                    feats_ehr = torch.cat((feats_ehr, output['feats_ehr'].detach().cpu()))
                    feats_img = torch.cat((feats_img, output['feats_img'].detach().cpu()))
                y.extend(output['y_ehr'].tolist())
            i+=1
        if mode =='test':
            return y, preds
        elif self.args.mode=='eval':
            return feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3, y
        else:
            return feats_ehr, feats_img, y
    
    def save_features(self, x, descrip, mode):
        model_path = self.args.save_dir+'/simclr_lr/'+self.args.file_name
        if not os.path.exists(model_path):
          os.makedirs(model_path)
        
        torch.save(x, model_path+'/{}_{}_epoch_{}.pt'.format(mode, descrip, self.current_epoch))
    
    def training_epoch_end(self, outputs):
        mode='train'
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            feats_ehr, feats_img, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr, 'feats_ehr', mode)
            self.save_features(feats_img, 'feats_img', mode)      
            self.save_features(y, 'y', mode)   
        
    def validation_epoch_end(self, outputs):
        mode='val'
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            feats_ehr, feats_img, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr, 'feats_ehr', mode)
            self.save_features(feats_img, 'feats_img', mode)      
            self.save_features(y, 'y', mode)
            

    def test_epoch_end(self, outputs):
        if ((self.args.fusion_type=='None') & (self.args.save_features == True)):
            mode = self.args.eval_set
            feats_ehr_0, feats_ehr_3, feats_img_0, feats_img_3, y = self.process_features(outputs, mode)
            self.save_features(feats_ehr_0, 'feats_ehr_0', mode)
            self.save_features(feats_ehr_3, 'feats_ehr_3', mode)
            self.save_features(feats_img_0, 'feats_img_0', mode)
            self.save_features(feats_img_3, 'feats_img_3', mode)      
            self.save_features(y, 'y', mode)
        else:
            if self.task =='phenotyping':
                mode = 'test'
                y, preds = self.process_features(outputs, mode)

                auroc_per_label = np.round(roc_auc_score(y, preds, average=None), 4)
                auprc_per_label = np.round(average_precision_score(y, preds, average=None), 4)


                auroc_label={}
                auprc_label={}
                for i, name in enumerate(self.LABEL_COLUMNS):
                    auroc_label[name]=auroc_per_label[i].item()
                    auprc_label[name]=auprc_per_label[i].item()
                    #print(name, auroc_per_label[i], auprc_per_label[i])

                self.log('auroc_label', auroc_label)
                self.log('auprc_label', auprc_label)
            
    def calculate_auroc_epoch(self, outputs, mode):
        labels = []
        predictions = []
        auroc_label={}
        outputs=outputs[self.args.fusion_type].squeeze()
        
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)

        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        for i, name in enumerate(self.LABEL_COLUMNS):
            class_roc_auc = roc_auc_score(labels[:, i], predictions[:, i])
            auroc_label[name]=class_roc_auc
            
        auroc = roc_auc_score(labels, predictions)
        auprc = average_precision_score(labels, predictions)
        
        return auroc, auprc 

In [4]:
# Load the model, weights (if any), and freeze layers (if any)
print("Loading model...")
if args.pretrain_type == 'simclr':
    model = SimCLR(args, train_dl)
    
model

Loading model...


NameError: name 'args' is not defined

In [None]:
from copy import deepcopy

# Prepare data features for downstream tasks 
@torch.no_grad()
def prepare_data_features(device, model, data_loader, bs, fusion_layer, fusion_type):
    print(fusion_layer)
    # Prepare model
    network = deepcopy(model)
    if 'ehr' not in fusion_type:
        network.model.cxr_model.vision_backbone.fc = nn.Identity() # Removing projection head g(.) 
     
    if 'cxr' not in fusion_type:
        network.model.ehr_model.dense_layer = nn.Identity() # Removing projection head g(.)
    
    network.eval()
    network.to(device)

    # Encode all images
    feats_ehr, feats_imgs, labels_ehr, labels_imgs = [], [], [], []
    
    for batch_ehr, batch_imgs, batch_ehr_labels, batch_cxr_labels, seq_lengths, pairs in data_loader:
        labels_ehr.append(torch.from_numpy(batch_ehr_labels).detach())
        #time_diff.append(torch.from_numpy(np.array(batch_time)).detach())
        
        if 'cxr' not in fusion_type:
            batch_ehr = torch.from_numpy(batch_ehr).float().to(device)
            #batch_ehr = batch_ehr.to(device)
            batch_ehr_feats = network.model.ehr_model(batch_ehr, seq_lengths)
            if fusion_layer == 3:
                batch_ehr_feats = network.model.ehr_model_g(batch_ehr_feats)
                
            print('ehr batch shape', np.shape(batch_ehr_feats))
            #batch_ehr_feats=torch.reshape(batch_ehr_feats, (1, np.shape(batch_ehr_feats)[0])) #TODO need this for other code
            feats_ehr.append(batch_ehr_feats.detach().cpu()) 

        if 'ehr' not in fusion_type:
            batch_imgs = batch_imgs.to(device)
            batch_imgs_feats = network.model.cxr_model(batch_imgs)
            if fusion_layer == 3:
                batch_imgs_feats = network.model.cxr_model_g(batch_imgs_feats)
                
            print('cxr batch shape', np.shape(batch_imgs_feats))
            feats_imgs.append(batch_imgs_feats.detach().cpu())
            labels_imgs.append(batch_cxr_labels)
    
    labels_ehr = torch.cat(labels_ehr, dim=0)
    #time_diff = torch.cat(time_diff, dim=0)
    
    print('shape ehr', np.shape(feats_ehr))
    print('shape imgs', np.shape(feats_imgs))
    
    print('type ehr', type(feats_ehr))
    print('type cxr', type(feats_imgs))
    
    print(type(feats_ehr[0]))
    print(type(feats_ehr[0][0]))
    print(feats_ehr[0][0])
    
    print(type(feats_imgs[0]))
    print(type(feats_imgs[0][0]))
    print(feats_imgs[0][0])
    
    if 'cxr' not in fusion_type:
        #if len(feats_ehr) == len(labels_ehr):
        #    feats_ehr=torch.as_tensor(feats_ehr)
        feats_ehr = torch.cat(feats_ehr, dim=0)
        

    if 'ehr' not in fusion_type:
        feats_imgs = torch.cat(feats_imgs, dim=0)
        labels_imgs = torch.cat(labels_imgs, dim=0)

    if 'cxr' in fusion_type:
        return data.DataLoader(data.TensorDataset(feats_imgs, labels_imgs, labels_ehr), batch_size=bs, shuffle=False, drop_last=False)
    elif 'ehr' in fusion_type:
        return data.DataLoader(data.TensorDataset(feats_ehr, labels_ehr), batch_size=bs, shuffle=False, drop_last=False)
    else:
        return data.DataLoader(data.TensorDataset(feats_ehr, feats_imgs, labels_ehr, labels_imgs), batch_size=bs, shuffle=False, drop_last=False)

    

In [None]:
# Set cuda device
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  

print('Using {} device...'.format(device)) 

print("Processing features for linear evaluation...")
train_dl = prepare_data_features(device, model, train_dl, args.batch_size, args.fusion_layer, args.fusion_type) 
val_dl = prepare_data_features(device, model, val_dl, args.batch_size, args.fusion_layer, args.fusion_type)
test_dl = prepare_data_features(device, model, test_dl, args.batch_size, args.fusion_layer, args.fusion_type)
    

Using cuda:0 device...
Processing features for linear evaluation...
0
