In [24]:
import os, pickle, warnings, dataclasses, itertools, argparse
from pathlib import Path
from functools import partial 
from dataclasses import dataclass

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy import stats as st

import torch
from torch import nn
from torch.nn import functional as F

from tqdm import tqdm
from IPython.display import display
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

from typing import Iterable, Tuple

import logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)

import transformers
transformers.logging.set_verbosity_error()

%load_ext autoreload
%autoreload 1
%aimport salford_datasets.salford, salford_datasets.salford_raw, transformer_experiment.utils.finetuning, transformer_experiment.salford_transformer_datasets

from salford_datasets.salford import SalfordData, SalfordFeatures, SalfordPrettyPrint, SalfordCombinations

from transformer_experiment.utils.finetuning import BERTModels

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
class Notebook:
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    DATA_DIR = Path('data/Salford/')
    CACHE_DIR = Path('models')
    RE_DERIVE = False

In [26]:
from transformers import AutoTokenizer

@dataclasses.dataclass
class SalfordTransformerDataset(torch.utils.data.Dataset):
    _text: Iterable[str]
    _labels: Iterable[str]
    _avail_idx: Iterable[bool]
    _text_tz: Iterable[str] = None

    @classmethod
    def from_SalfordData(cls, sal, model_uri, columns=SalfordCombinations.with_services):
        _avail_idx = sal[columns].notna().any(axis=1)
        _text = SalfordData(sal).tabular_to_text(columns)
        _labels = sal.CriticalEvent.copy().astype(int).values

        return cls(_text, _labels, _avail_idx).tokenise(model_uri)

    def tokenise(self, model_uri):
        tz =  AutoTokenizer.from_pretrained(model_uri)
        tz_kwargs = dict(truncation=True, padding=True, max_length=512)

        self._text_tz = tz(list(self._text), **tz_kwargs)
        return self
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return SalfordTransformerDataset(
                _text = self._text[idx],
                _labels = self._labels[idx],
                _avail_idx = self._avail_idx.iloc[idx],
                _text_tz = dict(
                    input_ids = self._text_tz['input_ids'][idx],
                    attention_mask = self._text_tz['attention_mask'][idx]
                )
            )
            
        return dict(
            input_ids = self._text_tz['input_ids'][idx],
            attention_mask = self._text_tz['attention_mask'][idx],
            labels = self._labels[idx]
        )

    def __len__(self):
        return len(self._text)

    @property
    def tensors(self):
        return dict(
            input_ids = torch.tensor(self._text_tz['input_ids']),
            attention_mask = torch.tensor(self._text_tz['attention_mask'])
        )

## 4. Fine-Tuned Transformer

 - 4.1. Clinical notes on their own
 - 4.2. `with_services` on its own (text-ified)
 - 4.3. Expanded diagnoses on their own
 - 4.4. `with_services` and clinical notes
 - 4.5. All together

In [27]:
EXPERIMENT_FEATURE_SETS = {
    '41': SalfordFeatures.Text[:-2],
    '42': SalfordCombinations.with_services,
    '43': SalfordFeatures.Diagnoses,
    '44': SalfordCombinations.with_services + SalfordFeatures.Text[:-2],
    '45': SalfordCombinations.with_services + SalfordFeatures.Text[:-2] + SalfordFeatures.Diagnoses
}

In [35]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformer_experiment.utils.finetuning import bert_finetuning_metrics
from transformer_experiment.utils.finetuning import split_dict_into_batches, load_dict_to_device

def finetune_note_transformer(sal_tz, model_uri, save_directory="bert-finetuned-notes_fake_delete", batch_size=56):
    bert_args = TrainingArguments(
        Notebook.CACHE_DIR/save_directory,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=10,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model='AP',
        report_to='none',
        optim="adamw_torch",
        disable_tqdm=False
    )

    bert_kwargs = dict(
        num_labels=2, output_attentions=False, output_hidden_states=False, ignore_mismatched_sizes=True
    )

    X_train, X_val = train_test_split(sal_tz, test_size=0.15, random_state=123, stratify=sal_tz._labels)

    model = AutoModelForSequenceClassification.from_pretrained(model_uri, **bert_kwargs)

    trainer = Trainer(
        model,
        bert_args,
        train_dataset=X_train,
        eval_dataset=X_val,
        compute_metrics=bert_finetuning_metrics
    )

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        trainer.train()

    return model.eval()

def finetuned_inference(model, dataset, batch_size):
    with torch.no_grad():
        X = split_dict_into_batches(dataset.tensors, batch_size)
        y_pred_logit = torch.concat([
            model(**load_dict_to_device(x)).logits for x in tqdm(X)
        ])

        y_pred_proba = F.softmax(y_pred_logit, dim=1)[:,1]
        
    return y_pred_proba.cpu().detach().numpy()

In [29]:
from sklearn.model_selection import train_test_split
from transformer_experiment.utils.embeddings import load_sal

def tokenise_dataset(model_uri, feature_set, re_derive=False, debug=False):
    sal = load_sal(re_derive, Notebook.DATA_DIR)
    if debug:
        sal = sal.sample(100)
        sal.loc[sal.sample(20).index, 'CriticalEvent'] = True
    
    sal_train, sal_test = train_test_split(sal, test_size=0.33, shuffle=False)

    logging.info('Tokenising feature set')
    sal_bert_train = SalfordTransformerDataset.from_SalfordData(sal_train, model_uri, feature_set)
    sal_bert_test = SalfordTransformerDataset.from_SalfordData(sal_test, model_uri, feature_set)

    return sal_bert_train, sal_bert_test

def load_tokenised_dataset_cached(bert_variant, experiment_num, feature_set):
    cache_filepath = Notebook.CACHE_DIR/f'sal_bert_{bert_variant}_{experiment_num}.bin'
    if os.path.isfile(cache_filepath):
        logging.info('Loading tokenised data from cache')
        with open(cache_filepath, 'rb') as file:
            sal_bert_train, sal_bert_test = pickle.load(file)
    else:
        sal_bert_train, sal_bert_test = tokenise_dataset(BERTModels[bert_variant], feature_set)
        with open(cache_filepath, 'wb') as file:
            pickle.dump((sal_bert_train, sal_bert_test), file)
    
    return sal_bert_train, sal_bert_test

In [36]:
from transformers import AutoModelForSequenceClassification

def get_checkpoint_directory(experiment_num='41', bert_variant='BioClinicalBert'):
    model_directory = f'bert_{bert_variant}_{experiment_num}' 
    checkpoint_dir = [_ for _ in os.listdir(Notebook.CACHE_DIR/model_directory) if 'checkpoint-' in _]
    checkpoint_dir = sorted(checkpoint_dir, key=lambda _: int(_.split('-')[1]))
    checkpoint_dir = Notebook.CACHE_DIR/model_directory/(checkpoint_dir[-1])

    return model_directory, checkpoint_dir

def run_finetuning_4(experiment_num='41', bert_variant='BioClinicalBert', batch_size=56, debug=False):
    feature_set = EXPERIMENT_FEATURE_SETS[experiment_num]
    model_uri = BERTModels[bert_variant]

    if debug:
        model_directory = "bert-finetuned-notes_fake_delete"
        sal_bert_train, _ = tokenise_dataset(model_uri, debug=True)
    else:
        model_directory = f'bert_{bert_variant}_{experiment_num}' 
        sal_bert_train, sal_bert_test = load_tokenised_dataset_cached(bert_variant, experiment_num, feature_set)

    model = finetune_note_transformer(
        sal_bert_train, model_uri, model_directory, batch_size
    )

    y_pred_proba = finetuned_inference(model, sal_bert_test)

    with open(Notebook.CACHE_DIR/model_directory/'test_pred_proba.bin', 'wb') as file:
        pickle.dump(y_pred_proba, file)


def run_inference_4(experiment_num='41', bert_variant='BioClinicalBert', batch_size=56):
    feature_set = EXPERIMENT_FEATURE_SETS[experiment_num]
    model_uri = BERTModels[bert_variant]
    
    #model_directory, checkpoint_dir = get_checkpoint_directory(experiment_num, bert_variant)
    model_directory = f'bert_{bert_variant}_{experiment_num}' 

    _, sal_bert_test = load_tokenised_dataset_cached(bert_variant, experiment_num, feature_set)

    model = AutoModelForSequenceClassification.from_pretrained(Notebook.CACHE_DIR/model_directory/'best_model').to(Notebook.DEVICE).eval()

    _, sal_test_idx, _, _ = get_train_test_indexes(load_salford_dataset(Notebook.RE_DERIVE, Notebook.DATA_DIR))
    idx_mask = sal_bert_test._avail_idx & sal_bert_test._avail_idx.isin(sal_test_idx)

    y_pred_proba = finetuned_inference(model, sal_bert_test, batch_size)
    return y_pred_proba, sal_bert_test, sal_test_idx
    y_pred_proba = pd.Series(y_pred_proba[idx_mask], index=idx_mask[idx_mask].index)

    with open(Notebook.CACHE_DIR/model_directory/'test_pred_proba_indexed.bin', 'wb') as file:
        pickle.dump(y_pred_proba, file)


In [37]:
y_pred_proba, sal_bert_test, sal_test_idx = run_inference_4()

2023-04-12 17:51:52,643 [INFO] Loading tokenised data from cache
2023-04-12 17:55:42,440 [INFO] Loading processed dataset
100%|██████████| 2250/2250 [05:45<00:00,  6.51it/s]


In [53]:
idx_mask = sal_bert_test._avail_idx & sal_bert_test._avail_idx.index.isin(sal_test_idx)

In [57]:
pd.Series(y_pred_proba[idx_mask], index=idx_mask[idx_mask].index)

SpellSerial
2253671_16    0.006928
2337767_2     0.021146
80401_108     0.006801
235017_35     0.065716
237776_98     0.301751
                ...   
119484_60     0.040629
2481624_32    0.607663
405214_49     0.014788
228825_177    0.158270
2507492_2     0.392159
Length: 36993, dtype: float32

In [8]:
# from transformer_experiment.utils.finetuning import construct_parser

# if __name__ == '__main__':
#     parser = construct_parser()
#     args = parser.parse_args()

#     run_finetuning_4(args.experiment, args.model, args.batch_size, args.debug)
#     #run_inference_4(args.experiment, args.model, args.batch_size)

## Results

In [9]:
def collate_inference_results():
    feature_sets = {
        41: 'Notes',
        42: 'Tabular',
        43: 'Diagnoses',
        45: 'All'
    }
    y_preds = {}
    for bert_variant, experiment in (pbar:= tqdm(itertools.product(BERTModels.keys(), [41, 42, 43, 45]))):
        pbar.set_description(f'Loading {bert_variant}:{experiment}')
        with open(f'models/bert_{bert_variant}_{experiment}/test_pred_proba_indexed.bin', 'rb') as file:
            y_pred_proba = pickle.load(file)
        y_pred_proba = y_pred_proba.reindex(SAL_TEST_IDX).fillna(0)
        y_pred_proba_unseen = y_pred_proba[SAL_TEST_IS_UNSEEN]
        y_preds[(bert_variant, feature_sets[experiment])] = (y_pred_proba, y_pred_proba_unseen)
        
    return y_preds

TRANSFORMER_Y_PREDS = collate_inference_results()
with open(Notebook.CACHE_DIR/'transformer_finetuning_results.bin', 'wb') as file:
    pickle.dump(TRANSFORMER_Y_PREDS, file)

## Explainability

In [10]:
from Transformer_Explainability.BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from Transformer_Explainability.BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import AutoModelForSequenceClassification


In [11]:
from transformers import AutoTokenizer

experiment_num='41'
bert_variant='BioClinicalBert'
batch_size=56
feature_set = EXPERIMENT_FEATURE_SETS[experiment_num]
model_uri = BERTModels[bert_variant]

model_directory, checkpoint_dir = get_checkpoint_directory(experiment_num, bert_variant)
tz = AutoTokenizer.from_pretrained(model_uri)
_, sal_bert_test = load_tokenised_dataset_cached(bert_variant, experiment_num)

tz = AutoTokenizer.from_pretrained(model_uri)

model = BertForSequenceClassification.from_pretrained(checkpoint_dir).to(Notebook.DEVICE).eval()

2023-04-04 12:31:28,537 [INFO] Loading tokenised data from cache


In [20]:
def finetuned_explainability(model, dataset):
    explanations = Generator(model)

    expl = []

    X = dataset.tensors
    for input_ids, attention_mask in zip(X['input_ids'].to(Notebook.DEVICE), X['attention_mask'].to(Notebook.DEVICE)):
        expl.append(
            explanations.generate_LRP(
                input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0), start_layer=0
            )[0].detach().cpu().numpy()
        )

    del explanations

    return expl

In [24]:
expl = finetuned_explainability(model, X)

The `device` argument is deprecated and will be removed in v5 of Transformers.


OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB (GPU 0; 11.91 GiB total capacity; 11.14 GiB already allocated; 21.12 MiB free; 11.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [23]:
expl[0]

array([0.0000000e+00, 4.8046673e-04, 4.8487942e-04, 4.8961386e-04,
       4.8804042e-04, 6.6239212e-05, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e

In [18]:
y_pred = X._labels

In [22]:
y_pred*2-1

array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

In [None]:
expl.shape

torch.Size([5, 149])

In [None]:
output = F.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
if classification == 0:
    expl *= (-1)

In [None]:
tokens = tz.convert_ids_to_tokens(input_ids.flatten())

In [None]:
from captum.attr import (
    visualization
)

vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                1,
                                1,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),1.0,1.0,"[CLS] ; a & e diagnosis is "" superficial injury of head ( disorder ) "" ; t ##ria ##ge note is "" bi ##ba , stayed on sofa all night due to di ##zzi ##ness , got up at 08 : 00 , went dizzy , collapsed , awoke on floor at 09 : 30 , no apparent injuries , had th ##ora ##ci ##c back pains last 6 / 52 which has increase . still experiencing di ##zzi ##ness on standing . "" ; ; ; [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),1.0,1.0,"[CLS] ; a & e diagnosis is "" superficial injury of head ( disorder ) "" ; t ##ria ##ge note is "" bi ##ba , stayed on sofa all night due to di ##zzi ##ness , got up at 08 : 00 , went dizzy , collapsed , awoke on floor at 09 : 30 , no apparent injuries , had th ##ora ##ci ##c back pains last 6 / 52 which has increase . still experiencing di ##zzi ##ness on standing . "" ; ; ; [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]"
,,,,
