In [1]:
version = "v1"

inputcol = "bhc_preceding_text"
outputcol = "brief_hospital_course"

# Fine-tuning BART for DischargeMe Brief Hospital Course Task

---

## Setup

---

In [2]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [3]:
import tiktoken

In [4]:
import sys
sys.path.insert(0,'../')

In [5]:
from prompt_functions import create_pt_prompt_per_service

In [6]:
from preprocessing import data_injection

In [7]:
# turn this on if you are running this not as an interactive notebook
# datasets.disable_progress_bar()

In [8]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb

    wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvimig-socrates[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
from pynvml import *

def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.", flush=True)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}", flush=True)
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}", flush=True)
    print_gpu_utilization()

## Model and tokenizer

---

Download model and tokenizer. Use default parameters or try custom values (see [HF Bart configuration](https://huggingface.co/transformers/_modules/transformers/configuration_bart.html) and [Fairseq Bart](https://github.com/pytorch/fairseq/tree/master/examples/bart)).

In [10]:
model_name = "GanjinZero/biobart-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 1024
decoder_max_length = 1024

## Data

---

In [11]:
import pandas as pd
from datasets import Dataset


In [12]:
print_gpu_utilization()

GPU memory occupied: 319 MB.


In [13]:
train_data = pd.read_pickle("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text+structured_data.pickle")
valid_data = pd.read_pickle("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text+structured_data.pickle")

In [14]:
train_data['brief_hospital_course_word_count'].describe()

count    68785.000000
mean       327.462397
std        236.807308
min         10.000000
25%        163.000000
50%        281.000000
75%        440.000000
max       3435.000000
Name: brief_hospital_course_word_count, dtype: float64

In [15]:
train_ds = Dataset.from_pandas(train_data[['hadm_id', inputcol, outputcol]], split="train")
valid_ds = Dataset.from_pandas(valid_data[['hadm_id', inputcol, outputcol]], split="valid")

### Prepare

**Format and split into train and validation sets**

In [16]:
def flatten(example):
    return {
        "document": example[inputcol],
        "summary": example[outputcol],
    }


def list2samples(example):
    documents = []
    summaries = []
    for sample in zip(example["document"], example["summary"]):
        if len(sample[0]) > 0:
            documents += sample[0]
            summaries += sample[1]
    return {"document": documents, "summary": summaries}


train_dataset_txt = train_ds.map(flatten, remove_columns=['hadm_id', inputcol, outputcol])
valid_dataset_txt = valid_ds.map(flatten, remove_columns=['hadm_id', inputcol, outputcol])


Map:   0%|          | 0/68785 [00:00<?, ? examples/s]

Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

**Preprocess and tokenize**

In [17]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["document"], batch["summary"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_dataset_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_dataset_txt.column_names,
)

validation_data = valid_dataset_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=valid_dataset_txt.column_names,
)

Map:   0%|          | 0/68785 [00:00<?, ? examples/s]

Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

In [18]:
# Write all the above preprocessing to file so we don't have to read it in again

# train_data.save_to_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text_BioBART_data.hf")
# validation_data.save_to_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text_BioBART_data.hf")

In [19]:
# Comment this out if you haven't generated the tokenized data above. 

from datasets import load_from_disk
train_data = load_from_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/train/discharge_target_with_preceding_text_BioBART_data.hf")
validation_data = load_from_disk("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text_BioBART_data.hf")

In [20]:
print("After processing datasets", flush=True)
print_gpu_utilization()

After processing datasets
GPU memory occupied: 319 MB.


In [21]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["document"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    return inputs
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, min_length=200, max_length=1500, early_stopping=True)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)


## Training

---

### Metrics

In [22]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

  metric = datasets.load_metric("rouge")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

### Training arguments

In [23]:
# epochs = 2
# train_size = train_data.num_rows
# train_batch_size = 8
# ga_steps = 1
# virtual_batch_size = train_batch_size * ga_steps   # "invented name" => 256
# per_epoch_steps = int(train_size / virtual_batch_size + 0.5) # round => 121
# total_steps = epochs * per_epoch_steps # => 605

In [24]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"bart-dischargeme-results_{version}",
    num_train_epochs=2, 
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=2,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir=f"bart-dischargeme-logs_{version}",
    logging_steps=5,
    save_total_limit=2,

    # additional args we added
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=250,
    save_steps=500,
    fp16=True,
    push_to_hub=False,
    metric_for_best_model='eval_loss',
    load_best_model_at_end=True,
    greater_is_better=False,
    disable_tqdm=True,
    log_level="info",
    logging_first_step=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    # we only evaluate on a set of 2000 because it faster to measure progress and select best model
    eval_dataset=validation_data.shuffle(seed=42).select(range(2000)),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Using auto half precision backend


In [25]:
print("============== After setting up trainer and args ==============", flush=True)
print_gpu_utilization()

GPU memory occupied: 2139 MB.


### Train

Wandb integration

In [26]:
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="bart-dischargeme",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "dischargeme preceding_text_only",
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_" + current_time

Evaluate before fine-tuning

In [27]:
print("============== Before trainer.evaluate() pretrain ==============", flush=True)
print_gpu_utilization()


GPU memory occupied: 2139 MB.


In [28]:
# Run an initial full validation dataset evaluation to get a baseline
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 2
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


{'eval_loss': 7.941129684448242, 'eval_rouge1': 0.2535, 'eval_rouge2': 0.0145, 'eval_rougeL': 0.2245, 'eval_rougeLsum': 0.2506, 'eval_gen_len': 20.0, 'eval_runtime': 426.5652, 'eval_samples_per_second': 4.689, 'eval_steps_per_second': 2.344}


{'eval_loss': 7.941129684448242,
 'eval_rouge1': 0.2535,
 'eval_rouge2': 0.0145,
 'eval_rougeL': 0.2245,
 'eval_rougeLsum': 0.2506,
 'eval_gen_len': 20.0,
 'eval_runtime': 426.5652,
 'eval_samples_per_second': 4.689,
 'eval_steps_per_second': 2.344}

In [29]:
print("============== After trainer.evaluate() pretrain ==============", flush=True)
print_gpu_utilization()


GPU memory occupied: 5311 MB.


Train the model

In [None]:
#%%wandb
# uncomment to display Wandb charts

trainer.train()

***** Running training *****
  Num examples = 68,785
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 17,198
  Number of trainable parameters = 406,291,456


{'loss': 8.3042, 'grad_norm': inf, 'learning_rate': 0.0, 'epoch': 0.0}
{'loss': 8.3542, 'grad_norm': 39.05048751831055, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.0}
{'loss': 8.3989, 'grad_norm': 48.65020751953125, 'learning_rate': 7.000000000000001e-07, 'epoch': 0.0}
{'loss': 7.9502, 'grad_norm': 26.361770629882812, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.0}
{'loss': 7.5405, 'grad_norm': 20.091793060302734, 'learning_rate': 1.7000000000000002e-06, 'epoch': 0.0}
{'loss': 6.9417, 'grad_norm': 14.027737617492676, 'learning_rate': 2.2e-06, 'epoch': 0.0}
{'loss': 6.2932, 'grad_norm': 8.123832702636719, 'learning_rate': 2.7e-06, 'epoch': 0.0}
{'loss': 5.8968, 'grad_norm': 6.940796375274658, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.0}
{'loss': 5.8289, 'grad_norm': 6.318179130554199, 'learning_rate': 3.7e-06, 'epoch': 0.0}
{'loss': 5.5069, 'grad_norm': 3.9666635990142822, 'learning_rate': 4.2000000000000004e-06, 'epoch': 0.01}
{'loss': 5.3415, 'grad_norm': 4.44

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 2


{'loss': 4.0782, 'grad_norm': 2.0047314167022705, 'learning_rate': 2.47e-05, 'epoch': 0.03}
{'eval_loss': 3.7908928394317627, 'eval_rouge1': 7.5955, 'eval_rouge2': 4.5109, 'eval_rougeL': 6.9906, 'eval_rougeLsum': 7.4926, 'eval_gen_len': 20.0, 'eval_runtime': 426.9903, 'eval_samples_per_second': 4.684, 'eval_steps_per_second': 2.342, 'epoch': 0.03}
{'loss': 3.9421, 'grad_norm': 1.8848192691802979, 'learning_rate': 2.5200000000000003e-05, 'epoch': 0.03}
{'loss': 3.9671, 'grad_norm': 2.4494495391845703, 'learning_rate': 2.57e-05, 'epoch': 0.03}
{'loss': 3.8797, 'grad_norm': 2.3967435359954834, 'learning_rate': 2.6200000000000003e-05, 'epoch': 0.03}
{'loss': 4.034, 'grad_norm': 2.2060980796813965, 'learning_rate': 2.6700000000000002e-05, 'epoch': 0.03}
{'loss': 3.9222, 'grad_norm': 1.9037469625473022, 'learning_rate': 2.7200000000000004e-05, 'epoch': 0.03}
{'loss': 3.9342, 'grad_norm': 1.8078581094741821, 'learning_rate': 2.7700000000000002e-05, 'epoch': 0.03}
{'loss': 3.8937, 'grad_norm':

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 2


{'loss': 3.7515, 'grad_norm': 2.314866065979004, 'learning_rate': 4.97e-05, 'epoch': 0.06}


Saving model checkpoint to bart-dischargeme-results_v1/tmp-checkpoint-500
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_eos_token_id': 2}
Configuration saved in bart-dischargeme-results_v1/tmp-checkpoint-500/config.json
Configuration saved in bart-dischargeme-results_v1/tmp-checkpoint-500/generation_config.json


{'eval_loss': 3.5171148777008057, 'eval_rouge1': 7.6906, 'eval_rouge2': 4.5627, 'eval_rougeL': 7.0861, 'eval_rougeLsum': 7.5985, 'eval_gen_len': 20.0, 'eval_runtime': 426.8672, 'eval_samples_per_second': 4.685, 'eval_steps_per_second': 2.343, 'epoch': 0.06}


Model weights saved in bart-dischargeme-results_v1/tmp-checkpoint-500/model.safetensors
tokenizer config file saved in bart-dischargeme-results_v1/tmp-checkpoint-500/tokenizer_config.json
Special tokens file saved in bart-dischargeme-results_v1/tmp-checkpoint-500/special_tokens_map.json


In [None]:
print("============== After trainer.train() ==============", flush=True)
print_gpu_utilization()


Evaluate after fine-tuning

In [None]:
# Run an final full validation dataset evaluation to get the best metrics at the end
trainer.evaluate()

In [None]:
if WANDB_INTEGRATION:
    wandb_run.finish()

## Evaluation

---

**Generate summaries from the fine-tuned model and compare them with those generated from the original, pre-trained one.**

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/template_code/bart-dischargeme-results_v2/checkpoint-16500/")

In [None]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["document"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, min_length=200, max_length=1500, early_stopping=True)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [None]:
%%time
test_samples = valid_dataset_txt.select(range(3))
x, summaries_before_tuning = generate_summary(test_samples, model_before_tuning)

In [None]:
y, summaries_after_tuning = generate_summary(test_samples, model)

In [None]:
# Get some example summaries
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            test_samples["summary"],
            summaries_after_tuning,
        ),
        headers=["Id", "Summary Gold", "Summary BART"],
    )
)
# print("\nTarget summaries:\n")
# print(
#     tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
# )
# print("\nSource documents:\n")
# print(tabulate(list(enumerate(test_samples["document"])), headers=["Id", "Document"]))