In [1]:
import os
from typing import Callable

import pickle
import matplotlib.pyplot as plt

import json

from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence
import torch.nn.functional as F

from sklearn.model_selection import ParameterGrid, ParameterSampler

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model, RNN
from rnn_utils import outs2df, compute_loss

from sklearn.metrics import roc_curve,roc_auc_score,average_precision_score,recall_score,precision_score,f1_score,accuracy_score

from tqdm.notebook import tqdm

import pandas as pd
idx = pd.IndexSlice
import numpy as np
from math import ceil

import json


from Metrics import Metrics
from config import Settings; settings = Settings()

# Load dataset and model

In [2]:
model_name = 'pleasant-music-50'
dataset_id = 'diag_only'

model_folder = os.path.join(settings.data_base,settings.models_folder,model_name)
assert os.path.exists(model_folder)

dataset_folder = os.path.join(settings.data_base,settings.model_ready_dataset_folder,dataset_id)
assert os.path.exists(dataset_folder)

In [3]:
# Load dataset
batch_size = 64 # really doesn't matter for this notebook since we will only to inference
grouping = 'ccs'

dataset = DiagnosesDataset(os.path.join(dataset_folder,'dataset.json'),grouping)

train_dataset = DiagnosesDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
val_dataset = DiagnosesDataset(os.path.join(dataset_folder,'val_subset.json'),grouping)
test_dataset = DiagnosesDataset(os.path.join(dataset_folder,'test_subset.json'),grouping)


print('patients in train split',len(train_dataset))
print('patients in val split',len(val_dataset))
print('patients in test split',len(test_dataset))


train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

# Load model

# model hyperparameters path
hypp_save_path = os.path.join(model_folder, 'hyper_parameters.json')
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

# weights path
weights_save_path = os.path.join(model_folder,"weights")

model = RNN(**params_loaded)
model.load_state_dict(torch.load(weights_save_path))

patients in train split 5249
patients in val split 1125
patients in test split 1125


<All keys matched successfully>

# Compute metrics

In [4]:
out,golden = outs2df(model,val_dataloader,dataset,return_golden=True)

In [5]:
# predictions
from Abstention.utils import get_prediction_thresholds

def make_predictions(model_outputs, golden, prediction_method='roc gm'):
    
    thresholds = get_prediction_thresholds(model_outputs,golden,method=prediction_method)

    def predict(predictions: pd.Series, threshold : float):
        return predictions.apply(lambda x: 1 if x > threshold else 0)

    preds = model_outputs.apply(lambda x: predict(x, thresholds.loc[x.name,'threshold']),axis=0)
    
    return preds

preds = make_predictions(out,golden)

In [6]:
preds.head(2)

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0,0,0,0,0,0,0,0,1,0,...,0,1,0,0,0,0,0,0,0,1
23,1,0,0,0,0,0,0,0,0,1,0,...,1,0,0,0,1,0,0,0,1,1


In [7]:
res = eval_model(model,val_dataloader,dataset)

In [9]:
res[1]

Unnamed: 0,roc_diag,roc_adm,avgprec_diag,avgprec_adm,accuracy_diag,accuracy_adm,recall_diag,recall_adm,precision_diag,precision_adm,f1_diag,f1_adm
median,0.740923,0.931559,0.177212,0.529468,0.762738,0.783088,0.642857,0.7,0.067532,0.115942,0.117647,0.197531
mean,0.739497,0.92265,0.266887,0.522789,0.744657,0.771881,0.598254,0.67544,0.152937,0.12595,0.205356,0.204976
std,0.163173,0.051843,0.260722,0.178667,0.176784,0.086581,0.224672,0.208668,0.200536,0.064315,0.223145,0.090623


In [10]:
from sklearn.metrics import roc_auc_score,average_precision_score,recall_score,precision_score,f1_score

def compute_metrics(model_outputs,model_predictions,golden):
    """
    all input dataframes must be of the form:
    double index of (<pat_id>,>adm_index>)
    and columns are the diagnostics. eg: diag_0,...,diag_272
    
    returns several metrics in a dataframe
    """
    
    diag_weights = golden.sum(axis=0)
    adm_weights = golden.sum(axis=1)
    
    # threshold independent
    roc_diag = model_outputs.apply(lambda col: roc_auc_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('roc_diag')
    roc_adm = model_outputs.apply(lambda row: roc_auc_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('roc_adm')

    avgprec_diag = model_outputs.apply(lambda col: average_precision_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('avgprec_diag')
    avgprec_adm = model_outputs.apply(lambda row: average_precision_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('avgprec_adm')

    # threshold dependent
    
    accuracy_diag = model_predictions.apply(lambda col: accuracy_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('accuracy_diag')
    accuracy_adm = model_predictions.apply(lambda row: accuracy_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('accuracy_adm')

    recall_diag = model_predictions.apply(lambda col: recall_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('recall_diag')
    recall_adm = model_predictions.apply(lambda row: recall_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('recall_adm')

    precision_diag = model_predictions.apply(lambda col: precision_score(golden[col.name],col,zero_division=0) if any(golden[col.name] == 1) else np.nan).rename('precision_diag')
    precision_adm = model_predictions.apply(lambda row: precision_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('precision_adm')

    f1_diag = model_predictions.apply(lambda col: f1_score(golden[col.name],col) if any(golden[col.name] == 1) else np.nan).rename('f1_diag')
    f1_adm = model_predictions.apply(lambda row: f1_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('f1_adm')
    
    # take weighted average
    diag_metrics_wavg = (pd.concat([roc_diag, avgprec_diag, accuracy_diag, recall_diag, precision_diag, f1_diag],axis=1)
                         .multiply(diag_weights,axis=0)
                         .sum(axis=0)
                         .divide(
                             diag_weights.sum()
                         )
                        )
    
    adm_metrics_wavg = (pd.concat([roc_adm,avgprec_adm,accuracy_adm,recall_adm,precision_adm,f1_adm],axis=1)
                        .multiply(adm_weights,axis=0)
                        .sum(axis=0)
                        .divide(
                            adm_weights.sum()
                        )
                       )

    
    return pd.concat([diag_metrics_wavg,adm_metrics_wavg])

In [28]:
pd.concat([a,b])

roc_diag                               0.752576
avgprec_diag                           0.430799
accuracy_diag                          0.727051
recall_diag                            0.667702
precision_diag                         0.335353
f1_diag                                0.416344
hihihihihihihihihihihiroc_adm          0.919891
hihihihihihihihihihihiavgprec_adm      0.532184
hihihihihihihihihihihiaccuracy_adm     0.762132
hihihihihihihihihihihirecall_adm       0.667702
hihihihihihihihihihihiprecision_adm    0.143320
hihihihihihihihihihihif1_adm           0.228479
dtype: float64

In [11]:
a,b = compute_metrics(out,preds,golden)

In [23]:
b.index = ['hi' + n for n in b.index]
b

hihihihihihihihihihihiroc_adm          0.919891
hihihihihihihihihihihiavgprec_adm      0.532184
hihihihihihihihihihihiaccuracy_adm     0.762132
hihihihihihihihihihihirecall_adm       0.667702
hihihihihihihihihihihiprecision_adm    0.143320
hihihihihihihihihihihif1_adm           0.228479
dtype: float64

In [25]:
c = {}
c.update(a)

In [29]:
c.update({'a':1})

In [31]:
c.update({'b':2},{'c':3})

TypeError: update expected at most 1 argument, got 2

In [24]:
a.to_dict()

{'roc_diag': 0.7525759673633372,
 'avgprec_diag': 0.43079859288139066,
 'accuracy_diag': 0.7270508188651702,
 'recall_diag': 0.6677018633540373,
 'precision_diag': 0.3353526481200552,
 'f1_diag': 0.41634386282459274}

In [148]:
weights = golden.sum(axis=1)
weights.head(2)

pat_id  adm_index
21      1            20.0
23      1             9.0
dtype: float32

In [150]:
pd.concat([acc,roc],axis=1).multiply(weights,axis=0).head(2)

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy_adm,roc_adm
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1
21,1,16.764706,18.77381
23,1,6.816176,8.326996


In [152]:
pd.concat([acc,roc],axis=1).multiply(weights,axis=0).sum(axis=0).divide(weights.sum())

accuracy_adm    0.762132
roc_adm         0.919891
dtype: float64

In [129]:
acc.head(3)

roc.head(3)

diag_0    0.923314
diag_1    0.639218
diag_2    0.593412
Name: accuracy_diag, dtype: float64

diag_0    0.718475
diag_1    0.679087
diag_2    0.608805
Name: roc_diag, dtype: float64

In [125]:
acc.head(1)
weights.head(1)

diag_0    0.923314
Name: accuracy_diag, dtype: float64

diag_0    8.0
dtype: float32

In [124]:
acc.multiply(weights)

diag_0        7.386516
diag_1      244.181163
diag_2      159.034483
diag_3       59.458055
diag_4       30.952136
               ...    
diag_267    113.139475
diag_268     18.392177
diag_269      1.851776
diag_270     12.988163
diag_271     26.922800
Length: 272, dtype: float64

In [147]:
weights = golden.sum(axis=1)
weights

pat_id  adm_index
21      1            20.0
23      1             9.0
61      1            14.0
94      1             7.0
105     1             4.0
                     ... 
99383   3            12.0
99650   1             6.0
        2            24.0
99756   1            10.0
        2            16.0
Length: 1943, dtype: float32

# Update eval

In [114]:
def eval_model(model, dataloader, dataset, only_loss=False, prediction_method='roc gm'):
    """
    return either the loss or the loss and a metrics dataframe.
    
    Returns
    -------
    loss : torch.tensor
    
    metrics : pd.DataFrame, column_names ~ [roc,avgprec,accuracy,recall,precision,f1],
                index = ['median','mean','std']
        each column is followed by either "_diag" or "_adm"
        
    """
    
    model_outputs,golden = outs2df(model, dataloader, dataset, return_golden=True)
    
    loss = compute_loss(model, dataloader)
    
    predictions = make_predictions(model_outputs,golden,prediction_method)
    
    if only_loss:
        return loss
    
    metrics = compute_metrics(model_outputs,predictions,golden)
    return loss,metrics

In [116]:
eval_model(model,val_dataloader,dataset)[1]

Unnamed: 0,roc_diag,roc_adm,avgprec_diag,avgprec_adm,accuracy_diag,accuracy_adm,recall_diag,recall_adm,precision_diag,precision_adm,f1_diag,f1_adm
median,0.740923,0.931559,0.177212,0.529468,0.762738,0.783088,0.642857,0.7,0.067532,0.115942,0.117647,0.197531
mean,0.739497,0.92265,0.266887,0.522789,0.744657,0.771881,0.598254,0.67544,0.152937,0.12595,0.205356,0.204976
std,0.163173,0.051843,0.260722,0.178667,0.176784,0.086581,0.224672,0.208668,0.200536,0.064315,0.223145,0.090623


In [None]:
def eval_model(model, dataloader, dataset, criterion, epoch, name, only_loss=False,level_interest=None,k_interest=None):
    """
    This functions evaluates and computes metrics of a model checkpoint on a dataloader
    
    criterion must be reduction='none'
    """
    
    model.eval()
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    result = {'name':name,
              'epoch':epoch
             }
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    all_metrics = None
    with torch.no_grad():
        for i, batch in enumerate(iter(dataloader)):
            
            # get the inputs; data is a list of [inputs, labels]
            history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

            inputs = history_sequences['sequence']
            outs = model(inputs)
            #print(outs.shape)
            #print(inputs)
            #print(target_sequences['sequence'].shape)

            loss = criterion(outs, target_sequences['sequence'])
            #print(inputs)
            
            # zero-out positions of the loss corresponding to padded inputs
            # if a sequence has all zeros it is considered to be a padding.
            # Comment: safer way to do this would be a solution using the lengths...
            sequences,lengths = pad_packed_sequence(inputs,batch_first=True)
            mask = ~sequences.any(dim=2).unsqueeze(2).repeat(1,1,sequences.shape[-1])
            
            #print(mask.shape)
            #print(loss.shape)
            #print(len(loss))
            #print(loss)
            loss.masked_fill_(mask, 0)
        
            loss = loss.sum() / (lengths.sum()*sequences.shape[-1])

            # compute loss
            n = target_sequences['sequence'].size(0)
            total_seq += n
            total_loss += loss.item() * n
            
            # compute other metrics

            _,lengths = pad_packed_sequence(history_sequences['sequence'])
            
            preds = outs2pred(outs,int2code)
            
            if all_metrics is None:
                all_metrics = compute_metrics(preds,target_sequences['original'],level_interest, k_interest)
            else:
                new_metrics = compute_metrics(preds,target_sequences['original'],level_interest, k_interest)
                concat_metrics(all_metrics,new_metrics)

        result['loss'] = total_loss / total_seq
        if only_loss:
            return result
        for level in all_metrics:
            if level not in result:
                result[level] = {}
            for metric in all_metrics[level]:
                if metric not in result[level].keys():
                    result[level][metric] = {}
                result[level][metric] = {'mean':np.mean(all_metrics[level][metric]),
                                         'std':np.std(all_metrics[level][metric]),
                                         'n': len(all_metrics[level][metric])
                                        }
    return result