In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import os
from rnn_utils import split_dataset, MYCOLLATE
from rnn_utils import ICareDataset_fast, ICareCOLLATE_fast
from rnn_utils import RNN, train_one_epoch, train_one_epochV2, eval_model, compute_loss, get_prediction_thresholds, outs2df,compute_metrics
from rnn_utils import gen_mask_padded_loss

import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler

from sklearn.model_selection import ParameterGrid, ParameterSampler

import numpy as np
import pandas as pd
import json

from tqdm.notebook import tqdm

from config import Settings; settings = Settings()

from ICDMappings import ICDMappings
icdmap = ICDMappings()

import wandb

idx = pd.IndexSlice

# Reproducibility

In [3]:
# Reproducibility
seed = settings.random_seed

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x7f7fbf559bf0>

# Create dataset

In [4]:
grouping = 'ccs'
batch_size=64

In [5]:
ccs_universe = list(icdmap.icd9_3toccs.data.keys())
dataset_folder = '/home/debian/Simao/master-thesis/data/model_ready_dataset/icare2021_diag_A301'
dataset = ICareDataset_fast(os.path.join(dataset_folder,'dataset.json'),
                            ccs_universe,
                            grouping
                          )

train_dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            collate_fn=ICareCOLLATE_fast(),
                            sampler=RandomSampler(dataset.train_indices)
                           )
val_dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            collate_fn=ICareCOLLATE_fast(),
                            sampler=RandomSampler(dataset.val_indices)
                           )
test_dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            collate_fn=ICareCOLLATE_fast(),
                            sampler=RandomSampler(dataset.test_indices)
                           )

# Nº batches
len(train_dataloader)
len(val_dataloader)
len(test_dataloader)

processing each patient


  0%|          | 0/262811 [00:00<?, ?it/s]

2875

616

616

In [34]:
print('Distribution of target size (nº of diagnostics in target window)')
all_lengths = {}
for e in tqdm(dataset.raw_data):
    lengths = [len(i) for i in dataset.raw_data[e]['ccs']['targets']]
    for l in lengths:
        if l not in all_lengths:
            all_lengths[l] = 1
        else:
            all_lengths[l] +=1
            
(pd.Series(all_lengths).sort_index() / pd.Series(all_lengths).sum())[:15].round(2)

  0%|          | 0/262811 [00:00<?, ?it/s]

0     0.23
1     0.18
2     0.13
3     0.10
4     0.07
5     0.06
6     0.04
7     0.03
8     0.03
9     0.02
10    0.02
11    0.01
12    0.01
13    0.01
14    0.01
dtype: float64

# Train

In [38]:
n_labels = input_size = next(iter(train_dataloader))['target_sequences']['sequence'].shape[2]

criterion = torch.nn.BCEWithLogitsLoss(reduction='none')

In [36]:
hyperparameters = {
    'hidden_size':[100,150],
    'num_layers':[1],
    'lr':[0.01,0.02],
    'model':['rnn','gru']
    
}
meta_parameters = {
    'epochs':1
}

params = ParameterGrid(hyperparameters)
print(f'params:',len(params))

#random_params = ParameterSampler(params.param_grid,n_iter=len(params)-1,random_state=231)
#next(iter(random_params))

params: 8


# Test

In [None]:
from torch import nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence
import torch.nn.functional as F

In [None]:
import torchmetrics

In [None]:
from torchmetrics.functional import recall,precision,f1_score

In [None]:
import pandas
import numpy

In [None]:
param_set = next(iter(params))
config = param_set
model = RNN(input_size=input_size,
              hidden_size=config['hidden_size'],
              num_layers=config['num_layers'],
              n_labels=n_labels,
              model=config['model'])

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
criterion = torch.nn.BCEWithLogitsLoss(reduction='none')

# Improved masked stuff

In [140]:
def gen_mask_padded_loss(lengths,loss_shape):
    """
    This method creates a mask to later perform loss.masked_fill_(mask,0)
    
    Note: this method is called at each batch so it has been optimized to some extent
    sacrificing some readibility. Hence it may be criptic to understand everything.
    
    Parameters
    ----------
    lengths: list, shape = (batch_size,)
        list with the actual length of each sequence in the batch
    
    loss_shape: tuple/list, shape=(batch_size, max_seq_length, n_labels)
        shape of the loss tensor on a given batch
    """
    idx = pd.IndexSlice
    # i.e. [ (pos_in_batch, seq_size), ...]
    # e.g. imagine batch of two sequences. first has size 2 and second size 6. we get [(0,2),(1,6)]
    seq_size_per_seq = list(zip(range(0,len(lengths)),lengths.numpy()))

    # i.e. [ (pos_in_batch,seq_index),(pos_in_batch,seq_index),...]
    # e.g. imagine batch of two seqs. first has size 2, second has size 1. produces: [[(0,0),(0,1)],[(1,0)]]
    real_seq_pos_per_seq = [list(zip([a[0]]*a[1],range(0,a[1]))) for a in seq_size_per_seq]
    # just flattens the previous list.
    # i.e. (taking the previous example) produces: [(0,0),(0,1),(1,0)]
    real_seq_pos_per_seq = [item for seq in real_seq_pos_per_seq for item in seq] 

    # create a mask that initially has everything as True
    res = (pd.DataFrame(np.ones(shape=(loss_shape[0]*loss_shape[1],loss_shape[2])))
           .assign(seq=np.array([[seq] * loss_shape[1] for seq in range(len(lengths))]).reshape((-1,1)),
                   index=list(range(0,loss_shape[1]))*loss_shape[0]
                  )
           .set_index(['seq','index']) # index is mean to help in the .loc after this cascade
           .astype(bool) # all values of dataframe are set to False now.
          )
    
    # set to False the values we don't want to change (aka: values that are not paddings)
    res.loc[idx[real_seq_pos_per_seq],:] = False

    # stack from (batch_size*max_seq_length,n_labels) to (batch_size,max_seq_length,n_labels)
    mask = torch.tensor(res.to_numpy().reshape(loss_shape))
    
    # now the mask has the same shape as the loss and ready to be applied on torch.masked_fill_
    return mask

In [142]:
from torch.nn.utils.rnn import pad_packed_sequence

In [143]:
for batch in val_dataloader:
    history_sequences, target_sequences = batch['train_sequences']['sequence'], batch['target_sequences']['sequence']
    
    outs = model(history_sequences,target_sequences)
    
    sequences,lengths = pad_packed_sequence(history_sequences,batch_first=True)
    break

In [147]:
lengths.max()
lengths

tensor(40)

tensor([ 9, 24, 32,  5,  5,  5, 10,  4,  7,  6,  3,  1, 13,  5,  5,  7,  3,  2,
         5,  7,  3,  5,  8,  4,  7,  2,  3, 12,  6,  4,  2, 19,  2,  8,  5,  4,
        40,  4,  9,  2, 14, 13,  2,  2,  2,  4, 19, 13,  1,  4, 13,  2, 12,  2,
         7, 19,  2,  1,  5,  3,  5,  4,  6,  3])

In [150]:
seq_size_per_seq = list(zip(range(0,len(lengths)),lengths.numpy()))
real_seq_pos_per_seq = [list(zip([a[0]]*a[1],range(0,a[1]))) for a in seq_size_per_seq]
# just flattens the previous list.
# i.e. (taking the previous example) produces: [(0,0),(0,1),(1,0)]
real_seq_pos_per_seq = [item for seq in real_seq_pos_per_seq for item in seq] 

In [159]:
a = torch.ones(size=(33,10,10))
a[[0,1,5],:] = 0
a

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1., 

In [156]:
outs.view(-1,283)#[
[(e[0] * lengths.max() + e[1]).item() for e in real_seq_pos_per_seq]

tensor([[ -8.3227,  -7.1623,  -6.1103,  ..., -11.4194,  -8.4300, -10.4651],
        [ -7.6110,  -5.9118,  -5.0685,  ..., -12.3436,  -8.8899, -12.1421],
        [ -8.0529,  -6.9389,  -5.7696,  ..., -13.1077,  -9.6372, -12.8298],
        ...,
        [ -0.1980,  -0.0790,  -0.1146,  ...,  -0.1642,  -0.0840,  -0.2897],
        [ -0.1980,  -0.0790,  -0.1146,  ...,  -0.1642,  -0.0840,  -0.2897],
        [ -0.1980,  -0.0790,  -0.1146,  ...,  -0.1642,  -0.0840,  -0.2897]],
       grad_fn=<ViewBackward0>)

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 120,
 121,
 122,
 123,
 124,
 160,
 161,
 162,
 163,
 164,
 200,
 201,
 202,
 203,
 204,
 240,
 241,
 242,
 243,
 244,
 245,
 246,
 247,
 248,
 249,
 280,
 281,
 282,
 283,
 320,
 321,
 322,
 323,
 324,
 325,
 326,
 360,
 361,
 362,
 363,
 364,
 365,
 400,
 401,
 402,
 440,
 480,
 481,
 482,
 483,
 484,
 485,
 486,
 487,
 488,
 489,
 490,
 491,
 492,
 520,
 521,
 522,
 523,
 524,
 560,
 561,
 562,
 563,
 564,
 600,
 601,
 602,
 603,
 604,
 605,
 606,
 640,
 641,
 642,
 680,
 681,
 720,
 721,
 722,
 723,
 724,
 760,
 761,
 762,
 763,
 764,
 765,
 766,
 800,
 801,
 802,
 840,
 841,
 842,
 843,
 844,
 880,
 881,
 882,
 883,
 884,
 885,
 886,
 887,
 920,
 921,
 922,


In [136]:
a = torch.randint(0,10,size=(10,10))

In [139]:
a

tensor([[5, 3, 2, 1, 1, 7, 7, 9, 1, 2],
        [7, 5, 2, 5, 2, 0, 2, 8, 2, 0],
        [4, 8, 9, 7, 0, 1, 8, 1, 3, 7],
        [6, 8, 0, 2, 1, 4, 7, 4, 6, 3],
        [1, 4, 0, 2, 7, 2, 8, 9, 4, 4],
        [8, 9, 6, 6, 9, 4, 9, 3, 9, 8],
        [8, 2, 1, 2, 6, 6, 1, 6, 8, 9],
        [7, 8, 7, 8, 0, 7, 6, 3, 3, 2],
        [2, 3, 3, 3, 5, 1, 0, 2, 3, 0],
        [6, 2, 6, 6, 1, 1, 2, 7, 0, 7]])

In [138]:
a[[1,3,5],:]

tensor([[7, 5, 2, 5, 2, 0, 2, 8, 2, 0],
        [6, 8, 0, 2, 1, 4, 7, 4, 6, 3],
        [8, 9, 6, 6, 9, 4, 9, 3, 9, 8]])

In [None]:
336, 485

In [69]:
for batch in val_dataloader:
    history_sequences, target_sequences = batch['train_sequences']['sequence'], batch['target_sequences']['sequence']
    
    outs = model(history_sequences,target_sequences)
    
    break
    
    

In [95]:
diag_1 = 79
diag_2 = 77

In [105]:
fpr_1, tpr, ths = f.roc(outs.view(-1,283)[:,diag_1],target_sequences.view(-1,283)[:,diag_1])
fpr_2, tpr, ths = f.roc(outs.view(-1,283)[:,diag_2],target_sequences.view(-1,283)[:,diag_2])

In [106]:
target_sequences.view(-1,283)[:,diag_1].sum()
target_sequences.view(-1,283)[:,diag_2].sum()

tensor(24.)

tensor(25.)

In [107]:
outs.view(-1,283)[:,diag_1]
outs.view(-1,283)[:,diag_2]

tensor([-4.1155, -3.4331, -3.0590,  ..., -0.1427, -0.1427, -0.1427],
       grad_fn=<SelectBackward0>)

tensor([-3.9595, -3.6251, -3.6365,  ..., -0.2735, -0.2735, -0.2735],
       grad_fn=<SelectBackward0>)

In [108]:
fpr_1[:5]

tensor([0.0000, 0.8435, 0.8438, 0.8442, 0.8445])

In [109]:
fpr_2[:5]

tensor([0.0000, 0.8438, 0.8444, 0.8447, 0.8451])

In [114]:
fpr_1.shape
fpr_2.shape

torch.Size([494])

torch.Size([494])

In [131]:
res = f.roc(outs.view(-1,283),target_sequences.view(-1,283),num_classes=283)[0]

In [133]:
res[0].shape

torch.Size([494])

In [43]:
from torchmetrics import functional as f

def compute_metricsV3(model, dataloader):
    
    
    recall = list()
    precision = list()
    f1 = list()
    for i,batch in tqdm(enumerate(iter(dataloader))):
        history_sequences, target_sequences = batch['train_sequences']['sequence'],batch['target_sequences']['sequence']
        outs = model(history_sequences,target_sequences)
        
        recall.append(f.recall(outs.view(-1,283),target_sequences.int().view(-1,283),top_k=30,average='samples') * target_sequences.shape[0] * target_sequences.shape[1] / target_sequences.any(dim=-1).sum())
        precision.append(f.precision(outs.view(-1,283),target_sequences.int().view(-1,283),top_k=30,average='samples') * target_sequences.shape[0] * target_sequences.shape[1] / target_sequences.any(dim=-1).sum())
        f1.append(f.f1_score(outs.view(-1,283),target_sequences.int().view(-1,283),top_k=30,average='samples') * target_sequences.shape[0] * target_sequences.shape[1] / target_sequences.any(dim=-1).sum())
    
    return {'recall@30':np.mean(recall),
            'precision@30':np.mean(precision),
            'f1@30':np.mean(f1)
           }
    

In [None]:
def compute_metrics(model_outputs,model_predictions,golden,metrics,mode='adm'):
    """
    all input dataframes must be of the form:
    double index of (<pat_id>,>adm_index>)
    and columns are the diagnostics. eg: diag_0,...,diag_272
    
    returns several metrics in a dataframe
    
    
    Parameters:
    -----------
    
    metrics : list
        ['roc,avgprec','acc','recall','precision','f1']
    """
    
    tqdm.pandas()
    
    accepted = ['roc','avgprec','acc','recall','accuracy','precision','f1','recall@','precision@','f1@']
    
    diag_weights = golden.sum(axis=0)
    adm_weights = golden.sum(axis=1)
    
    if metrics == 'all':
        metrics = accepted
    
    assert len(metrics) > 0
    assert any([e in metrics for e in accepted]) or any([e for e in metrics if 'recall@' in e])
    
    # threshold independent
    diag_metrics = list()
    adm_metrics = list()
    res_metrics = list()
    
    if 'roc' in metrics:
        print('computing roc')
        roc = model_outputs.progress_apply(lambda row: roc_auc_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('roc_adm') if mode=='adm' else model_outputs.progress_apply(lambda col: roc_auc_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('roc_diag')
        #roc_diag = model_outputs.apply(lambda col: roc_auc_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('roc_diag')
        #roc_adm = model_outputs.apply(lambda row: roc_auc_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('roc_adm')
        #diag_metrics.append(roc_diag)
        #adm_metrics.append(roc_adm)
        res_metrics.append(roc)
    
    if 'avgprec' in metrics:
        avgprec_diag = model_outputs.apply(lambda col: average_precision_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('avgprec_diag')
        avgprec_adm = model_outputs.apply(lambda row: average_precision_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('avgprec_adm')
        diag_metrics.append(avgprec_diag)
        adm_metrics.append(avgprec_adm)

    # threshold dependent
    
    if 'accuracy' in metrics:
        accuracy_diag = model_predictions.apply(lambda col: accuracy_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('accuracy_diag')
        accuracy_adm = model_predictions.apply(lambda row: accuracy_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('accuracy_adm')
        diag_metrics.append(accuracy_diag)
        adm_metrics.append(accuracy_adm)

    if 'recall' in metrics:
        recall_diag = model_predictions.apply(lambda col: recall_score(golden[col.name],col,zero_division=0)).rename('recall_diag')
        recall_adm = model_predictions.apply(lambda row: recall_score(golden.loc[row.name],row,zero_division=0),axis=1).rename('recall_adm')
        diag_metrics.append(recall_diag)
        adm_metrics.append(recall_adm)

    if 'precision' in metrics:
        precision_diag = model_predictions.apply(lambda col: precision_score(golden[col.name],col) if any(model_predictions[col.name] == 1) else np.nan).rename('precision_diag')
        precision_adm = model_predictions.apply(lambda row: precision_score(golden.loc[row.name],row) if any(model_predictions.loc[row.name] == 1) else np.nan,axis=1).rename('precision_adm')
        diag_metrics.append(precision_diag)
        adm_metrics.append(precision_adm)

    if 'f1' in metrics:
        f1_diag = model_predictions.apply(lambda col: f1_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('f1_diag')
        f1_adm = model_predictions.apply(lambda row: f1_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('f1_adm')
        diag_metrics.append(f1_diag)
        adm_metrics.append(f1_adm)
    
    # i.e. if recall@k in metrics
    if any(filter(lambda x: re.match('\w+@\d+',x), metrics)):
        
        matches = [e[0] for e in [re.findall('\w+@\d+',e) for e in metrics] if e] # get all <metric>@k in metrics (there may be multiple)
        for match in matches:
            
            k = int(re.findall('\w+@(\d+)',match)[0])
            metric = re.findall('(\w+)@\d+',match)[0]
            
            topk_outputs = model_outputs.apply(lambda row: row.nlargest(k),axis=1)

            # fix missing columns from previous operation
            missing_cols = [col for col in model_outputs.columns if col not in topk_outputs.columns]
            topk_outputs_all_cols = pd.concat([topk_outputs,pd.DataFrame(columns=missing_cols)])
            topk_outputs_all_cols = topk_outputs_all_cols[model_outputs.columns]
            
            ## sometimes k > (#logits>0) so we will turn all 0 logits into nan so that the following lines don't convert them to predictions
            topk_outputs_all_cols = topk_outputs_all_cols.mask(topk_outputs_all_cols == 0,np.nan)
            # done, continuing...

            topk_predictions = np.where(topk_outputs_all_cols.isna(),0,1)
            topk_predictions = pd.DataFrame(data=topk_predictions,columns=model_outputs.columns,index=model_outputs.index)

            if metric == 'recall':
                print(f'computing recall@{k}')
                metric_at_k = (topk_predictions
                               .progress_apply(lambda row: recall_score(golden.loc[row.name],row,zero_division=0),axis=1)
                               .rename(f'recall@{k}_adm') 
                               if mode=='adm' else 
                               topk_predictions
                               .progress_apply(lambda col: recall_score(golden[col.name],col,zero_division=0))
                               .rename(f'recall@{k}_diag')
                              )
                #metric_at_k_diag = topk_predictions.apply(lambda col: recall_score(golden[col.name],col,zero_division=0)).rename(f'recall@{k}_diag')
                #metric_at_k_adm = topk_predictions.apply(lambda row: recall_score(golden.loc[row.name],row,zero_division=0),axis=1).rename(f'recall@{k}_adm')
            
            elif metric == 'precision':
                print(f'computing precision@{k}')
                metric_at_k = (topk_predictions
                                .progress_apply(lambda row: precision_score(golden.loc[row.name],row) 
                                       if any(topk_predictions.loc[row.name] == 1) else np.nan,axis=1)
                                .rename(f'precision@{k}_adm') 
                                if mode=='adm' else 
                                topk_predictions
                                .progress_apply(lambda col: precision_score(golden[col.name],col) 
                                       if any(topk_predictions[col.name] == 1) else np.nan)
                                .rename(f'precision@{k}_diag')
                               )
                #metric_at_k_diag = topk_predictions.apply(lambda col: precision_score(golden[col.name],col) if any(topk_predictions[col.name] == 1) else np.nan).rename(f'precision@{k}_diag')
                #metric_at_k_adm = topk_predictions.apply(lambda row: precision_score(golden.loc[row.name],row) if any(topk_predictions.loc[row.name] == 1) else np.nan,axis=1).rename(f'precision@{k}_adm')
                
            elif metric == 'f1':
                metric_at_k_diag = topk_predictions.apply(lambda col: f1_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename(f'f1@{k}_diag')
                metric_at_k_adm = topk_predictions.apply(lambda row: f1_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename(f'f1@{k}_adm')
            
            else:
                print('what is happening')
                print(metric)

            #diag_metrics.append(metric_at_k_diag)
            #adm_metrics.append(metric_at_k_adm)    
            res_metrics.append(metric_at_k)
    
    # take weighted average
    """
    diag_metrics_wavg = (pd.concat(diag_metrics,axis=1)
                         .multiply(diag_weights,axis=0)
                         .sum(axis=0)
                         .divide(
                             diag_weights.sum()
                         )
                        )
    
    adm_metrics_wavg = (pd.concat(adm_metrics,axis=1)
                        .multiply(adm_weights,axis=0)
                        .sum(axis=0)
                        .divide(
                            adm_weights.sum()
                        )
                       )
    """
    #diag_metrics_wavg = (pd.concat(diag_metrics,axis=1)
    #                     .mean(axis=0)
    #                    )
    
    #adm_metrics_wavg = (pd.concat(adm_metrics,axis=1)
    #                     .mean(axis=0)
    #                    )

    #res = pd.concat([diag_metrics_wavg,adm_metrics_wavg])
    res = pd.concat(res_metrics,axis=1).mean(axis=0)
    res.index.name = 'metrics'
    
    return res

____

In [None]:
model_folder = 'tmp_models/'

In [39]:
#param_set = {
#          'hidden_size':100,
#          'num_layers':1,
##          'lr':0.01,
#          'model':'rnn'
#         }
for idx,param_set in tqdm(enumerate(params)):
    config = {**param_set, 
              **meta_parameters}
    
    wandb.init(
        project="icare", 
        config=config
    )
    
    model = RNN(input_size=input_size,
              hidden_size=config['hidden_size'],
              num_layers=config['num_layers'],
              n_labels=n_labels,
              model=config['model'])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    loss = compute_loss(model,train_dataloader)
    wandb.log({'epoch':0,'loss':loss})
    
    print('Training each epoch')
    for epoch in tqdm(range(1,config['epochs']+1)):
        
        loss = train_one_epochV2(model,train_dataloader,epoch,criterion,optimizer);
        wandb.log({'epoch':epoch,'loss':loss})
        
        
    
    train_metrics = compute_metricsV3(model,train_dataloader)
    train_metrics = {f'train_{k}':train_metrics[k] for k in train_metrics}
    
    val_metrics = compute_metricsV3(model,val_dataloader)
    val_metrics = {f'val_{k}':val_metrics[k] for k in val_metrics}    

    log = dict()

    log.update(train_metrics)
    log.update(val_metrics)

    wandb.log(log)
    
    model_name = str(param_set)

    hypp_save_path = os.path.join(model_folder, model_name+'_hyper_parameters.json')

    with open(hypp_save_path, "w") as file:
        json.dump(params, file)

    print('Hyperparameters saved!')
    
    weights_save_path = os.path.join(model_folder,model_name+"_weights")

    torch.save(model.state_dict(), 
               weights_save_path
              )
    print('Model saved!')

0it [00:00, ?it/s]

VBox(children=(Label(value='0.139 MB of 0.139 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Starting to train each batch


0it [00:00, ?it/s]

NameError: name 'compute_metricsV3' is not defined

----

In [19]:
from multiprocessing import Pool
from sklearn.metrics import roc_auc_score

In [None]:
outs_train

In [9]:
# setup

df = pd.DataFrame(np.zeros(shape=(int(1e8),10)))
#df_later = df.copy()
#df_later.loc[int(1e7)] = 1
#df_later = df_later.astype(bool)
#
#df_sooner = df.copy()
#df_sooner.loc[int(1e2)] = 1
#df_sooner = df_sooner.astype(bool)

In [10]:
%%timeit -r 10 -n 1

df_sooner = df.copy()
df_sooner.loc[int(1e2)] = 1
df_sooner = df_sooner.astype(bool)
df_sooner.iloc[:,0].any()

5.47 s ± 89.3 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [11]:
%%timeit -r 10 -n 1

df_later = df.copy()
df_later.loc[int(1e7)] = 1
df_later = df_later.astype(bool)
df_later.iloc[:,0].any()

5.44 s ± 42.9 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [43]:
array.mask(array.index == int(1e2)).fillna(1)#.loc[int(1e2)]

0            0.0
1            0.0
2            0.0
3            0.0
4            0.0
            ... 
999999995    0.0
999999996    0.0
999999997    0.0
999999998    0.0
999999999    0.0
Length: 1000000000, dtype: float64

----

In [30]:
def normal_roc_optimized(logits,golden):
    idx = pd.IndexSlice
    
    logits_ = logits.copy()
    golden_ = golden.copy()
    
    logits_.columns = pd.MultiIndex.from_product([['logits'],logits_.columns])
    golden_.columns = pd.MultiIndex.from_product([['golden'],golden_.columns])
    
    full = logits_.join(golden_,how='inner')
    assert (full.shape[0] == logits.shape[0]) and (full.shape[0] == golden.shape[0]),'oops'
    
    return full.apply(lambda row: roc_auc_score(row.loc[:,idx['golden',:]],row.loc[:,idx['logits']]) if any(
    #return logits.apply(lambda row: roc_auc_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('roc_adm') 

In [31]:
normal_roc(outs_train,golden_train)

270248
270248


Unnamed: 0_level_0,Unnamed: 1_level_0,logits,logits,logits,logits,logits,logits,logits,logits,logits,logits,...,golden,golden,golden,golden,golden,golden,golden,golden,golden,golden
Unnamed: 0_level_1,Unnamed: 1_level_1,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_273,diag_274,diag_275,diag_276,diag_277,diag_278,diag_279,diag_280,diag_281,diag_282
pat_id,adm_index,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
00070121385D5F499BB0D98F48554EF5,1,0.521170,0.487607,0.482274,0.475589,0.552587,0.507825,0.506065,0.470086,0.530118,0.447996,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
00070121385D5F499BB0D98F48554EF5,2,0.509584,0.494575,0.494839,0.497005,0.532272,0.519580,0.491184,0.466202,0.519124,0.462451,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
00070121385D5F499BB0D98F48554EF5,3,0.513210,0.491401,0.484080,0.488004,0.558111,0.507040,0.501613,0.471400,0.531054,0.443273,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0008F5D602267E1E2F679BA745F38A41,1,0.517556,0.494975,0.507667,0.483336,0.539981,0.499710,0.504512,0.480275,0.520952,0.457920,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0008F5D602267E1E2F679BA745F38A41,2,0.515604,0.504581,0.508042,0.504277,0.531513,0.508115,0.496231,0.483177,0.515269,0.446479,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
FFFEF32F9705DCBB22EE8E7CC09E9379,2,0.514425,0.498155,0.501106,0.511381,0.542690,0.518909,0.527698,0.480515,0.523945,0.445996,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCAF4C44303A609ABA747E6E00CEE,1,0.515025,0.496486,0.507078,0.495090,0.541965,0.511540,0.508252,0.473659,0.524212,0.460159,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCAF4C44303A609ABA747E6E00CEE,2,0.511931,0.491205,0.494301,0.516626,0.528295,0.506213,0.500183,0.484189,0.529203,0.454736,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCD298CB71236CEB3470483A4C6A1,1,0.509109,0.493436,0.502984,0.504799,0.543310,0.513838,0.511638,0.478368,0.537129,0.461304,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
for e in tqdm(dataset):
    target = e['target']
    if any([(pd.Series(t).value_counts() >1).any() if t else False for t in target]):
        print('found one')
        break

  0%|          | 0/262811 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
target

[[],
 [258.0, 136.0, 258.0],
 [136.0, 258.0, 47.0, 44.0, 47.0],
 [258.0, 47.0, 44.0, 47.0],
 [47.0, 44.0, 47.0],
 [47.0, 258.0],
 [258.0],
 [],
 []]

In [None]:
from tqdm.auto import tqdm

tqdm.pandas()

In [14]:
df = pd.DataFrame(np.random.randint(0, int(1e8), (10000, 1000)))

In [15]:
df.groupby(0).progress_apply(lambda x: x**2)

  0%|          | 0/10000 [00:00<?, ?it/s]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
0,8861896394421316,2516512275852201,372733605816100,4172308873801744,5427823796622756,762071635866025,292111749490729,9997616142086400,6321334108485124,1271447821372689,...,251821352454400,2339752576838121,3704089009201,3943910838718096,6966029135233081,1289798731957824,1984177646603881,300951681465681,7736057397619600,5319091648002609
1,9155879298676881,6709341313903761,65115732802624,890036469243441,9545029142782225,747917698471056,1806302955388129,1476837230529600,3410535705643264,4293490110523441,...,4681045028294569,2865903717266496,1946334394598400,3823240248510736,2004459901728025,7188269002995600,2243874077724304,6803153467275361,4944864005354529,19988598081321
2,3498268242745600,669675749072836,8038721756248561,1850330143135296,1072158795994896,28563637494016,1021796322871401,5273630994852900,2872036653400996,6139705831875625,...,509823885318400,916343854253584,3164205151607824,5673962587004100,712247208961600,16249839769881,6765721667556049,5029166424891456,30716934105796,9946115385025369
3,5311420841104,26637211265625,1822132330619364,5192115845970244,5280184349425129,363343527111184,1167735767278201,6401713074618436,1904908371305536,1169945083891600,...,8993614065000889,4729322665345041,4174562196165904,1025072216529001,40104241849681,1644630079213521,4089657497184900,2461892336851600,319438981596736,188127202106809
4,7723203564934081,666917970810025,548547784683409,6585170425777969,147144356836,1168733602068516,7887804921289225,2172900609241924,2184455880614569,9670344739843921,...,106726420399104,1490704223710321,2282394441945744,3051335410343236,1148400881668096,5580157529004304,496791675302976,853373183524,5421407382862849,7177021597028416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,5766018068358601,99489433364721,1637230981469521,4248013433289796,56538564139369,45277368491716,259256820914116,4111923029021284,9518272623873225,3034494608545009,...,1205378755659225,166931852284521,1725772410201889,7003729359936,166618090411969,7454134896878224,6223394012257081,6113684857283136,513185321116836,1289090035210000
9996,196378490547121,384595713876544,3465764145889081,7138363746590244,2037127152357025,69772274880400,62955909394576,8010592788667225,12669869394576,1180895698690569,...,90766015494400,775019887394041,5609292930315369,1449463478112400,5197174711993489,17178894536049,234544201706896,39165355117729,2048348658446596,4330366199248324
9997,2961342014949796,678613180542025,4530780144544900,6008298034556944,1743870706663684,5581324261644121,329371536771216,1148995490810896,7708482306949369,2208631159397776,...,730305169657281,320143452857809,7351274721180625,3698201265614244,9472348087524,1782668910928896,6118404224812441,6129201498473089,3095992769589316,52856942197284
9998,3977470202227344,107320835618929,830945279844,3086961604026624,4711842366265321,1437973421184400,234800795350521,1743325367909904,1327628297249956,587601694807009,...,953331328132096,669732526814481,9854895040286976,1790783162981316,1704794320614025,360751712706225,988775384150401,1905336031524025,197825434591849,713561075454096


# Improving dataset and dataloader

In [13]:
dataset.data['0000676389D1EE60EB48AF5693F3F3DE']

{'ccs': {'history': [[670.0], [670.0]],
  'targets': [[670.0], []],
  'extra_features': {'delta_days': [0.0, 162.0],
   'date_last_history': ['2016-02-28', '2016-08-08']}}}

In [8]:
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler,RandomSampler

In [None]:
SubsetRandomSampler()

In [None]:
n = 1000
train_size = 0.70
val_size = 0.15
test_size = 0.15

test = 0.15*n
train_size = 

In [5]:
class ICareDataset_fast(Dataset):
    
    def __init__(self, 
                 diagnoses_file, 
                 universe_grouping, 
                 grouping='ccs', # desired grouping to use (for both input and output currently),
                 train_size:float = 0.70,
                 val_size:float = 0.15,
                 test_size:float = 0.15,
                 shuffle_dataset:bool = True,
                 random_seed :int = 432
                ):
        
        assert train_size+val_size+test_size == 1, 'Oops'

        with open(diagnoses_file,'r') as fp:
            self.raw_data = json.load(fp)

        # list patients
        self.patients = list(self.raw_data.keys())
        
        self.grouping = grouping
        self.universe_grouping=universe_grouping
        
        self.__preprocess()
        
        self.data = {}
        
        print('processing each patient')
        for pat in tqdm(self.raw_data):
            
            history_sequence = self.adms2multihot(self.raw_data[pat][self.grouping]['history'])
            target_sequence = self.adms2multihot(self.raw_data[pat][self.grouping]['targets'])
            
            self.data[pat] = {'history_sequence':history_sequence,
                              'target_sequence':target_sequence
                             }
        
        dataset_size = len(self.patients)
        indices = list(range(dataset_size))
        if shuffle_dataset :
            np.random.seed(random_seed)
            np.random.shuffle(indices)
            
        train_split = int(np.floor(train_size * dataset_size))
        val_split = int(np.floor(val_size * dataset_size))
        
        self.train_indices = indices[:train_split]
        self.val_indices = indices[train_split:train_split+val_split]
        self.test_indices = indices[-(train_split+val_split):]
            
            
    def adms2multihot(self,adms):
        #print(adms)
        #print(self.grouping_data[self.grouping]['code2int'].keys())
        return (torch.stack(
                                [ F.one_hot( # list comprehension
                                    # create a multi-hot of diagnoses of each admission
                                     torch.tensor( 
                                         list(map(lambda code: self.grouping_data[self.grouping]['code2int'][code],
                                             set(admission) # we don't care about repeated codes
                                            ))
                                     ),
                                     num_classes=self.grouping_data[grouping]['n_labels']
                                 )
                                 .sum(dim=0)
                                 .float()
                                 if admission 
                                 else
                                 torch.zeros(size=(self.grouping_data[grouping]['n_labels'],))
                                 for admission in adms
                                ]
                            )
               )
    def __preprocess(self):
        # necessary data of each code_grouping (eg. ccs, chapters) for posterior padding and one_hot_encoding of batches
        self.grouping_data = {}
        for grouping_code in self.raw_data[list(self.raw_data.keys())[0]].keys():
            self.grouping_data[grouping_code] = {}

            # get all codes of this group
            all_data_grouping = self.universe_grouping

            # store n_labels this group
            self.grouping_data[grouping_code]['n_labels'] = len(set(all_data_grouping))

            # store unique sorted codes from dataset
            self.grouping_data[grouping_code]['sorted'] = sorted(set(all_data_grouping))

            # store code2int & int2code
            int2code = dict(enumerate(self.grouping_data[grouping_code]['sorted']))
            code2int = {ch: ii for ii, ch in int2code.items()}

            self.grouping_data[grouping_code]['int2code'] = int2code
            self.grouping_data[grouping_code]['code2int'] = code2int
            self.grouping_data[grouping_code]['int2code_converter'] = lambda idx: self.grouping_data[grouping_code]['int2code'][idx]

    def __str__(self):
        return 'Available groupings: ' +str(self.data[list(self.data.keys())[0]].keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        gets original converted from int2code
        """
        patient_data = self.data[self.patients[idx]]


        return {'train':patient_data['history_sequence'],
                'target':patient_data['target_sequence'],
                'pid':self.patients[idx]
               }
    
    
    
class ICareCOLLATE_fast:
    """
    This collate class gets a dataset in the format of:
    [
    {'train':[[[code1,code2],[code3]],[[etc..],[etc...]]]
      'target:':[[[code1],[code2]],[[etc..],[etc...]]]
    },
     {etc..},
     etc..
    ]
    
    And outputs a pack of train and pad of test sequences
    """
    def __init__(self):
        pass
    
    def __call__(self,batch):
        return {'train_sequences' : dict(sequence=pack_sequence([batch[i]['train'] for i in range(len(batch))],enforce_sorted=False)),
                'target_sequences': dict(sequence=pad_sequence([batch[i]['target'] for i in range(len(batch))],batch_first=True)),
                'pids': [e['pid'] for e in batch]
               }

In [69]:
ccs_universe = list(icdmap.icd9_3toccs.data.keys())
dataset_folder = '/home/debian/Simao/master-thesis/data/model_ready_dataset/icare2021_diag_A301'
dataset_fast = ICareDataset_dev(os.path.join(dataset_folder,'dataset.json'),ccs_universe,grouping)

processing each patient


  0%|          | 0/262811 [00:00<?, ?it/s]

In [74]:
train_dataloader_fast = DataLoader(dataset_fast,batch_size=64,sampler=RandomSampler(dataset_fast.train_indices),collate_fn=ICareCOLLATE_dev())

In [75]:
%%timeit -r 3 -n 1
for batch in train_dataloader_fast:
    pass

6.18 s ± 734 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


In [76]:
%%timeit -r 3 -n 1
for batch in train_dataloader_slow:
    pass

2min 39s ± 32.5 s per loop (mean ± std. dev. of 3 runs, 1 loop each)


In [90]:
_

'ola linda ❤️'

In [82]:
_ = 3

In [83]:
_____________ = 'ola'
aa = 'ola'
a = 'adeus'

bbbbbb

6

In [20]:
train_dataloader = DataLoader(dataset,
           batch_size=batch_size,
           collate_fn=ICareCOLLATE_dev(),
           sampler=SubsetRandomSampler(dataset.train_indices)
          )

In [37]:
a,lengths = pad_packed_sequence(next(iter(train_dataloader))['train_sequences'],batch_first=True)

In [39]:
lengths

tensor([ 7,  2,  7,  6, 21,  7,  6,  2, 20, 13,  2,  5, 19,  3, 12,  2, 11, 11,
         3,  2, 10,  5,  5,  6,  9,  2,  2,  2,  2, 10,  2, 10,  2,  5, 17,  2,
         3, 13,  2,  4,  3,  6,  7,  5,  2,  3,  2,  9, 10, 12, 10,  2,  3,  3,
         2, 24,  1,  2, 21, 11,  5, 13,  8, 10])

In [41]:
train_dataset = IcareDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
len(train_dataset)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=ICareCOLLATE(dataset),shuffle=True)

train_dataset_dev = ICareDataset_dev(os.path.join(dataset_folder,'train_subset.json'),grouping)
len(train_dataset_dev)
train_dataloader_dev = DataLoader(train_dataset_dev,batch_size=batch_size,collate_fn=ICareCOLLATE_dev,shuffle=True)

183967

processing each patient


  0%|          | 0/183967 [00:00<?, ?it/s]

KeyError: 258.0

In [24]:
ccs_universe = list(icdmap.icd9_3toccs.data.keys())
dataset_folder = '/home/debian/Simao/master-thesis/data/model_ready_dataset/icare2021_diag_A301'
dataset = IcareDataset_dev(os.path.join(dataset_folder,'dataset.json'),
                           ccs_universe,
                       grouping
                      )

processing each patient


  0%|          | 0/262811 [00:00<?, ?it/s]

In [31]:
batch = [dataset[i]['train'] for i in range(3)]

In [35]:
pad_sequence(batch,batch_first=True).shape

torch.Size([3, 9, 283])

In [37]:
dataset[0]

{'train': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

In [None]:
_pad_

In [30]:
torch.stack([e['train'] for e in batch])

RuntimeError: stack expects each tensor to be equal size, but got [2, 283] at entry 0 and [4, 283] at entry 1

In [None]:
class ICareCOLLATE_dev:
    """
    This collate class gets a dataset in the format of:
    [
    {'train':[[[code1,code2],[code3]],[[etc..],[etc...]]]
      'target:':[[[code1],[code2]],[[etc..],[etc...]]]
    },
     {etc..},
     etc..
    ]
    
    And outputs a pack of train and pad of test sequences
    """
    def __init__(self):
        pass
    
    def __call__(self,batch):
        
        

In [19]:
list(dataset.data.keys())[0]

'0000676389D1EE60EB48AF5693F3F3DE'

In [16]:
dataset.patients[0]

'0000676389D1EE60EB48AF5693F3F3DE'

In [28]:
dataset[0]

{'train': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

In [12]:
list(dataset.raw_data.keys())[0]

'0000676389D1EE60EB48AF5693F3F3DE'

In [3]:
class IcareDataset(Dataset):
    def __init__(self, diagnoses_file, universe_grouping, grouping='ccs' # desired grouping to use (for both input and output currently),
                ):
        
        # load admissions data
        with open(diagnoses_file,'r') as fp:
            self.data = json.load(fp)
        
        # list patients
        self.patients = list(self.data.keys())
        
        self.grouping = grouping
        self.universe_grouping=universe_grouping
        
        # create mappings between codes to one-hot into self.grouping_data
        self.__preprocess()
            
    def __preprocess(self):
        # necessary data of each code_grouping (eg. ccs, chapters) for posterior padding and one_hot_encoding of batches
        self.grouping_data = {}
        for grouping_code in self.data[list(self.data.keys())[0]].keys():
            self.grouping_data[grouping_code] = {}
            
            # get all codes of this group
            all_data_grouping = self.universe_grouping
            
            # store n_labels this group
            self.grouping_data[grouping_code]['n_labels'] = len(set(all_data_grouping))
            
            # store unique sorted codes from dataset
            self.grouping_data[grouping_code]['sorted'] = sorted(set(all_data_grouping))
            
            # store code2int & int2code
            int2code = dict(enumerate(self.grouping_data[grouping_code]['sorted']))
            code2int = {ch: ii for ii, ch in int2code.items()}
            
            self.grouping_data[grouping_code]['int2code'] = int2code
            self.grouping_data[grouping_code]['code2int'] = code2int
            self.grouping_data[grouping_code]['int2code_converter'] = lambda idx: self.grouping_data[grouping_code]['int2code'][idx]
        
    def __str__(self):
        return 'Available groupings: ' +str(self.data[list(self.data.keys())[0]].keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        gets original converted from int2code
        """
        patient_data = self.data[self.patients[idx]][self.grouping]
        
        train = patient_data['history']
        target = patient_data['targets']
        
        # remove duplicates (can happen in low granuality codes such as ccs)
        train = [list(set(admission)) for admission in train]
        target = [list(set(admission)) for admission in target]
        
        return {'train':train,
                'target':target,
                'pid':self.patients[idx]
               }

In [4]:
class ICareCOLLATE:
    """
    This collate class gets a dataset in the format of:
    [
    {'train':[[[code1,code2],[code3]],[[etc..],[etc...]]]
      'target:':[[[code1],[code2]],[[etc..],[etc...]]]
    },
     {etc..},
     etc..
    ]
    
    And outputs a pack of train and pad of test sequences
    """
    def __init__(self,dataset):
        self.dataset = dataset
    
    def __call__(self,batch):
        patients = {'train':{'sequence':[],'original':[],'pids':[]},
                    'target':{'sequence':[],'original':[],'pids':[]}
                   }
        
        grouping_code = self.dataset.grouping
        n_labels = self.dataset.grouping_data[grouping_code]['n_labels']
        code2int = self.dataset.grouping_data[grouping_code]['code2int']
        
        # <Nº admissions - 1> of each patient
        seq_lengths = []
        
        # 1-to-1 correspondence between each admission in {train/target}_admissions_sequenced and the patient's id.
        patients_list = []
        for pat in batch:
            
            pid = pat['pid'] # patient id
            train_admissions_sequenced = []
            target_admissions_sequenced = []
            seq_lengths.append(len(pat['train']))

            # convert each train admission into a multi-hot vector
            for train_admission in pat['train']:
                admission = (F.one_hot(torch.tensor(list(map(lambda code: code2int[code],train_admission))),num_classes=n_labels)
                             .sum(dim=0).float() #one-hot of each diagnose to multi-hot vector of diagnoses
                            )
                train_admissions_sequenced.append(admission)
            
            

            # convert each target admission into a one-hot vector
            for target_admission in pat['target']:
                
                if not target_admission: # target is empty
                    admission = torch.zeros(size=(n_labels,))
                else: #target has at least 1 diagnostic
                    # convert admission to multi-hot vector
                    admission = (F.one_hot(torch.tensor(list(map(lambda code: code2int[code],target_admission))),num_classes=n_labels)
                                 .sum(dim=0).float() #one-hot of each diagnose to multi-hot vector of diagnoses
                                )
                target_admissions_sequenced.append(admission)

            # stack multiple train admissions of a single patient into a single tensor
            if len(train_admissions_sequenced) > 1:
                train_admissions_sequenced = torch.stack(train_admissions_sequenced)
            else:
                train_admissions_sequenced = train_admissions_sequenced[0].view((1,-1))

            # stack multiple target admissions of a single patient into a single tensor
            if len(target_admissions_sequenced) > 1:
                target_admissions_sequenced = torch.stack(target_admissions_sequenced)
            else:
                target_admissions_sequenced = target_admissions_sequenced[0].view((1,-1))

            # store final train and test tensors
            patients['train']['sequence'].append(train_admissions_sequenced)
            patients['target']['sequence'].append(target_admissions_sequenced)
            
            patients['train']['original'].append(pat['train'])
            patients['target']['original'].append(pat['target'])
            
            # repeat pid for each admission they have on target
            pid_train_list = [pid] * len(pat['train'])
            pid_target_list = [pid] * len(pat['target'])
            patients['train']['pids'].extend(pid_train_list)
            patients['target']['pids'].extend(pid_target_list)

        # pad sequences (some patients have more admissions than others)
        patients['train']['sequence'] = pack_sequence(patients['train']['sequence'],enforce_sorted=False)
        patients['target']['sequence'] = pad_sequence(patients['target']['sequence'],batch_first=True)
        
        return {'train_sequences':patients['train'],
                'target_sequences':patients['target'],
                'train_pids':patients['train']['pids'],
                'target_pids':patients['target']['pids']
               }