In [1]:
import os
import numpy as np
import pickle
import json
import math
from reader import InHospitalMortalityReader, PhenotypingReader, LengthOfStayReader, DecompensationReader
from tqdm import tqdm

In [2]:
with open('resources/channel_info.json') as f:
    series_channel_info = json.load(f)

with open('resources/discretizer_config.json') as f:
    series_config = json.load(f)
    id_to_channel = series_config['id_to_channel']
    is_categorical_channel = series_config['is_categorical_channel']
    normal_values = series_config['normal_values']
    possible_values = series_config['possible_values']

In [3]:
def read_chunk(reader, chunk_size):
    data = {}
    for i in range(chunk_size):
        ret = reader.read_next()
        for k, v in ret.items():
            if k not in data:
                data[k] = []
            data[k].append(v)
    data["header"] = data["header"][0]
    return data

In [4]:
period_length = 48
path = '/data1/yzh/codebase/data/MIMIC-III/mortality/'

data_all = []
mask_all = []
label_all = []
name_all = []
for mode in ['train', 'val', 'test']:
    reader = InHospitalMortalityReader(dataset_dir=os.path.join(path, 'train' if mode != 'test' else 'test'),
            listfile=os.path.join(path, mode + '_listfile.csv'), period_length=period_length)
    N = reader.get_number_of_examples()
    ret = read_chunk(reader, N)
    data = ret["X"]
    ts = ret["t"]
    labels = ret["y"]
    names = ret["name"]
    label_all += labels
    name_all += names
    for patient, name in zip(data, names):
        data_patient = np.zeros(shape=(len(id_to_channel), period_length), dtype=np.float32)
        mask_patient = np.zeros(shape=(len(id_to_channel), period_length), dtype=np.float32)
        last_time = -1
        for row in patient:
            time = int(float(row[0]))
            if time == period_length:
                time -= 1
            if time > period_length:
                raise ValueError('This should not happen')
                break
            for index in range(len(row) - 1):
                value = row[index + 1]
                if value == '':
                    # continue
                    if mask_patient[index, time] == 0 and time - last_time > 0:
                        if last_time >= 0:
                            data_patient[index, last_time + 1:time + 1] = data_patient[index, last_time]
                        else:
                            if is_categorical_channel[id_to_channel[index]]:
                                data_patient[index, last_time + 1:time + 1] = series_channel_info[id_to_channel[index]]['values'][normal_values[id_to_channel[index]]]
                            else:
                                data_patient[index, last_time + 1:time + 1] = float(normal_values[id_to_channel[index]])
                else:
                    mask_patient[index, time] = 1
                    if is_categorical_channel[id_to_channel[index]]:
                        data_patient[index, time] = series_channel_info[id_to_channel[index]]['values'][value]
                    else:
                        data_patient[index, time] = float(value)
            last_time = time
        if last_time < period_length - 1:
            data_patient[:, last_time + 1:period_length] = data_patient[:, last_time, None]
        data_all.append(data_patient.transpose(-1, -2))
        mask_all.append(mask_patient.transpose(-1, -2))
print(len(data_all), len(mask_all), len(label_all), len(name_all))

21139 21139 21139 21139


In [5]:
data_all = np.array(data_all)
mask_all = np.array(mask_all)
data_all_concat = np.concatenate(data_all, axis=0)
x_masked = np.ma.masked_array(data_all_concat, np.concatenate(mask_all, axis=0) == 0)
mean = np.mean(x_masked, 0)
std = np.std(x_masked, 0)
print(mean, std)
data_normalized = np.where(mask_all == 1, (data_all - mean.reshape(1, 1, -1)) / std.reshape(1, 1, -1), 0)

[0.13354037267080746 61.46690821598523 0.5394059277377888
 3.119367646871375 5.290858060882255 11.618071512938903 3.1806345816977584
 143.21973519841134 86.29870307167235 168.72015948168453 78.73721156458411
 97.69931626649706 19.29791803628463 120.31002625862565 37.03954583770163
 83.27327296371479 7.28222605621633] [0.3401578185750751 250.4033844396291 0.20066273178085373
 1.262167663202079 1.4043288333340955 3.908595655788102 1.8973968540861121
 69.22505471875131 19.16908521820205 15.020152083998523 154.81962785121664
 1031.0258674132324 6.63096314003484 25.231333122285367 9.536494024362197
 26.057987256006786 2.217912317350937]


In [6]:
# pickle.dump((data_all.tolist(), label_all, np.array(), mask_all.tolist(), name_all), open('mortality.pkl', 'wb'))
pickle.dump((data_normalized.tolist(), label_all, mask_all.tolist(), name_all), open('data.pkl', 'wb'))

In [7]:
cnt = 0
cnt1 = 0
for i in range(len(mask_all)):
    for j in range(len(mask_all[i])):
        cnt += sum(mask_all[i][j])
        cnt1 += len(mask_all[i][j])
print('Observed Rate:', cnt / cnt1)

Observed Rate: 0.43294860164606075


In [None]:
sum(label_all) / len(label_all)

0.13231467902928237

In [9]:
lens = []
for i in range(len(data_normalized)):
    lens.append(len(data_normalized[i]))
print(len(data_normalized), sum(lens), min(lens), max(lens))

21139 1014672 48 48
