In [1]:
import os, pickle, warnings, itertools
from pathlib import Path
from functools import partial 

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn import functional as F

from tqdm import tqdm
from IPython.display import display
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

import logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)

import transformers
transformers.logging.set_verbosity_error()

%load_ext autoreload
%autoreload 1
%aimport salford_datasets.salford, salford_datasets.salford_raw, transformer_experiment.utils.embeddings, transformer_experiment.utils.finetuning, transformer_experiment.utils.shallow_classifiers, transformer_experiment.utils.plots
%aimport acd_experiment.base_dataset, acd_experiment.salford_adapter, acd_experiment.models, acd_experiment.sci, acd_experiment.systematic_comparison

from salford_datasets.salford import SalfordData, SalfordFeatures, SalfordPrettyPrint, SalfordCombinations
from acd_experiment.salford_adapter import SalfordAdapter
from transformer_experiment.utils.embeddings import BERTModels

In [2]:
class Notebook:
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    DATA_DIR = Path('data/Salford')
    CACHE_DIR = Path('data/cache')
    IMAGE_DIR = Path('images/shallow')
    SYSTEMATIC_COMPARISON_DIR = Path('data/systematic_comparison/')
    RE_DERIVE = False

In [3]:
from transformer_experiment.utils.shallow_classifiers import load_salford_dataset, get_train_test_indexes

SAL = load_salford_dataset(Notebook.RE_DERIVE, Notebook.DATA_DIR)
SAL_TRAIN_IDX, SAL_TEST_IDX, SAL_TEST_UNSEEN_IDX, SAL_TEST_IS_UNSEEN = get_train_test_indexes(SAL)
Y_TRUES = {
    'Complete': SAL.CriticalEvent.loc[SAL_TEST_IDX],
    'Unseen': SAL.CriticalEvent.loc[SAL_TEST_UNSEEN_IDX],
}

2023-04-12 17:18:53,699 [INFO] Loading processed dataset


## Embedding Generation

In [101]:
from transformers import AutoTokenizer, AutoModel
    
def load_tz_to_device(tz_output):
    """ Given the direct output of the tokeniser, loads the tokens to the GPU """
    return dict(map(
        lambda _: (_[0], _[1].to(Notebook.DEVICE)), tz_output.items()
    ))

def split_into_batches(Xt, batch_size):
    """ Given a tensor/ndarray and a batch size, splits it into batches of size up to batch_size along the first dimension """
    return np.array_split(
        Xt, np.ceil(len(Xt)/batch_size)
    )

def get_note_embeddings(X, model_uri=BERTModels.BioClinicalBert):
    tz, model = AutoTokenizer.from_pretrained(model_uri), AutoModel.from_pretrained(model_uri).to(Notebook.DEVICE).eval()
    tz_kwargs = dict(truncation=True, padding=True, return_tensors='pt')

    get_batch_embedding = lambda x: (
        model(
            **load_tz_to_device(tz(list(x), **tz_kwargs))
        )['last_hidden_state'][:, 0, :].cpu()
    )

    with torch.no_grad():
        emb = torch.cat([
            get_batch_embedding(_) for _ in tqdm(split_into_batches(X, 500), desc="Generating embeddings..")
        ])
    
    return emb

def get_note_embeddings_all_BERTs(sal):
    columns = ['AE_TriageNote', 'AE_MainDiagnosis', 'AE_PresentingComplaint']
    avail_idx = sal[columns].notna().any(axis=1)
    X = SalfordData(sal.loc[avail_idx]).tabular_to_text(columns).values

    with torch.no_grad():
        result = {
            model_name: get_note_embeddings(X, model_uri).numpy() for model_name, model_uri in BERTModels.items()
        }

    result = {
        model_name: pd.DataFrame(
            embeddings, index=avail_idx[avail_idx].index
        ).reindex(index=sal.index) for model_name, embeddings in result.items()
    }
    return result

if Notebook.RE_DERIVE:
    NOTE_EMBEDDINGS = get_note_embeddings_all_BERTs(SAL)
    with open('data/cache/note_embeddings.bin', 'wb') as file:
        pickle.dump(NOTE_EMBEDDINGS, file)
else:
    with open('data/cache/note_embeddings.bin', 'rb') as file:
        NOTE_EMBEDDINGS = pickle.load(file)

## Note-Only Classifiers

In [103]:
from acd_experiment.models import Estimator_LightGBM, Estimator_L2Regression

ESTIMATORS = {_._name: _ for _ in [
    Estimator_LightGBM,
    Estimator_L2Regression,
]}
STUDY_GRID = list(itertools.product(ESTIMATORS.keys(), NOTE_EMBEDDINGS.keys()))

In [111]:
from sklearn.calibration import CalibratedClassifierCV
import optuna
from acd_experiment.systematic_comparison import get_xy, PipelineFactory
from acd_experiment.salford_adapter import SalfordAdapter

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def run_joint_estimator(sal, embeddings, estimator_name, cv_jobs=4):
    estimator = ESTIMATORS[estimator_name]
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        X, y = SalfordAdapter(sal).xy(
            x=SalfordCombinations['with_services'],
            imputation = estimator._requirements['imputation'],
            fillna = estimator._requirements['fillna'],
            ordinal_encoding = estimator._requirements['ordinal'],
            onehot_encoding = estimator._requirements['onehot']
        )
    if estimator._requirements['fillna']:
        embeddings = embeddings.fillna(0.0)
    X = pd.concat((X, embeddings.add_prefix('EMBEDDING__')), axis=1)
    X_train, X_test = SalfordAdapter(X.loc[SAL_TRAIN_IDX]), SalfordAdapter(X.loc[SAL_TEST_IDX])
    y_train, y_test = sal.CriticalEvent.loc[SAL_TRAIN_IDX], sal.CriticalEvent.loc[SAL_TEST_IDX]

    params = optuna.load_study(
        study_name =f'{estimator_name}_None_Within-1_with_notes_and_labs', storage=f'sqlite:///{Notebook.SYSTEMATIC_COMPARISON_DIR}/{estimator_name}.db'
    ).best_params

    pipeline_factory = PipelineFactory(
        estimator=estimator, resampler=None, X_train=X_train, y_train=y_train,
    )

    model = CalibratedClassifierCV(
        pipeline_factory(**params), cv=cv_jobs, method="isotonic", n_jobs=cv_jobs,
    ).fit(X_train, y_train)

    y_pred_proba = model.predict_proba(X_test)[:, 1]
    y_pred_proba_unseen = y_pred_proba[SAL_TEST_IS_UNSEEN]

    return y_pred_proba, y_pred_proba_unseen

if Notebook.RE_DERIVE:
    RESULTS = {}
    for estimator_name, bert_variant in (pbar := tqdm(STUDY_GRID)):
        pbar.set_description(f'Training {estimator_name} on {bert_variant}')
        RESULTS[(estimator_name, bert_variant, 'only')] = run_joint_estimator(SAL, NOTE_EMBEDDINGS[bert_variant], estimator_name)
        with open(Notebook.CACHE_DIR/'transformer_shallow_results_embonly.bin', 'wb') as file:
            pickle.dump(RESULTS, file)
else:
    with open(Notebook.CACHE_DIR/'transformer_shallow_results_embonly.bin', 'rb') as file:
        RESULTS = pickle.load(file)

In [120]:
from sklearn.ensemble import VotingClassifier

def run_ensemble_estimator(sal, embeddings, estimator_tabular_name, estimator_embedding_name, cv_jobs=4):
    estimator_tabular, estimator_embedding = (
        ESTIMATORS[estimator_tabular_name], 
        ESTIMATORS[estimator_embedding_name]
    )

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        X_tabular, y = SalfordAdapter(sal).xy(
            x=SalfordCombinations['with_services'],
            imputation = estimator_tabular._requirements['imputation'],
            fillna = estimator_tabular._requirements['fillna'],
            ordinal_encoding = estimator_tabular._requirements['ordinal'],
            onehot_encoding = estimator_tabular._requirements['onehot']
        )
    if estimator_embedding._requirements['fillna']:
        embeddings = embeddings.fillna(0.0)

    y_train, y_test = sal.CriticalEvent.loc[SAL_TRAIN_IDX], sal.CriticalEvent.loc[SAL_TEST_IDX]

    y_pred_probas, y_pred_probas_unseen = [], []
    for estimator, X in [(estimator_tabular, X_tabular), (estimator_embedding, embeddings)]:
        X_train, X_test = SalfordAdapter(X.loc[SAL_TRAIN_IDX]), SalfordAdapter(X.loc[SAL_TEST_IDX])
        params = optuna.load_study(
            study_name =f'{estimator._name}_None_Within-1_with_notes_and_labs', storage=f'sqlite:///{Notebook.SYSTEMATIC_COMPARISON_DIR}/{estimator._name}.db'
        ).best_params

        pipeline = PipelineFactory(
            estimator=estimator, resampler=None, X_train=X_train, y_train=y_train,
        )

        model = CalibratedClassifierCV(
            pipeline(**params), cv=cv_jobs, method="isotonic", n_jobs=cv_jobs,
        ).fit(X_train, y_train)

        y_pred_proba = model.predict_proba(X_test)[:, 1]
        y_pred_probas.append(y_pred_proba)
        y_pred_probas_unseen.append(y_pred_proba[SAL_TEST_IS_UNSEEN])

    y_pred_proba = np.mean(np.array(y_pred_probas), axis=0)
    y_pred_proba_unseen = np.mean(np.array(y_pred_probas_unseen), axis=0)

    return y_pred_proba, y_pred_proba_unseen

STUDY_GRID = list(itertools.product(ESTIMATORS.keys(), ESTIMATORS.keys(), NOTE_EMBEDDINGS.keys()))
if Notebook.RE_DERIVE or True:
    RESULTS = {}
    for estimator_name_tab, estimator_name_emb, bert_variant in (pbar := tqdm(STUDY_GRID)):
        pbar.set_description(f'Training {estimator_name_tab}-{estimator_name_emb} on {bert_variant}')
        RESULTS[(estimator_name_tab, estimator_name_emb, bert_variant, 'ensemble')] = run_ensemble_estimator(SAL, NOTE_EMBEDDINGS[bert_variant], estimator_name_tab, estimator_name_emb)
        with open(Notebook.CACHE_DIR/'transformer_shallow_results_ensemble.bin', 'wb') as file:
            pickle.dump(RESULTS, file)
else:
    with open(Notebook.CACHE_DIR/'transformer_shallow_results_ensemble.bin', 'rb') as file:
        RESULTS = pickle.load(file)

Training L2Regression-L2Regression on PubMedBert: 100%|██████████| 12/12 [30:55<00:00, 154.63s/it]     


In [123]:
from transformer_experiment.utils.shallow_classifiers import get_discriminative_metrics
Y_TRUES = {
    'Complete': SAL.CriticalEvent.loc[SAL_TEST_IDX],
    'Unseen': SAL.CriticalEvent.loc[SAL_TEST_UNSEEN_IDX],
}
def get_full_metrics_tables(results):
    metrics = {
        'Complete': [],
        'Unseen': [],
    }
    for (estimator_name_tab, estimator_name_emb, feature_group_name, _), y_preds in results.items():
        for y_pred_proba, (y_true_name, y_true) in zip(y_preds, Y_TRUES.items()):

            metrics[y_true_name].append(dict(
                Estimator_Tab = estimator_name_tab,
                Estimator_Emb = estimator_name_emb,
                Features = feature_group_name,
            ) | get_discriminative_metrics(
                y_true, y_pred_proba
            ))
    
    for y_true_name, y_true in Y_TRUES.items():
        metrics[y_true_name].append(dict(
            Estimator='NEWS2',
            Features='Reference'
        ) | get_discriminative_metrics(
            y_true, SAL.NEWS_Score_Admission.loc[y_true.index]
        ))

    return {
        y_true_name: pd.DataFrame(metric_list) for y_true_name, metric_list in metrics.items()
    }

METRICS = get_full_metrics_tables(RESULTS)

In [126]:
METRICS['Complete'].groupby(['Estimator_Tab', 'Estimator_Emb']).mean()

The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.


Unnamed: 0_level_0,Unnamed: 1_level_0,AUROC,AUROC_Upper,AUROC_Lower,AP,AP_Upper,AP_Lower
Estimator_Tab,Estimator_Emb,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
L2Regression,L2Regression,0.892936,0.899386,0.884566,0.454751,0.476971,0.429238
L2Regression,LightGBM,0.891827,0.898627,0.883246,0.4577,0.480329,0.431261
LightGBM,L2Regression,0.924961,0.930581,0.918546,0.552789,0.572307,0.527996
LightGBM,LightGBM,0.925593,0.93131,0.919624,0.55645,0.575067,0.531869


## Fine-tuned transformers

In [10]:
def load_transformer_y_preds():
    y_preds = {}
    for bert_variant, experiment in (pbar:= tqdm(itertools.product(BERTModels.keys(), [41, 42, 43, 45]))):
        pbar.set_description(f'Loading {bert_variant}:{experiment}')
        with open(f'models/bert_{bert_variant}_{experiment}/test_pred_proba_indexed.bin', 'rb') as file:
            return pickle.load(file)
    return y_preds

TRANSFORMER_Y_PREDS = load_transformer_y_preds()

Loading BioClinicalBert:41: : 0it [00:00, ?it/s]


In [11]:
TRANSFORMER_Y_PREDS

Series([], dtype: float32)