In [1]:
version = "v1"

inputcol = "bhc_preceding_text"
outputcol = "brief_hospital_course"

# Fine-tuning BART for DischargeMe Brief Hospital Course Task

---

## Setup

---

In [1]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [10]:
import sys
sys.path.insert(0,'./')

In [11]:
from prompt_functions import create_pt_prompt_per_service

ModuleNotFoundError: No module named 'prompt_functions'

In [12]:
from preprocessing import data_injection

ModuleNotFoundError: No module named 'preprocessing'

In [3]:
datasets.disable_progress_bar()

In [4]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb

    wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvimig-socrates[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
from pynvml import *

def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.", flush=True)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}", flush=True)
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}", flush=True)
    print_gpu_utilization()

## Model and tokenizer

---

Download model and tokenizer. Use default parameters or try custom values (see [HF Bart configuration](https://huggingface.co/transformers/_modules/transformers/configuration_bart.html) and [Fairseq Bart](https://github.com/pytorch/fairseq/tree/master/examples/bart)).

In [6]:
model_name = "GanjinZero/biobart-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 1024  # demo
decoder_max_length = 1024

## Data

---

### Download

For demonstration, we are only using a small portion of the data.

In [8]:
import pandas as pd
from datasets import Dataset


In [9]:
print_gpu_utilization()

GPU memory occupied: 93 MB.


In [22]:
train_data = pd.read_pickle("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text+structured_data.pickle")
valid_data = pd.read_pickle("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text+structured_data.pickle")

In [25]:
train_data['brief_hospital_course_word_count'].describe()

count    68785.000000
mean       327.462397
std        236.807308
min         10.000000
25%        163.000000
50%        281.000000
75%        440.000000
max       3435.000000
Name: brief_hospital_course_word_count, dtype: float64

In [11]:
train_ds = Dataset.from_pandas(train_data[['hadm_id', "bhc_preceding_text", "brief_hospital_course"]], split="train")
valid_ds = Dataset.from_pandas(valid_data[['hadm_id', "bhc_preceding_text", "brief_hospital_course"]], split="valid")

In [12]:
# data = datasets.load_dataset("wiki_lingua", name=language, split="train[:200]")

# Take a look at the data
# print(train_ds[0])

# data['article'][0]

### Prepare

**Format and split into train and validation sets**

In [13]:
def flatten(example):
    return {
        "document": example["bhc_preceding_text"],
        "summary": example["brief_hospital_course"],
    }


def list2samples(example):
    documents = []
    summaries = []
    for sample in zip(example["document"], example["summary"]):
        if len(sample[0]) > 0:
            documents += sample[0]
            summaries += sample[1]
    return {"document": documents, "summary": summaries}


train_dataset_txt = train_ds.map(flatten, remove_columns=['hadm_id', "bhc_preceding_text", "brief_hospital_course"])
# we don't need to do this step because we don't have multiple doc/summary pairs within each example
# train_dataset = train_dataset.map(list2samples, batched=True)

valid_dataset_txt = valid_ds.map(flatten, remove_columns=['hadm_id', "bhc_preceding_text", "brief_hospital_course"])
# valid_dataset = valid_dataset.map(list2samples, batched=True)

# train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1).values()

In [15]:
import tiktoken

In [21]:
train_data

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 68785
})

**Preprocess and tokenize**

In [16]:
# We already did this, so we can just load this dataset

# def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
#     source, target = batch["document"], batch["summary"]
#     source_tokenized = tokenizer(
#         source, padding="max_length", truncation=True, max_length=max_source_length
#     )
#     target_tokenized = tokenizer(
#         target, padding="max_length", truncation=True, max_length=max_target_length
#     )

#     batch = {k: v for k, v in source_tokenized.items()}
#     # Ignore padding in the loss
#     batch["labels"] = [
#         [-100 if token == tokenizer.pad_token_id else token for token in l]
#         for l in target_tokenized["input_ids"]
#     ]
#     return batch


# train_data = train_dataset_txt.map(
#     lambda batch: batch_tokenize_preprocess(
#         batch, tokenizer, encoder_max_length, decoder_max_length
#     ),
#     batched=True,
#     remove_columns=train_dataset_txt.column_names,
# )

# validation_data = valid_dataset_txt.map(
#     lambda batch: batch_tokenize_preprocess(
#         batch, tokenizer, encoder_max_length, decoder_max_length
#     ),
#     batched=True,
#     remove_columns=valid_dataset_txt.column_names,
# )

In [17]:
# train_data.save_to_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text_BioBART_data.hf")
# validation_data.save_to_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text_BioBART_data.hf")

In [27]:
from datasets import load_from_disk
train_data = load_from_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text_BioBART_data.hf")
validation_data = load_from_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text_BioBART_data.hf")

In [28]:
print("After processing datasets", flush=True)
print_gpu_utilization()

After processing datasets
GPU memory occupied: 93 MB.


In [50]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["document"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    return inputs
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, min_length=200, max_length=1500, early_stopping=True)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [56]:
inputs = generate_summary(train_dataset_txt[:2], model)


In [59]:
inputs.input_ids

tensor([[    0,  1437, 50118,  ...,   288,  3226,     2],
        [    0,  1437, 50118,  ...,  6617,   256,     2]])

In [60]:
train_data

input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, min_length=200, max_length=1500, early_stopping=True)
# output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)


In [61]:
outputs

tensor([[    2,  1437, 50249,  ...,     1,     1,     1],
        [    2,  1437, 50249,  ...,   384,  1640,     2]])

## Training

---

### Metrics

In [18]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

  metric = datasets.load_metric("rouge")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


### Training arguments

In [19]:
# epochs = 2
# train_size = train_data.num_rows
# train_batch_size = 8
# ga_steps = 1
# virtual_batch_size = train_batch_size * ga_steps   # "invented name" => 256
# per_epoch_steps = int(train_size / virtual_batch_size + 0.5) # round => 121
# total_steps = epochs * per_epoch_steps # => 605

In [20]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"bart-dischargeme-results_{version}",
    num_train_epochs=2,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,  # demo
    per_device_eval_batch_size=2,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir=f"bart-dischargeme-logs_{version}",
    logging_steps=5,
    save_total_limit=2,

    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=250,
    save_steps=500,
    fp16=True,
    push_to_hub=False,
    metric_for_best_model='eval_loss',
    load_best_model_at_end=True,
    greater_is_better=False,
    disable_tqdm=True,
    log_level="info",
    logging_first_step=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data.shuffle(seed=42).select(range(2000)),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Using auto half precision backend


In [21]:
print("============== After setting up trainer and args ==============", flush=True)
print_gpu_utilization()


GPU memory occupied: 1791 MB.


### Train

Wandb integration

In [25]:
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="bart-dischargeme",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "dischargeme preceding_text_only",
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_" + current_time

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112348362803458, max=1.0…

Evaluate before fine-tuning

In [26]:
print("============== Before trainer.evaluate() pretrain ==============", flush=True)
print_gpu_utilization()


GPU memory occupied: 2086 MB.


In [27]:
# trainer.evaluate()

In [28]:
print("============== After trainer.evaluate() pretrain ==============", flush=True)
print_gpu_utilization()


GPU memory occupied: 2086 MB.


Train the model

In [29]:
#%%wandb
# uncomment to display Wandb charts

trainer.train()

Step,Training Loss,Validation Loss




KeyboardInterrupt: 

In [None]:
print("============== After trainer.train() ==============", flush=True)
print_gpu_utilization()


Evaluate after fine-tuning

In [None]:
trainer.evaluate()

In [None]:
if WANDB_INTEGRATION:
    wandb_run.finish()

## Evaluation

---

**Generate summaries from the fine-tuned model and compare them with those generated from the original, pre-trained one.**

In [22]:
model = AutoModelForSeq2SeqLM.from_pretrained("/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/template_code/bart-dischargeme-results_v2/checkpoint-16500/")

loading configuration file /home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/template_code/bart-dischargeme-results_v2/checkpoint-16500/config.json
Model config BartConfig {
  "_name_or_path": "/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/template_code/bart-dischargeme-results_v2/checkpoint-16500/",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
 

In [45]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["document"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, min_length=200, max_length=1500, early_stopping=True)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [44]:
%%time
test_samples = valid_dataset_txt.select(range(3))
x, summaries_before_tuning = generate_summary(test_samples, model_before_tuning)

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}



CPU times: user 3min 34s, sys: 653 ms, total: 3min 34s
Wall time: 3min 36s


In [46]:
y, summaries_after_tuning = generate_summary(test_samples, model)

In [47]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            test_samples["summary"],
            summaries_after_tuning,
        ),
        headers=["Id", "Summary Gold", "Summary BART"],
    )
)
# print("\nTarget summaries:\n")
# print(
#     tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
# )
# print("\nSource documents:\n")
# print(tabulate(list(enumerate(test_samples["document"])), headers=["Id", "Document"]))

  Id  Summary Gold                                                       Summary BART
----  -----------------------------------------------------------------  ----------------------------------------------------------------
   0  Ms. ___ is a ___ year old right-handed female with a               Ms. ___ is a ___ year old lady with lupus anticoagulant,
      significant history of lupus anticoagulant, recurrent PEs ___,     history of recurrent PE (___), and long-standing
      ___, and longstanding anxiety/panic attacks who presented on       anxiety/panic attacks currently on coumadin who
      ___ with acute onset chest pain and right eye blurry               presents today with 1.5 days of chest pain acutely
      vision.                                                            worsening today accompanied by monocular blurry vision
                                                                         out of the right eye.
      # NEURO: She was admitted to Neurology for workup 