In [None]:
import argparse
import pandas as pd
import numpy as np
import json
import os
import glob
#from data_utils import *
from sklearn.preprocessing import OneHotEncoder
from transformers import AutoTokenizer, AutoModel
import warnings
import pickle
from torch.utils.data import Dataset, DataLoader
import torch
from PIL import Image
import torchvision.transforms as transforms
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer


warnings.filterwarnings('ignore', category=UserWarning, message='.*OneHotEncoder was fitted without feature names.*')


class TextDataset(Dataset):
    def __init__(self, dataframe, target_cols, tokenizer,type_text, max_token_len=512):
        self.tokenizer = tokenizer
        self.type_text = type_text
        self.dataframe = dataframe
        self.labels = dataframe[target_cols].values
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        data_row = self.dataframe.iloc[idx]

        text_data = data_row[self.type_text]
        
        encoding = self.tokenizer.encode_plus(
            text_data,
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        text = encoding['input_ids'].flatten()
        att_mask = encoding['attention_mask'].flatten()
        labels = torch.tensor(self.labels[idx], dtype=torch.float)
        return text, att_mask, labels


class DemographicsDataset(Dataset):
    def __init__(self, dataframe, labels):
        self.features = dataframe.values 
        self.labels = labels  

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        demo = torch.tensor(self.features[idx], dtype=torch.float)
        labels = torch.tensor(self.labels[idx], dtype=torch.float)
        return demo, labels

class MedicalImageDataset(Dataset):
    def __init__(self, dataframe, target_cols, img_col, transform=None):
        self.dataframe = dataframe
        self.img_col = img_col
        self.labels = dataframe[target_cols].values
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        labels = torch.tensor(self.labels[idx], dtype=torch.float)
        img_path = self.dataframe.iloc[idx][self.img_col]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB for consistency

        if self.transform:
            image = self.transform(image)

        return image, labels

    
class TimeSeriesDataset(Dataset):
    def __init__(self, time_series_data):
        self.time_series_data = time_series_data

    def __len__(self):
        return len(self.time_series_data)

    def __getitem__(self, idx):
        time_series = self.time_series_data[idx]
        ts = torch.tensor(time_series, dtype=torch.float)
        return ts

def regular_sample_sequence(seq, fixed_length):
    n = len(seq)
    if n > fixed_length:
        # Select indices at regular intervals
        indices = np.round(np.linspace(0, n - 1, fixed_length)).astype(int)
        sampled_seq = seq[indices]
    else:
        # If shorter, pad the sequence
        padded_seq = np.zeros((fixed_length, seq.shape[1]))
        padded_seq[:n] = seq
        sampled_seq = padded_seq
    return sampled_seq



def preprocess_timeseries(file_path, config, encoders): #, fixed_length
    categorical_cols = [col for col in config["id_to_channel"] if config["is_categorical_channel"][col]]
    continuous_cols = [col for col in config["id_to_channel"] if not config["is_categorical_channel"][col]]

    ts = pd.read_csv(file_path)

    # Impute missing values
    ts[continuous_cols] = ts[continuous_cols].fillna(method='ffill').fillna(0)
    ts[categorical_cols] = ts[categorical_cols].fillna('missing')
    
    # Initialize a DataFrame for the encoded data
    encoded_data = pd.DataFrame(index=ts.index)

    # One-hot encode each categorical column separately
    for col in categorical_cols:
        if col in encoders:  # Check if the encoder for the column exists
            encoded_col = encoders[col].transform(ts[[col]])
            encoded_col_df = pd.DataFrame(encoded_col, columns=encoders[col].get_feature_names_out([col]), index=ts.index)
            encoded_data = pd.concat([encoded_data, encoded_col_df], axis=1)

    ts = ts.drop(columns=categorical_cols)
    ts = pd.concat([ts, encoded_data], axis=1)
    
    # Normalize continuous variables
    scaler = StandardScaler()
    ts[continuous_cols] = scaler.fit_transform(ts[continuous_cols])
    
    return ts


def pad_sequence(seq, maxlen, n_features):
    padded_seq = np.zeros((maxlen, n_features))
    padded_seq[:len(seq)] = seq[:maxlen]
    return padded_seq


def preprocess_demo(df, categorical_encoders=None):
    # Specify demographic columns
    demographic_cols = ['admittime', 'admission_type', 'admission_location', 
                        'insurance', 'language', 'marital_status', 'ethnicity', 
                        'gender', 'anchor_age', 'anchor_year', 'anchor_year_group']
    
    df = df[demographic_cols].copy()

    # Initially identify categorical and continuous columns
    continuous_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
    categorical_cols = [col for col in df.select_dtypes(include=['object']).columns if col != 'admittime']
    print(categorical_cols)

    # Convert 'admittime' to datetime and extract features
    df['admittime'] = pd.to_datetime(df['admittime'])
    df['admission_hour'] = df['admittime'].dt.hour
    df['admission_dayofweek'] = df['admittime'].dt.dayofweek
    df['admission_month'] = df['admittime'].dt.month
    
    # Drop the original 'admittime' column after feature extraction
    df.drop(columns=['admittime'], inplace=True)

    # Impute missing values for continuous and categorical columns
    imputer_continuous = SimpleImputer(strategy='mean')
    imputer_categorical = SimpleImputer(strategy='most_frequent')
    df[continuous_cols] = imputer_continuous.fit_transform(df[continuous_cols])
    for col in categorical_cols:
        df[col] = imputer_categorical.fit_transform(df[[col]]) #.ravel()
    
    # One-hot encode categorical variables
    if categorical_encoders is None:
        categorical_encoders = {}
        for col in categorical_cols:
            encoder = OneHotEncoder(handle_unknown='ignore', sparse =False)
            categorical_encoders[col] = encoder.fit(df[[col]])
            transformed = encoder.transform(df[[col]])
            df = df.drop(columns=[col])
            df = pd.concat([df, pd.DataFrame(transformed, columns=encoder.get_feature_names_out([col]))], axis=1)
    else:
        for col in categorical_cols:
            encoder = categorical_encoders[col]
            transformed = encoder.transform(df[[col]])
            df = df.drop(columns=[col])
            df = pd.concat([df, pd.DataFrame(transformed, columns=encoder.get_feature_names_out([col]))], axis=1)
    
    scaler = StandardScaler()
    df[continuous_cols] = scaler.fit_transform(df[continuous_cols])

    return df, categorical_encoders

def create_demo_general(path_to_core):
    pts = pd.read_csv(path_to_core + "patients.csv")
    admissions = pd.read_csv(path_to_core + "admissions.csv")
    demo = admissions.merge(pts, on = "subject_id")
    demo = demo[['subject_id', 'hadm_id', 'admittime',
       'admission_type', 'admission_location',
       'insurance', 'language', 'marital_status', 'ethnicity', 'gender', 'anchor_age',
       'anchor_year', 'anchor_year_group']]
    return demo

def create_time_series_dataset(file_paths, config, encoders, fixed_length):
    processed_data = [preprocess_timeseries(path, config, encoders) for path in file_paths]
    sampled_data = [regular_sample_sequence(data.values, fixed_length) for data in processed_data]
    print(sampled_data[0].shape)
    print(sampled_data[20].shape)
    
    return TimeSeriesDataset(sampled_data)

def create_demo_dataset(pre_processed_data, target_cols):
    return DemographicsDataset(pre_processed_data, target_cols)

def create_image_dataset(data, target_cols, img_col, transform):
    # Initialize the image dataset
    return MedicalImageDataset(data, target_cols, img_col, transform)

def create_text_dataset(data, target_cols, type_text):
    tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
    return TextDataset(data,target_cols, tokenizer, type_text)

def determine_categories_and_length(ehr_data_dir, task, phase, data, config):
    categorical_cols = [col for col in config["id_to_channel"] if config["is_categorical_channel"][col]]

    all_categories = {col: set() for col in categorical_cols}
    lengths = []
    for file_path in data['filename']:
        full_path = f'{ehr_data_dir}/{task}/{phase}/' + file_path
        ts = pd.read_csv(full_path, usecols=categorical_cols)
        lengths.append(len(ts))
        for col in categorical_cols:
            all_categories[col].update(ts[col].dropna().unique())

    fixed_length = int(np.mean(lengths))

    return all_categories, fixed_length

def create_encoders(unique_categories):
    encoders = {}
    for col, categories in unique_categories.items():
        if not categories:
            print(f"No unique values found for column '{col}'. Skipping this column.")
            continue

        unique_values = list(categories)
        encoder = OneHotEncoder(categories=[unique_values], handle_unknown='ignore', sparse=False)
        encoder.fit(np.array(unique_values).reshape(-1, 1))

        encoders[col] = encoder

    return encoders


def save_dataset(dataset, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(dataset, f)



In [None]:
# Define variables for each command line argument
ehr_data_dir = "YOUR_PATH/physionet.org/files/mimiciv/1.0/"
cxr_data_dir = "YOUR_PATH/physionet.org/files/mimic-cxr-jpg/2.0.0"
discharge_path = "YOUR_PATH/mimic_iv/note/discharge.csv"
rad_path = "YOUR_PATH/mimic_iv/note/radiology.csv"
core_dir = "YOUR_PATH/physionet.org/files/mimiciv/1.0/core/"
config_file_path = "discretizer_config.json" #mimic4extract/mimic3models/resources/, from MedFuse repo

with open(config_file_path, 'r') as file:
    config = json.load(file)

In [None]:
data_dir = cxr_data_dir
cxr_metadata = pd.read_csv(f'{data_dir}/mimic-cxr-2.0.0-metadata.csv')
icu_stay_metadata = pd.read_csv(f'{ehr_data_dir}/per_subject/all_stays.csv')
columns = ['subject_id', 'stay_id', 'intime', 'outtime', 'hadm_id']


In [None]:
cxr_merged_icustays = cxr_metadata.merge(icu_stay_metadata[columns ], how='inner', on='subject_id')
cxr_merged_icustays

In [None]:
# combine study date time
cxr_merged_icustays = cxr_metadata.merge(icu_stay_metadata[columns ], how='inner', on='subject_id')
cxr_merged_icustays['StudyTime'] = cxr_merged_icustays['StudyTime'].apply(lambda x: f'{int(float(x)):06}' )
cxr_merged_icustays['StudyDateTime'] = pd.to_datetime(cxr_merged_icustays['StudyDate'].astype(str) + ' ' + cxr_merged_icustays['StudyTime'].astype(str) ,format="%Y%m%d %H%M%S")

cxr_merged_icustays.intime=pd.to_datetime(cxr_merged_icustays.intime)
cxr_merged_icustays.outtime=pd.to_datetime(cxr_merged_icustays.outtime)
end_time = cxr_merged_icustays.outtime
#if task == 'in-hospital-mortality':
#    end_time = cxr_merged_icustays.intime + pd.DateOffset(hours=48)

cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&((cxr_merged_icustays.StudyDateTime<=end_time))]

# cxr_merged_icustays_during = cxr_merged_icustays.loc[(cxr_merged_icustays.StudyDateTime>=cxr_merged_icustays.intime)&((cxr_merged_icustays.StudyDateTime<=cxr_merged_icustays.outtime))]
# select cxrs with the ViewPosition == 'AP
cxr_merged_icustays_AP = cxr_merged_icustays_during[cxr_merged_icustays_during['ViewPosition'] == 'AP']

groups = cxr_merged_icustays_AP.groupby('stay_id')

groups_selected = []
for group in groups:
    # select the latest cxr for the icu stay
    selected = group[1].sort_values('StudyDateTime').tail(1).reset_index()
    groups_selected.append(selected)
groups = pd.concat(groups_selected, ignore_index=True)

paths = glob.glob(cxr_data_dir + "/resized/" + '*.jpg')
groups["dicom_id_path"] = cxr_data_dir + "/resized/" + groups["dicom_id"] + ".jpg"
groups = groups[groups["dicom_id_path"].isin(paths)]

In [None]:
ds = pd.read_csv(discharge_path)

In [None]:
ds['text'] = ds['text'].apply(lambda x: x.split('Discharge Diagnosis')[0])

In [None]:
rad = pd.read_csv(rad_path)

In [None]:
rad.head()["text"][0]

In [None]:
demo = create_demo_general(core_dir)
demo.head()

In [None]:

task = "phenotyping"
test_data = pd.read_csv(f'{ehr_data_dir}/{task}/{"test"}_listfile.csv')
val_data = pd.read_csv(f'{ehr_data_dir}/{task}/{"val"}_listfile.csv')


In [None]:
data = pd.concat([test_data, val_data])

In [None]:
data = data.merge(groups, on = ["stay_id"])
data= data.merge(ds, on = ["subject_id", "hadm_id"])
data= data.merge(rad, on = ["subject_id", "hadm_id"])
data= data.merge(demo, on = ["subject_id", "hadm_id"])

In [None]:
data = pd.read_csv('YOUR_PATH/physionet.org/files/mimiciv/1.0/finetune_pheno_data_all_modalities.csv')
data.head()

In [None]:
CLASSES = [
       'Acute and unspecified renal failure', 'Acute cerebrovascular disease',
       'Acute myocardial infarction', 'Cardiac dysrhythmias',
       'Chronic kidney disease',
       'Chronic obstructive pulmonary disease and bronchiectasis',
       'Complications of surgical procedures or medical care',
       'Conduction disorders', 'Congestive heart failure; nonhypertensive',
       'Coronary atherosclerosis and other heart disease',
       'Diabetes mellitus with complications',
       'Diabetes mellitus without complication',
       'Disorders of lipid metabolism', 'Essential hypertension',
       'Fluid and electrolyte disorders', 'Gastrointestinal hemorrhage',
       'Hypertension with complications and secondary hypertension',
       'Other liver diseases', 'Other lower respiratory disease',
       'Other upper respiratory disease',
       'Pleurisy; pneumothorax; pulmonary collapse',
       'Pneumonia (except that caused by tuberculosis or sexually transmitted disease)',
       'Respiratory failure; insufficiency; arrest (adult)',
       'Septicemia (except in labor)', 'Shock'
    ]

In [None]:
# Assuming 'stay_id' is a column in the DataFrame
stay_ids = data['stay_id'].unique()

# Shuffle the array of unique stay_ids
np.random.shuffle(stay_ids)

# Determine the indices to split stay_ids into train, validation, and test sets
train_split_index = int(0.7 * len(stay_ids))  # 70% for training
val_split_index = int(0.85 * len(stay_ids))  # Next 15% for validation, leaving 15% for test

# Select stay_ids for train, validation, and test sets
train_stay_ids = stay_ids[:train_split_index]
val_stay_ids = stay_ids[train_split_index:val_split_index]
test_stay_ids = stay_ids[val_split_index:]

# Filter the DataFrame based on the selected stay_ids for each set
train_data = data[data['stay_id'].isin(train_stay_ids)]
val_data = data[data['stay_id'].isin(val_stay_ids)]
test_data = data[data['stay_id'].isin(test_stay_ids)]

# Reset the index for each DataFrame to have a clean start from 0
train_data = train_data.reset_index(drop=True)
val_data = val_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [None]:
train_data.to_csv("finetune_train_data.csv", index=False)
val_data.to_csv("finetune_val_data.csv", index=False)
test_data.to_csv("finetune_test_data.csv", index=False)

In [None]:
train_data = pd.read_csv("finetune_pheno_train_data.csv")
val_data = pd.read_csv("finetune_pheno_val_data.csv")
test_data = pd.read_csv("finetune_pheno_test_data.csv")

In [None]:
len(val_data[CLASSES].sum())

In [None]:
val_data[CLASSES].sum()

In [None]:
train_data_contrast = pd.read_csv("train_data.csv")

In [None]:
# CODE ADAPTED FROM MEDFUSE REPO
from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import platform
import pickle
import json
import os


class Discretizer:
    def __init__(self, timestep=0.8, store_masks=True, impute_strategy='zero', start_time='zero',
                 config_path= 'discretizer_config.json'):

        with open(config_path) as f:
            config = json.load(f)
            self._id_to_channel = config['id_to_channel']
            self._channel_to_id = dict(zip(self._id_to_channel, range(len(self._id_to_channel))))
            self._is_categorical_channel = config['is_categorical_channel']
            self._possible_values = config['possible_values']
            self._normal_values = config['normal_values']

        self._header = ["Hours"] + self._id_to_channel
        self._timestep = timestep
        self._store_masks = store_masks
        self._start_time = start_time
        self._impute_strategy = impute_strategy

        # for statistics
        self._done_count = 0
        self._empty_bins_sum = 0
        self._unused_data_sum = 0

    def transform(self, X, header=None, end=None):
        if header is None:
            header = self._header
        assert header[0] == "Hours"
        eps = 1e-6

        N_channels = len(self._id_to_channel)
        ts = [float(row[0]) for row in X]
        for i in range(len(ts) - 1):
            assert ts[i] < ts[i+1] + eps

        if self._start_time == 'relative':
            first_time = ts[0]
        elif self._start_time == 'zero':
            first_time = 0
        else:
            raise ValueError("start_time is invalid")

        if end is None:
            max_hours = max(ts) - first_time
        else:
            max_hours = end - first_time

        N_bins = int(max_hours / self._timestep + 1.0 - eps)

        cur_len = 0
        begin_pos = [0 for i in range(N_channels)]
        end_pos = [0 for i in range(N_channels)]
        for i in range(N_channels):
            channel = self._id_to_channel[i]
            begin_pos[i] = cur_len
            if self._is_categorical_channel[channel]:
                end_pos[i] = begin_pos[i] + len(self._possible_values[channel])
            else:
                end_pos[i] = begin_pos[i] + 1
            cur_len = end_pos[i]

        data = np.zeros(shape=(N_bins, cur_len), dtype=float)
        mask = np.zeros(shape=(N_bins, N_channels), dtype=int)
        original_value = [["" for j in range(N_channels)] for i in range(N_bins)]
        total_data = 0
        unused_data = 0

        def write(data, bin_id, channel, value, begin_pos):
            channel_id = self._channel_to_id[channel]
            if self._is_categorical_channel[channel]:
                category_id = self._possible_values[channel].index(value)
                N_values = len(self._possible_values[channel])
                one_hot = np.zeros((N_values,))
                one_hot[category_id] = 1
                for pos in range(N_values):
                    data[bin_id, begin_pos[channel_id] + pos] = one_hot[pos]
            else:
                data[bin_id, begin_pos[channel_id]] = float(value)

        for row in X:
            t = float(row[0]) - first_time
            if t > max_hours + eps:
                continue
            bin_id = int(t / self._timestep - eps)
            assert 0 <= bin_id < N_bins

            for j in range(1, len(row)):
                if row[j] == "":
                    continue
                channel = header[j]
                channel_id = self._channel_to_id[channel]

                total_data += 1
                if mask[bin_id][channel_id] == 1:
                    unused_data += 1
                mask[bin_id][channel_id] = 1

                write(data, bin_id, channel, row[j], begin_pos)
                original_value[bin_id][channel_id] = row[j]

        # impute missing values

        if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']:
            raise ValueError("impute strategy is invalid")

        if self._impute_strategy in ['normal_value', 'previous']:
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if self._impute_strategy == 'normal_value':
                        imputed_value = self._normal_values[channel]
                    if self._impute_strategy == 'previous':
                        if len(prev_values[channel_id]) == 0:
                            imputed_value = self._normal_values[channel]
                        else:
                            imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        if self._impute_strategy == 'next':
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins-1, -1, -1):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if len(prev_values[channel_id]) == 0:
                        imputed_value = self._normal_values[channel]
                    else:
                        imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        empty_bins = np.sum([1 - min(1, np.sum(mask[i, :])) for i in range(N_bins)])
        self._done_count += 1
        self._empty_bins_sum += empty_bins / (N_bins + eps)
        self._unused_data_sum += unused_data / (total_data + eps)

        if self._store_masks:
            data = np.hstack([data, mask.astype(np.float32)])

        # create new header
        new_header = []
        for channel in self._id_to_channel:
            if self._is_categorical_channel[channel]:
                values = self._possible_values[channel]
                for value in values:
                    new_header.append(channel + "->" + value)
            else:
                new_header.append(channel)

        if self._store_masks:
            for i in range(len(self._id_to_channel)):
                channel = self._id_to_channel[i]
                new_header.append("mask->" + channel)

        new_header = ",".join(new_header)

        return (data, new_header)

    def print_statistics(self):
        print("statistics of discretizer:")
        print("\tconverted {} examples".format(self._done_count))
        print("\taverage unused data = {:.2f} percent".format(100.0 * self._unused_data_sum / self._done_count))
        print("\taverage empty  bins = {:.2f} percent".format(100.0 * self._empty_bins_sum / self._done_count))


class Normalizer:
    def __init__(self, fields=None):
        self._means = None
        self._stds = None
        self._fields = None
        if fields is not None:
            self._fields = [col for col in fields]

        self._sum_x = None
        self._sum_sq_x = None
        self._count = 0

    def _feed_data(self, x):
        x = np.array(x)
        self._count += x.shape[0]
        if self._sum_x is None:
            self._sum_x = np.sum(x, axis=0)
            self._sum_sq_x = np.sum(x**2, axis=0)
        else:
            self._sum_x += np.sum(x, axis=0)
            self._sum_sq_x += np.sum(x**2, axis=0)

    def _save_params(self, save_file_path):
        eps = 1e-7
        with open(save_file_path, "wb") as save_file:
            N = self._count
            self._means = 1.0 / N * self._sum_x
            self._stds = np.sqrt(1.0/(N - 1) * (self._sum_sq_x - 2.0 * self._sum_x * self._means + N * self._means**2))
            self._stds[self._stds < eps] = eps
            pickle.dump(obj={'means': self._means,
                             'stds': self._stds},
                        file=save_file,
                        protocol=2)

    def load_params(self, load_file_path):
        with open(load_file_path, "rb") as load_file:
            if platform.python_version()[0] == '2':
                dct = pickle.load(load_file)
            else:
                dct = pickle.load(load_file, encoding='latin1')
            self._means = dct['means']
            self._stds = dct['stds']

    def transform(self, X):
        if self._fields is None:
            fields = range(X.shape[1])
        else:
            fields = self._fields
        ret = 1.0 * X
        for col in fields:
            ret[:, col] = (X[:, col] - self._means[col]) / self._stds[col]
        return ret

In [None]:
def read_timeseries(ehr_data_dir, task):
    path = f'{ehr_data_dir}/{task}/train/10002430_episode1_timeseries.csv'
    ret = []
    with open(path, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        assert header[0] == "Hours"
        for line in tsfile:
            mas = line.strip().split(',')
            ret.append(np.array(mas))
    return np.stack(ret)

In [None]:
discretizer = Discretizer(timestep=1.0,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

In [None]:

task = "phenotyping"


In [None]:
discretizer_header = discretizer.transform(read_timeseries(ehr_data_dir, task))[1].split(',')

In [None]:
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

In [None]:
cont_channels

In [None]:
normalizer = Normalizer(fields=cont_channels)

In [None]:
normalizer_state = "ph_ts1.0.input_str_previous.start_time_zero.normalizer"
normalizer.load_params(normalizer_state)

In [None]:
import pandas as pd
import os
import numpy as np
import h5py

def _read_timeseries(ts_filename, dataset_dir, time_bound=None):
    ret = []
    with open(os.path.join(dataset_dir, ts_filename), "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        assert header[0] == "Hours"
        for line in tsfile:
            mas = line.strip().split(',')
            if time_bound is not None:
                t = float(mas[0])
                if t > time_bound + 1e-6:
                    break
            ret.append(mas)  # Keep as list of strings for flexibility
    return (ret, header)

def preprocess_and_save(discretizer, normalizer, listfile, base_dataset_dir, output_dir):
    df = pd.read_csv(listfile)
    for idx, row in df.iterrows():
        filename = row['stay']
        
        # Check for the file's existence in both 'train' and 'test' directories
        train_dir = os.path.join(base_dataset_dir, 'train')
        test_dir = os.path.join(base_dataset_dir, 'test')
        if os.path.exists(os.path.join(train_dir, filename)):
            dataset_dir = train_dir
        elif os.path.exists(os.path.join(test_dir, filename)):
            dataset_dir = test_dir
        else:
            print(f"File not found in both train and test directories: {filename}")
            continue  # Skip this file if not found
        
        data, header = _read_timeseries(filename, dataset_dir, None)
        
        print(f"Original length of {filename}: {len(data)}")
        
        # Preprocess and save as before
        data_preprocessed = discretizer.transform(data, header=header)[0]
        data_normalized = normalizer.transform(data_preprocessed)
        
        print(f"Length after preprocessing {filename}: {len(data_normalized)}")
        
        output_path = os.path.join(output_dir, os.path.basename(filename).replace('.csv', '.h5'))
        with h5py.File(output_path, 'w') as hf:
            hf.create_dataset('data', data=data_normalized)     

In [None]:
class EHRdatasetFinetune(Dataset):
    def __init__(self, listfile, preprocessed_dir, classes):
        self.data_files = []
        self.labels = []
        
        df = pd.read_csv(listfile)
        for idx, row in df.iterrows():
            # Load preprocessed data
            preprocessed_file = os.path.join(preprocessed_dir, os.path.basename(row['stay']).replace('.csv', '.h5'))
            if os.path.exists(preprocessed_file):
                self.data_files.append(preprocessed_file)
                self.labels.append(row[classes].values)  # Adjust based on your actual label handling
    
    def __getitem__(self, index):
        try:
            with h5py.File(self.data_files[index], 'r') as hf:
                data = hf['data'][:]
            label = self.labels[index]  # Adjust this based on how you handle labels
            return data, label
        
        except KeyError as e:
            print(f"Error loading data from file {self.data_files[index]}: {e}")
            raise
    
    def __len__(self):
        return len(self.data_files)

In [None]:
dataset_dir = os.path.join(ehr_data_dir, f'{task}')
train_listfile = "finetune_pheno_train_ts_df.csv"
output_dir = 'YOUR_PATH/timeseries_data_finetune_pheno'


preprocess_and_save(discretizer, normalizer, train_listfile , dataset_dir, output_dir)

In [None]:
test_listfile = "finetune_pheno_test_ts_df.csv"
output_dir = 'YOUR_PATH/timeseries_data_finetune_pheno'


preprocess_and_save(discretizer, normalizer,test_listfile , dataset_dir, output_dir)

In [None]:
val_listfile = "finetune_pheno_val_ts_df.csv"
output_dir = 'YOUR_PATH/timeseries_data_finetune_pheno'

preprocess_and_save(discretizer, normalizer,val_listfile , dataset_dir, output_dir)

In [None]:
train_dataset = EHRdatasetFinetune(train_listfile, output_dir, CLASSES)
test_dataset = EHRdatasetFinetune(test_listfile, output_dir, CLASSES)
val_dataset = EHRdatasetFinetune(val_listfile, output_dir, CLASSES)

In [None]:
transposed_labels = list(zip(*val_dataset.labels))

# Count the number of 1's in each column
ones_count_per_index = [sum(column) for column in transposed_labels]

ones_count_per_index

In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch


def my_collate(batch):
    data = [torch.tensor(item[0], dtype=torch.float) for item in batch]
    labels = [item[1] for item in batch]
    
    # Diagnostic print: Sequence lengths before padding
    print("Before padding:", [len(x) for x in data])
    
    data_padded = pad_sequence(data, batch_first=True, padding_value=0.0)
    lengths = torch.tensor([len(x) for x in data], dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)
    
    # Diagnostic print: Sequence lengths after padding (should be the same within a batch)
    print("After padding:", data_padded.shape[1])
    
    return data_padded, lengths, labels


In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=my_collate, shuffle=False)

In [None]:
#this is done first for train so one-hot-encoding is the same for val and for testing as it is for train
unique_categories, fixed_length = determine_categories_and_length(ehr_data_dir, task, "train", train_data, config)
encoders = create_encoders(unique_categories)
categorical_cols = ['admission_type', 'admission_location', 'insurance', 'language', 'marital_status', 'ethnicity', 'gender', 'anchor_year_group']
imputer_continuous = SimpleImputer(strategy='mean')
imputer_categorical = SimpleImputer(strategy='most_frequent')
#train_data[continuous_cols] = imputer_continuous.fit_transform(train_data[continuous_cols])
df = train_data
for col in categorical_cols:
    print(col)
    transform = imputer_categorical.fit_transform(train_data[[col]])
    print(np.array(transform).shape)
    print(transform)
    print(transform.ravel())
    print(transform.ravel().shape)
    print(df[col])
    df[col] = transform.ravel()
    #end

In [None]:
#this is only relevant for the demographic modality
_, demo_train_encoders = preprocess_demo(train_data_contrast, None)

In [None]:
demo_train_encoders

In [None]:
demo_train_data, _ = preprocess_demo(train_data, demo_train_encoders)
demo_dataset = create_demo_dataset(demo_train_data, train_data[CLASSES].values)
save_dataset(demo_dataset, f'{ehr_data_dir}/{task}/train_finetune_demo_dataset.pkl')

In [None]:
demo_val_data, _ = preprocess_demo(val_data, demo_train_encoders)
demo_test_data, _ = preprocess_demo(test_data, demo_train_encoders)

demo_dataset_val = create_demo_dataset(demo_val_data, val_data[CLASSES].values)
demo_dataset_test = create_demo_dataset(demo_test_data, test_data[CLASSES].values)


In [None]:
save_dataset(demo_dataset_val, f'{ehr_data_dir}/{task}/val_finetune_demo_dataset.pkl')
save_dataset(demo_dataset_test, f'{ehr_data_dir}/{task}/test_finetune_demo_dataset.pkl')

In [None]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
img_dataset = create_image_dataset(train_data,CLASSES, 'dicom_id_path', transform)
save_dataset(img_dataset, f'{ehr_data_dir}/{task}/train_finetune_img_dataset.pkl')

In [None]:
img_dataset = create_image_dataset(val_data, CLASSES, 'dicom_id_path', transform)
save_dataset(img_dataset, f'{ehr_data_dir}/{task}/val_finetune_img_dataset.pkl')

In [None]:
img_dataset = create_image_dataset(test_data, CLASSES, 'dicom_id_path', transform)
save_dataset(img_dataset, f'{ehr_data_dir}/{task}/test_finetune_img_dataset.pkl')

In [None]:
text_dataset = create_text_dataset(train_data, CLASSES, "text_x")
save_dataset(text_dataset, f'{ehr_data_dir}/{task}/train_finetune_text_ds_dataset.pkl')


In [None]:
text_dataset_val = create_text_dataset(val_data,CLASSES, "text_x")
save_dataset(text_dataset_val, f'{ehr_data_dir}/{task}/val_finetune_text_ds_dataset.pkl')

text_dataset_test = create_text_dataset(test_data,CLASSES, "text_x")
save_dataset(text_dataset_test, f'{ehr_data_dir}/{task}/test_finetune_text_ds_dataset.pkl')

In [None]:
text_dataset_rad = create_text_dataset(train_data,CLASSES, "text_y")
save_dataset(text_dataset_rad, f'{ehr_data_dir}/{task}/train_finetune_text_rad_dataset.pkl')


In [None]:
text_dataset_val = create_text_dataset(val_data,CLASSES, "text_y")
save_dataset(text_dataset_val, f'{ehr_data_dir}/{task}/val_finetune_text_rad_dataset.pkl')

text_dataset_test = create_text_dataset(test_data,CLASSES, "text_y")
save_dataset(text_dataset_test, f'{ehr_data_dir}/{task}/test_finetune_text_rad_dataset.pkl')

In [None]:
save_dataset(train_dataset, f'{ehr_data_dir}/{task}/train_finetune_ts_dataset.pkl')
save_dataset(val_dataset, f'{ehr_data_dir}/{task}/val_finetune_ts_dataset.pkl')
save_dataset(test_dataset, f'{ehr_data_dir}/{task}/test_finetune_ts_dataset.pkl')