In [None]:
import os
import sys
import time
import pickle
import itertools
import numpy as np
import pandas as pd

from itertools import islice
import matplotlib.pyplot as plt
from collections import Counter
from multiprocessing import Pool
from datetime import datetime, timedelta
from astropy.table import QTable, Table, Column
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler


import torch
torch.set_num_threads(6)
from torch.utils.data import DataLoader

### Loading Excel Sheet

In [None]:
xls = pd.ExcelFile("path to input excel sheet") 
data_sheet1 = xls.parse(0)
data_sheet1= data_sheet1[:1913]
data_sheet1

### Getting  patients' IDs

In [None]:
pat_id = data_sheet1["Pat-ID"]
print(np.unique(pat_id.values))

### Getting Label

In [None]:
is_regular = data_sheet1["Regular day_infectious [1 = 'regular day', 0 = 'irregular day']"]

### Getting date and making a dictionary  with key=[date, ID] and value=is_regular

In [None]:
date = data_sheet1["Date"]
print(date)

In [None]:
Dict = {}
for i in range (len(date)):
    Dict[str([date[i].date(), pat_id[i]])]= is_regular[i]

### Getting Cohort [0 = Inpatient, 1 = Outpatient] and making a dictionary  with key=[date, ID] and value=cohort

In [None]:
cohort = data_sheet1["Cohort [0 = Inpatient, 1 = Outpatient]"]

In [None]:
Dict_cohort = {}
for i in range (len(date)):
    Dict_cohort[str([date[i].date(), pat_id[i]])]= cohort[i]

### Functions   

In [None]:
def num_samples(all_intervals):
    c = 0
    for i in range(len(all_intervals)):
        c +=1
    return c

def max_min_len(all_intervals):
    length_lst = []
    for i in range(len(all_intervals)):
        length_lst.append((all_intervals[i][0]).shape[0])
    return  [min(length_lst), max(length_lst)] 

def len_samples(all_intervals):
    length_lst = []
    for i in range(len(all_intervals)):
        length_lst.append((all_intervals[i][0]).shape[0])
    return  length_lst

def list_start_hour(all_intervals):
    start_hour = []
    for i in range(len(all_intervals)):
        start_hour.append((all_intervals[i][1]).hour)
    return start_hour    

### Reading Patients Data

In [None]:
file_path = 'path to .pkl files'

InPat_Pkl_list = []
for root, dirs, files in os.walk(file_path):
    for name in files:
        InPat_Pkl_list.append(name)
            
sorted(InPat_Pkl_list)

In [None]:
length_lst = []
values_lst = []
max_len_lst = []
min_len_lst = []
estimated_total_num = []
s_h_in = []
for i in range(len(InPat_Pkl_list)):
    with open(os.path.join(file_path, InPat_Pkl_list[i]), 'rb') as fileobj:
        [all_intervals, num_skipped_intervals] = pickle.load(fileobj)
        values_lst.append(num_samples(all_intervals))
        max_len_lst.append(max_min_len(all_intervals)[1])
        min_len_lst.append(max_min_len(all_intervals)[0])
        length_lst.append(len_samples(all_intervals))
        s_h_in.append(list_start_hour(all_intervals))
        estimated_total_num.append((np.array(len_samples(all_intervals)).sum())//3600)
        
print(f'Total Number of Intervals:{np.array(values_lst).sum()}')  

In [None]:
print(f'Estimated Total Number of Hours:{np.array(estimated_total_num).sum()}')

### Plotting Number of Samples per Hour

In [None]:
s_h_all = s_h_in
s_h_lst_merged = np.array(list(itertools.chain(*s_h_all)))
num = []
for i in range(0,24):
    num.append((1.0*(s_h_lst_merged==i)).sum())
    
h_lst = [f'0{i}:00-{i+1}:00' if i<10 else f'{i}:00-{i+1}:00' for i in range(0,24)]    

# Define plot space
fig, ax = plt.subplots(figsize=(50, 20))

# Create bar plot
ax.bar(h_lst, num)

ax.set_xticklabels(h_lst, rotation=45, ha="right", fontsize=30) 
plt.rc('ytick', labelsize=45)
plt.ylabel("Number of samples per hour", fontsize=40)
plt.show()

In [None]:
keys_lst = []
for i in range(len(InPat_Pkl_list)):
    keys_lst.append(InPat_Pkl_list[i].split('_')[1])

indices = np.argsort(keys_lst)
indices    
sorted_keys_lst = sorted(keys_lst)    
sorted_estimated_total_num = [estimated_total_num[i] for i in indices]   
t = Table([sorted_keys_lst, sorted_estimated_total_num], names=('Pat-ID', 'Total number of hours'))
print(t)


In [None]:
length_lst_merged = list(itertools.chain(*length_lst))
plt.figure()
plt.hist(length_lst_merged)
plt.xlabel("length")
plt.ylabel("frequencies", fontsize=20)
plt.show()

In [None]:
threshold = 2000.
length_ = []
for i in range(len(length_lst)):
    length_.append(int((1.*(np.array(length_lst[i])>=threshold)).sum()))
    
print(f'Nubmer of Samples whose Length>={threshold}:', np.array(length_).sum()) 
print('Eligible Intervals(%):', 100*((np.array(length_).sum())/(np.array(values_lst).sum())))

### adding is_regular  variable to data and saving them

In [None]:
L = 3000
all_data_regular = []
all_data_irregular = []
for i in range(len(InPat_Pkl_list)):
    with open(os.path.join(file_path, InPat_Pkl_list[i]), 'rb') as fileobj:
        [all_intervals, num_skipped_intervals] = pickle.load(fileobj)
    for i in range(len(all_intervals)):
        intervals = torch.tensor(all_intervals[i][0]) 
        s_t = all_intervals[i][1]
        e_t = all_intervals[i][2]
        ID = int(all_intervals[i][3])
        start_time = [s_t.year, s_t.month, s_t.day, s_t.hour, s_t.minute]
        end_time = [e_t.year, e_t.month, e_t.day, e_t.hour, e_t.minute]
        if str([s_t.date(), ID]) in Dict.keys():
            is_regular = Dict[str([s_t.date(), ID])]
            length = len(all_intervals[i][0])
            if length>=L:
                if is_regular==1:
                    all_data_regular.append([all_intervals[i][0][:L, :],
                                     ID, is_regular, np.array(start_time), np.array(end_time)])
                if is_regular==0:
                    all_data_irregular.append([all_intervals[i][0][:L, :],
                                     ID, is_regular, np.array(start_time), np.array(end_time)])    
print(len(all_data_irregular))    
print(len(all_data_regular)) 

In [None]:
loader = DataLoader(all_data_regular, batch_size=len(all_data_regular), shuffle=True, num_workers=0)
x, ID, is_regular, _, _ = next(iter(loader))

print(sorted(Counter(np.array(ID)).items())) 

In [None]:
loader = DataLoader(all_data_irregular, batch_size=len(all_data_irregular), shuffle=True, num_workers=0)
x, ID, is_regular, _, _ = next(iter(loader))

print(sorted(Counter(np.array(ID)).items()))

### Splitting Regular Data into train/test (90/10%)

In [None]:
labels = [all_data_regular[i][-4] for i in range(len(all_data_regular))]
train_set, test_set = train_test_split(all_data_regular, test_size=0.1, shuffle=True, stratify=labels)
print('len(train_set):', len(train_set))
print('len(test_set):', len(test_set))


train_set_label = [train_set[i][-4] for i in range(len(train_set))]
print(sorted(Counter(train_set_label).items()))


test_set_label = [test_set[i][-4] for i in range(len(test_set))] 
print(sorted(Counter(test_set_label).items()))

In [None]:
path = "path to saved data"

with open(f'{path}/train_set.pickle', 'wb') as output:
    pickle.dump(train_set, output)

with open(f'{path}/test_set.pickle', 'wb') as output:
    pickle.dump(test_set, output) 
    
with open(f'{path}/ood_set.pickle', 'wb') as output:
    pickle.dump(all_data_irregular, output)  