In [13]:
from pathlib import Path
from utils.data_loader import load_data

In [61]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [92]:
class DSet(Dataset):
    def __init__(self, dset_path):
        self.x_data, self.t_data = load_data(dset_path)
    
    def __len__(self):
        pass
    
    def __getitem__(self, idx):
        pass
        

In [93]:
class DSetMIMIC(DSet):
    def __init__(self, dset_path):
        super().__init__(dset_path)
    
    def __len__(self):
        return len(self.x_data)
    
    def __getitem__(self, idx):
        return (self.x_data[idx], self.t_data[idx])

In [65]:
device = torch.device('cpu')

In [89]:
def collate_fn_categorical_marker(xt_tuples):
    """
    Input: list of (x_i,t_i) tuples
    Output: tensors (x, t, m) of shapes (BS, max_len), (BS, max_len, t_dim), (BS, max_len)
    Note: the x_tensor will contain the class labels, and not one-hot representation
    """
    BS = len(xt_tuples)
    
    x_data = [x for x,t in xt_tuples]
    t_data = [t for x,t in xt_tuples]
    
    t_dim = t_data[0].shape[1]
    
    seq_len = [len(t) for t in t_data]
    max_seq_len = max(seq_len)
    
    x_tensor = np.zeros((BS, max_seq_len))
    t_tensor = np.zeros((BS, max_seq_len, t_dim))
    mask_tensor = np.zeros((BS, max_seq_len))
    
    for idx in range(BS):
        x_tensor[idx, :seq_len[idx] ] = x_data[idx].flatten()
        t_tensor[idx, :seq_len[idx] ] = t_data[idx]
        mask_tensor[idx, :seq_len[idx]] = 1.
        
    x_tensor, t_tensor, mask_tensor = torch.tensor(x_tensor).long().to(device), torch.tensor(t_tensor).float().to(device), torch.tensor(mask_tensor).float().to(device)
    return (x_tensor, t_tensor, mask_tensor)


def collate_fn_real_marker(xt_tuples):
    """
    Input: list of (x_i,t_i) tuples
    Output: tensors (x, t, m) of shapes (BS, max_len, x_dim), (BS, max_len, t_dim), (BS, max_len)
    """
    BS = len(xt_tuples)
    
    x_data = [x for x,t in xt_tuples]
    t_data = [t for x,t in xt_tuples]
    
    x_dim = x_data[0].shape[1]
    t_dim = t_data[0].shape[1]
    
    seq_len = [len(t) for t in t_data]
    max_seq_len = max(seq_len)
    
    x_tensor = np.zeros((BS, max_seq_len))
    t_tensor = np.zeros((BS, max_seq_len, t_dim))
    mask_tensor = np.zeros((BS, max_seq_len))
    
    for idx in range(BS):
        x_tensor[idx, :seq_len[idx] ] = x_data[idx].flatten()
        t_tensor[idx, :seq_len[idx] ] = t_data[idx]
        mask_tensor[idx, :seq_len[idx]] = 1.
        
    x_tensor, t_tensor, mask_tensor = torch.tensor(x_tensor).long().to(device), torch.tensor(t_tensor).float().to(device), torch.tensor(mask_tensor).float().to(device)
    return (x_tensor, t_tensor, mask_tensor)

In [94]:
class DLoaderMIMIC(DataLoader):
    def __init__(self, dset, bs=16):
        super().__init__(dset, batch_size=bs, collate_fn=collate_fn_categorical_marker)

In [86]:
data_dir = Path('../data')
dset_train_path = data_dir / 'mimic_train.pkl'

In [95]:
dset = DSetMIMIC(dset_train_path)
dloader = DLoaderMIMIC(dset)

In [96]:
len(dloader.dataset)

951

In [88]:
# dset[1]
for x,t,m in dloader:
    for b_idx in range(len(x)):
        print(x[b_idx], t[b_idx], m[b_idx])
    break

tensor([88, 63, 63, 88,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]) tensor([[  1.2770, 127.8160,   5.7198],
        [  1.3302, 559.5646,   6.2364],
        [  1.3654, 369.3514,   1.7805],
        [  1.4246, 622.7903,   2.0215],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000]]) tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([80, 80, 79,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]) tensor([[  1.6435, 127.8160,   9.895