In [1]:
import os, pickle, warnings, dataclasses, itertools
from pathlib import Path
from functools import partial 

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy import stats as st

import torch
from torch import nn
from torch.nn import functional as F

from tqdm import tqdm
from IPython.display import display
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

from typing import Iterable, Tuple

import logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)

import transformers
transformers.logging.set_verbosity_error()

%load_ext autoreload
%autoreload 1
%aimport salford_datasets.salford, salford_datasets.salford_raw, transformer_experiment.utils

from salford_datasets.salford import SalfordData, SalfordFeatures, SalfordPrettyPrint, SalfordCombinations
from salford_datasets.utils import DotDict

from transformer_experiment.utils import dict_product

In [2]:
class Notebook:
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    DATA_DIR = Path('data/Salford/')
    RE_DERIVE = False

In [3]:
if Notebook.RE_DERIVE:
    SAL = SalfordData.from_raw(
        pd.read_hdf(Notebook.DATA_DIR/'raw_v2.h5', 'table')
    ).augment_derive_all()
    SAL.to_hdf(Notebook.DATA_DIR/'sal_processed_transformers.h5', 'table')
else:
    SAL = SalfordData(pd.read_hdf(Notebook.DATA_DIR/'sal_processed_transformers.h5', 'table'))

In [4]:
BERTModels = DotDict(
    BioClinicalBert="emilyalsentzer/Bio_ClinicalBERT",
    Bert="distilbert-base-uncased",
    PubMedBert="ml4pubmed/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_pub_section"
)

Experiments: 
 1. Tabular Data only
 2. Note Embeddings Only
 3. Tabular & Note Embeddings
    - One model for both
    - Ensemble separate models
 4. Note Transformer
 5. Text-ified record Transformer
 6. Note & Text-ified record Transformer
    - One model for both
    - Ensemble separate models


In [5]:
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.metrics import precision_score, recall_score, roc_auc_score, average_precision_score, fbeta_score, make_scorer
from lightgbm import LGBMClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV

CROSS_VALIDATION_METRICS = dict(
    Precision='precision',
    Recall='recall',
    AUC='roc_auc',
    AP='average_precision',
    F1='f1',
    F2=make_scorer(fbeta_score, beta=2)
)

LIGHTGBM_PARAMETERS = dict(
    objective='binary',
    random_state=123,
    metrics=['l2', 'auc'],
    boosting_type='gbdt',
    is_unbalance=True,
    n_jobs=1
)

REGRESSION_PARAMETERS = dict(
    max_iter=5000,
    solver='lbfgs',
    random_state=123,
    penalty='l2'
)

CALIBRATION_PARAMETERS = dict(
    ensemble=True,
    cv=StratifiedKFold(n_splits=4, shuffle=True, random_state=123),
    method='isotonic',
    n_jobs=4
)

CROSS_VALIDATION_PARAMETERS = dict(
    cv=StratifiedKFold(n_splits=4, shuffle=True, random_state=123),
    n_jobs=1,
    scoring=CROSS_VALIDATION_METRICS
)

In [6]:
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import make_column_selector, make_column_transformer

REGRESSION_PREPROCESSOR = make_column_transformer(
    (OneHotEncoder(), make_column_selector(dtype_include='category')),
    (SimpleImputer(strategy='median'), make_column_selector(dtype_include=np.number)),
    remainder='passthrough'
)

In [7]:
from sklearn.calibration import IsotonicRegression
def run_shallow_CV_experiments(X_variants, y):
    classifiers = {
        'LightGBM': LGBMClassifier(
            **LIGHTGBM_PARAMETERS
        ),
        'LR-L2': LogisticRegression(   
            **REGRESSION_PARAMETERS
        )
    }

    experiments = itertools.product(X_variants.items(), classifiers.items())

    results = []
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for (X_name, X), (classifier_name, classifier) in (pbar := tqdm(experiments)):
            pbar.set_description(f'Parallel running 4 CV folds of {classifier_name} with {X_name} embeddings..')
            if classifier_name == 'LR-L2':
                X = REGRESSION_PREPROCESSOR.fit_transform(X)

            results.append(pd.DataFrame.from_dict(
                cross_validate(
                    CalibratedClassifierCV(classifier, **CALIBRATION_PARAMETERS),
                    X, y, **CROSS_VALIDATION_PARAMETERS
                )
            ).assign(Embedding=X_name, Classifier=classifier_name))
    

    return pd.concat(results).groupby(['Embedding', 'Classifier']).mean()

## 1. Tabular Classifier

In [17]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def cv_tabular_classifier(sal):
    X = SalfordData(sal[SalfordCombinations.with_services]).convert_str_to_categorical()
    y = sal.CriticalEvent
    X_variants = {
        key: X[columns] for key, columns in SalfordCombinations.items()
    }

    return run_shallow_CV_experiments(X_variants, y)

if Notebook.RE_DERIVE:
    RESULTS_1 = cv_tabular_classifier(SAL)
    RESULTS_1.to_csv('data/cache/result1.csv')
else:
    RESULTS_1 = pd.read_csv('data/cache/result1.csv').set_index(['Embedding', 'Classifier'])

display(RESULTS_1)

Unnamed: 0_level_0,Unnamed: 1_level_0,fit_time,score_time,test_Precision,test_Recall,test_AUC,test_AP,test_F1,test_F2
Embedding,Classifier,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
news,LR-L2,8.647021,0.258225,0.0,0.0,0.594588,0.088105,0.0,0.0
news,LightGBM,2.506067,0.99019,0.545704,0.00392,0.690098,0.133228,0.007783,0.004891
with_composites,LR-L2,96.810492,0.252361,0.657983,0.051477,0.78837,0.246132,0.095479,0.063111
with_composites,LightGBM,5.015228,2.101757,0.67793,0.118125,0.846199,0.326153,0.20115,0.141483
with_labs,LR-L2,730.570759,0.352537,0.567155,0.08358,0.842708,0.280948,0.145647,0.100753
with_labs,LightGBM,7.765831,2.990653,0.672319,0.22017,0.918216,0.466192,0.331701,0.254383
with_phenotype,LR-L2,62.047589,0.225753,0.723203,0.039432,0.680067,0.174806,0.074742,0.048619
with_phenotype,LightGBM,4.441151,1.74508,0.697218,0.057443,0.768164,0.23099,0.106084,0.070344
with_services,LR-L2,221.962656,0.234918,0.877467,0.364489,0.913015,0.601378,0.514847,0.412686
with_services,LightGBM,8.059408,3.304551,0.839739,0.486648,0.958504,0.719494,0.616178,0.531323


## 2. Note Embedding Classifier

In [8]:
from transformers import AutoTokenizer, AutoModel
    
def load_tz_to_device(tz_output):
    """ Given the direct output of the tokeniser, loads the tokens to the GPU """
    return dict(map(
        lambda _: (_[0], _[1].to(Notebook.DEVICE)), tz_output.items()
    ))

def split_into_batches(Xt, batch_size):
    """ Given a tensor/ndarray and a batch size, splits it into batches of size up to batch_size along the first dimension """
    return np.array_split(
        Xt, np.ceil(len(Xt)/batch_size)
    )

def get_note_embeddings(X, model_uri=BERTModels.BioClinicalBert):
    tz, model = AutoTokenizer.from_pretrained(model_uri), AutoModel.from_pretrained(model_uri).to(Notebook.DEVICE).eval()
    tz_kwargs = dict(truncation=True, padding=True, return_tensors='pt')

    get_batch_embedding = lambda x: (
        model(
            **load_tz_to_device(tz(list(x), **tz_kwargs))
        )['last_hidden_state'][:, 0, :].cpu()
    )

    with torch.no_grad():
        emb = torch.cat([
            get_batch_embedding(_) for _ in tqdm(split_into_batches(X, 500), desc="Generating embeddings..")
        ])
    
    return emb

def get_note_embeddings_all_BERTs(sal):
    columns = ['AE_TriageNote', 'AE_MainDiagnosis', 'AE_PresentingComplaint']
    avail_idx = sal[columns].notna().any(axis=1)
    X = SalfordData(sal.loc[avail_idx]).tabular_to_text(columns).values

    with torch.no_grad():
        result = {
            model_name: get_note_embeddings(X, model_uri) for model_name, model_uri in BERTModels.items()
        }

    return result, avail_idx

if Notebook.RE_DERIVE:
    NOTE_EMBEDDINGS, note_avail_idx = get_note_embeddings_all_BERTs(SAL)
    with open('data/cache/note_embeddings.bin', 'wb') as file:
        pickle.dump((NOTE_EMBEDDINGS, note_avail_idx), file)
else:
    with open('data/cache/note_embeddings.bin', 'rb') as file:
        (NOTE_EMBEDDINGS, note_avail_idx) = pickle.load(file)

In [14]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def cv_embedding_only_classifier(sal, embeddings_dict, avail_idx):
    y = sal.loc[avail_idx, 'CriticalEvent'].astype(int)
    X_variants = {
        model_name: X.numpy() for model_name, X in embeddings_dict.items()
    }

    return run_shallow_CV_experiments(X_variants, y)

if Notebook.RE_DERIVE:
    RESULTS_2 = train_embedding_only_classifier(SAL, NOTE_EMBEDDINGS, note_avail_idx)
    RESULTS_2.to_csv('data/cache/result2.csv')
else:
    RESULTS_2 = pd.read_csv('data/cache/result2.csv').set_index(['Embedding', 'Classifier'])

display(RESULTS_2)

Unnamed: 0_level_0,Unnamed: 1_level_0,fit_time,score_time,test_Precision,test_Recall,test_AUC,test_AP,test_F1,test_F2
Embedding,Classifier,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Bert,LR-L2,481.301529,1.94432,0.660414,0.121317,0.837179,0.328124,0.204883,0.144966
Bert,LightGBM,331.267331,4.63108,0.614847,0.046985,0.82535,0.273014,0.087234,0.057619
BioClinicalBert,LR-L2,935.784982,2.264355,0.702388,0.131871,0.847504,0.346662,0.2219,0.157414
BioClinicalBert,LightGBM,319.884479,4.630379,0.635437,0.069744,0.839576,0.288721,0.125588,0.084831
PubMedBert,LR-L2,962.396765,4.411551,0.637574,0.14325,0.851519,0.342452,0.233855,0.169519
PubMedBert,LightGBM,330.919344,4.579374,0.618268,0.094246,0.847822,0.308226,0.163392,0.113447


## 3. Tabular & Embedding Classifier

### 3.1 One Classifier for Both

In [15]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def cv_tabular_and_embedding_classifier(sal, embeddings_dict, avail_idx):
    X = SalfordData(sal.loc[avail_idx, SalfordCombinations.with_services]).convert_str_to_categorical()
    y = sal.loc[avail_idx, 'CriticalEvent']

    X_variants = {
        transformer: pd.concat((X, pd.DataFrame(embedding).add_prefix('EMBEDDING_').set_index(X.index)), axis=1)
        for transformer, embedding in embeddings_dict.items()
    }

    return run_shallow_CV_experiments(X_variants, y)

if Notebook.RE_DERIVE:
    RESULTS_31 = cv_tabular_and_embedding_classifier(SAL, NOTE_EMBEDDINGS, note_avail_idx)
    RESULTS_31.to_csv('data/cache/result31.csv')
else:
    RESULTS_31 = pd.read_csv('data/cache/result31.csv').set_index(['Embedding', 'Classifier'])

display(RESULTS_31)

Unnamed: 0_level_0,Unnamed: 1_level_0,fit_time,score_time,test_Precision,test_Recall,test_AUC,test_AP,test_F1,test_F2
Embedding,Classifier,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Bert,LR-L2,3446.892846,0.930931,0.865875,0.470312,0.936871,0.677538,0.609427,0.517559
Bert,LightGBM,358.325441,31.259331,0.882465,0.493255,0.954317,0.720029,0.632744,0.54095
BioClinicalBert,LR-L2,3449.244502,1.0126,0.866191,0.472424,0.939958,0.685728,0.611361,0.51966
BioClinicalBert,LightGBM,362.900432,31.037868,0.881836,0.493988,0.955296,0.721962,0.633221,0.541623
PubMedBert,LR-L2,3425.869108,0.895057,0.859471,0.473157,0.940401,0.685249,0.610311,0.51989
PubMedBert,LightGBM,357.951831,30.942927,0.883045,0.498577,0.955592,0.723896,0.637251,0.546107


### 3.2 Ensembles

In [16]:
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import VotingClassifier

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def cv_tabular_and_embedding_ensemble(sal, embeddings_dict, avail_idx):
    tabular_columns = SalfordCombinations.with_services
    X = SalfordData(sal.loc[avail_idx, tabular_columns]).convert_str_to_categorical()
    y = sal.loc[avail_idx, 'CriticalEvent']

    X_variants = {
        transformer: pd.concat((X, pd.DataFrame(embedding).add_prefix(f'EMBEDDING_').set_index(X.index)), axis=1)
        for transformer, embedding in embeddings_dict.items()
    }

    embedding_selector = make_column_transformer(('passthrough', make_column_selector(pattern='EMBEDDING_'))).set_output(transform='pandas')
    data_selector = make_column_transformer(('passthrough', tabular_columns)).set_output(transform='pandas')

    classifier_factory = {
        'LightGBM': lambda selector: make_pipeline(
            selector, 
            CalibratedClassifierCV(
                LGBMClassifier(**LIGHTGBM_PARAMETERS), **CALIBRATION_PARAMETERS
            )),
        'LR-L2': lambda selector: make_pipeline(
            selector, 
            REGRESSION_PREPROCESSOR, 
            CalibratedClassifierCV(
                LogisticRegression(**REGRESSION_PARAMETERS), **CALIBRATION_PARAMETERS
            ))
    }

    experiments = itertools.product(
        X_variants.items(), 
        itertools.product(classifier_factory.items(), repeat=2)
    )

    cross_validation_parameters = CROSS_VALIDATION_PARAMETERS | dict(
        n_jobs=4
    )

    results = []
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for (X_name, X), ((cls_name_data, cls_factory_data), (cls_name_embeddings, cls_factory_embeddings)) in (pbar := tqdm(experiments)):
            pbar.set_description(f'Parallel running 4 CV folds of {cls_name_data}-{cls_name_embeddings} with {X_name} embeddings..')

            ensemble = VotingClassifier([
                (f'DATA_{cls_name_data}', cls_factory_data(data_selector)),
                (f'EMB_{cls_name_embeddings}', cls_factory_embeddings(embedding_selector)),
            ], voting='soft')

            results.append(pd.DataFrame.from_dict(
                cross_validate(
                    ensemble,
                    X, y, **cross_validation_parameters
                )
            ).assign(Embedding=X_name, Classifier_Data=cls_name_data, Classifier_Emb=cls_name_embeddings))

        return pd.concat(results).groupby(['Embedding', 'Classifier_Data', 'Classifier_Emb']).mean()

if Notebook.RE_DERIVE:
    RESULTS_32 = cv_tabular_and_embedding_ensemble(SAL, NOTE_EMBEDDINGS, note_avail_idx)
    RESULTS_32.to_csv('data/cache/result32.csv')
else:
    RESULTS_32 = pd.read_csv('data/cache/result32.csv').set_index(['Embedding', 'Classifier_Data', 'Classifier_Emb'])

display(RESULTS_32)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,fit_time,score_time,test_Precision,test_Recall,test_AUC,test_AP,test_F1,test_F2
Embedding,Classifier_Data,Classifier_Emb,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Bert,LR-L2,LR-L2,750.592701,0.542523,,,,,,
Bert,LR-L2,LightGBM,606.041267,0.415372,,,,,,
Bert,LightGBM,LR-L2,473.115724,10.148475,0.951664,0.372763,0.945833,0.699402,0.535621,0.424372
Bert,LightGBM,LightGBM,324.754093,11.136703,0.959971,0.353492,0.944035,0.69714,0.516642,0.404595
BioClinicalBert,LR-L2,LR-L2,1329.363483,0.484679,,,,,,
BioClinicalBert,LR-L2,LightGBM,826.312274,0.437503,,,,,,
BioClinicalBert,LightGBM,LR-L2,998.361817,12.950396,0.951735,0.372763,0.947542,0.704964,0.535624,0.424371
BioClinicalBert,LightGBM,LightGBM,332.994291,11.442362,0.959222,0.35964,0.945917,0.699079,0.523088,0.411008
PubMedBert,LR-L2,LR-L2,809.081656,0.537361,,,,,,
PubMedBert,LR-L2,LightGBM,600.526571,0.408377,,,,,,
