Set Up

In [1]:
!pip install -q sentencepiece
!pip install -q transformers
!pip install -q evaluate
!pip install -q rouge_score

import evaluate

In [2]:
!pip install -q transformers datasets
!pip install -q peft
!pip install -q accelerate
!pip install -U bitsandbytes
!pip install sacrebleu



In [3]:
#let's make longer output readable without horizontal scrolling
from pprint import pprint
import os

# List files in the /content directory
print(os.listdir('/content'))

['.config', 'drive', 'results', 'sample_data']


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
import pandas as pd

# Load CSV files (assuming they are in Google Drive or uploaded to Colab)
discharge_df = pd.read_csv('/content/drive/MyDrive/discharge.csv')

In [6]:
discharge_df.head()

Unnamed: 0,note_id,subject_id,hadm_id,note_type,note_seq,charttime,storetime,text
0,10000032-DS-21,10000032,22595853,DS,21,2180-05-07 00:00:00,2180-05-09 15:26:00,\nName: ___ Unit No: _...
1,10000032-DS-22,10000032,22841357,DS,22,2180-06-27 00:00:00,2180-07-01 10:15:00,\nName: ___ Unit No: _...
2,10000032-DS-23,10000032,29079034,DS,23,2180-07-25 00:00:00,2180-07-25 21:42:00,\nName: ___ Unit No: _...
3,10000032-DS-24,10000032,25742920,DS,24,2180-08-07 00:00:00,2180-08-10 05:43:00,\nName: ___ Unit No: _...
4,10000084-DS-17,10000084,23052089,DS,17,2160-11-25 00:00:00,2160-11-25 15:09:00,\nName: ___ Unit No: __...


We will need to subset the data because of the immense number of rows and text contained within each row for ease of processing

In [7]:
def clean_text(text):
    text = text.replace('\n', ' ')
    text = ' '.join(text.split())
    return text

sample_size = 10000

discharge_sample = discharge_df.sample(n=sample_size, random_state=42)
discharge_sample['cleaned_text'] = discharge_sample['text'].apply(clean_text)

print(discharge_sample.head())

               note_id  subject_id   hadm_id note_type  note_seq  \
6292    10202247-DS-15    10202247  28736349        DS        15   
92111   12784119-DS-19    12784119  27383409        DS        19   
209235   16314105-DS-3    16314105  27871553        DS         3   
225051  16805731-DS-23    16805731  24081862        DS        23   
143620  14334225-DS-10    14334225  29709912        DS        10   

                  charttime            storetime  \
6292    2173-11-11 00:00:00  2173-11-15 13:25:00   
92111   2196-11-13 00:00:00  2196-11-14 19:33:00   
209235  2141-05-10 00:00:00  2141-06-02 15:28:00   
225051  2149-10-14 00:00:00  2149-10-14 15:27:00   
143620  2154-09-26 00:00:00  2154-09-26 16:07:00   

                                                     text  \
6292     \nName:  ___                    Unit No:   __...   
92111    \nName:  ___                 Unit No:   ___\n...   
209235   \nName:  ___                   Unit No:   ___...   
225051   \nName:  ___.            

In [8]:
# Looking at the first row entry of the discharge description.
# Note at the bottom of the description is a prescribed discharge diagnosis (instructions) which is what we want to evaluate against
discharge_sample.iloc[0]['cleaned_text']

"Name: ___ Unit No: ___ Admission Date: ___ Discharge Date: ___ Date of Birth: ___ Sex: F Service: MEDICINE Allergies: Anticholinergics,Other / Reglan Attending: ___. Chief Complaint: Abdominal pain Major Surgical or Invasive Procedure: None History of Present Illness: Patient is a ___ yo woman with history of chronic pancreatitis s/p cholecystectomy and sphincterotomy who presents with 1wk of worsening abdominal pain. As per patient, the pain is intermittent, sharp and ___ in quality. It localizes to her mid/righ upper abdomen and radiates up the chest wall and to the back. Patient finds this pain to be very similar to her prior pancreatitis flare-ups. Denies precipitants including alcohol use, abd. trauma, infections or h/o gallstones. She notes that she is on a restricted diet (no caffeine/fatty food/fried food/dairy) as part of her pancreatitis management. Also notes that pain worsens with meals; and that it is alleviated when NPO, or with NSAIDs and dilaudid. . Patient was initall

In [9]:
instructions = discharge_sample.iloc[0]['cleaned_text'].split("Discharge Instructions:")
instructions = ' '.join(instruction.strip() for instruction in instructions[1:] if instruction.strip())
instructions

'You were admitted to the hospital because of abdominal pain and inibility to eat secondary to pain. You were given IV pain medication and fluids. Over the course of your stay you were slowly able to eat more starting with fluids first. You tolerated a low residue diet with your baseline amount of pain. You were seen by the ___ doctors in the hospital who recommended tests to be sent out to look for a cause of your pain. No cause could be found. Medication changes: You were started on Omeprazole 20mg once a day. Remember to avoid dairy products and try to eat a low residue bland diet. Your appointment with Dr. ___ was changed to ___ at 11am. Please return to the hospital or call your doctor if you have temperature greater than 101, shortness of breath, worsening difficulty with swallowing, chest pain, abdominal pain, diarrhea, or any other symptoms that you are concerned about. Followup Instructions: ___'

In [10]:
# Use function to split for entire dataset
import re

def extract_sections(text):
    # Regex to match the discharge instructions section
    discharge_pattern = re.compile(r"Discharge Instructions:(.*)", re.DOTALL)
    match = discharge_pattern.search(text)
    if match:
        discharge_summary = match.group(1).strip()
        # Removing the discharge instructions part from the original text
        rest_of_text = text[:match.start()] + text[match.end():]
        return rest_of_text.strip(), discharge_summary
    return None, None

input_texts = []
output_texts = []

for index, row in discharge_sample.iterrows():
    input_text, discharge_instructions = extract_sections(row['cleaned_text'])
    if input_text and discharge_instructions:
        input_texts.append(input_text)
        output_texts.append(discharge_instructions)

In [11]:
input_texts

["Name: ___ Unit No: ___ Admission Date: ___ Discharge Date: ___ Date of Birth: ___ Sex: F Service: MEDICINE Allergies: Anticholinergics,Other / Reglan Attending: ___. Chief Complaint: Abdominal pain Major Surgical or Invasive Procedure: None History of Present Illness: Patient is a ___ yo woman with history of chronic pancreatitis s/p cholecystectomy and sphincterotomy who presents with 1wk of worsening abdominal pain. As per patient, the pain is intermittent, sharp and ___ in quality. It localizes to her mid/righ upper abdomen and radiates up the chest wall and to the back. Patient finds this pain to be very similar to her prior pancreatitis flare-ups. Denies precipitants including alcohol use, abd. trauma, infections or h/o gallstones. She notes that she is on a restricted diet (no caffeine/fatty food/fried food/dairy) as part of her pancreatitis management. Also notes that pain worsens with meals; and that it is alleviated when NPO, or with NSAIDs and dilaudid. . Patient was inital

In [12]:
output_texts[0]

'You were admitted to the hospital because of abdominal pain and inibility to eat secondary to pain. You were given IV pain medication and fluids. Over the course of your stay you were slowly able to eat more starting with fluids first. You tolerated a low residue diet with your baseline amount of pain. You were seen by the ___ doctors in the hospital who recommended tests to be sent out to look for a cause of your pain. No cause could be found. Medication changes: You were started on Omeprazole 20mg once a day. Remember to avoid dairy products and try to eat a low residue bland diet. Your appointment with Dr. ___ was changed to ___ at 11am. Please return to the hospital or call your doctor if you have temperature greater than 101, shortness of breath, worsening difficulty with swallowing, chest pain, abdominal pain, diarrhea, or any other symptoms that you are concerned about. Followup Instructions: ___'

In [13]:
all_input_texts = ' '.join(input_texts)

In [14]:
# Split into train validation test sets
from datasets import Dataset
dataset = []

for input_text, output_text in zip(input_texts, output_texts):
    dataset.append({
        "input": input_text,
        "output": output_text
    })

dataset = Dataset.from_list(dataset)
train_test_dataset = dataset.train_test_split(test_size=0.2)
test_valid = train_test_dataset['test'].train_test_split(test_size=0.5)
train = train_test_dataset["train"]
valid = test_valid["train"]
test = test_valid["test"]

In [15]:
train

Dataset({
    features: ['input', 'output'],
    num_rows: 7940
})

In [16]:
test

Dataset({
    features: ['input', 'output'],
    num_rows: 993
})

In [17]:
from transformers import BitsAndBytesConfig
import torch

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

In [18]:
train[0]['input']

'Name: ___. Unit No: ___ Admission Date: ___ Discharge Date: ___ Date of Birth: ___ Sex: M Service: MEDICINE Allergies: No Known Allergies / Adverse Drug Reactions Attending: ___ Chief Complaint: GERD Major Surgical or Invasive Procedure: ___ Interventional radiology TIPS evaluation History of Present Illness: Pt is a ___ year old ___ male with a history of HCV/Alcoholic cirrhosis, porcelain gallbladder, Type 2 DM, and a past SBP on cipro ppx who presented with vomiting starting two days ago as well as GERD-like symptoms. Pt reports having a decreased appetitie over the past two days, and that belly becomes quite distended after meals until he end up (projectile) vomiting. He states that prior to vomiting, he feels like he is having "heartburn" and there is a pressure on his chest when he lies down to sleep (to the point where it keeps him up/wakes him at night despite three pillows). He has noticed increased fluid accumulation in his belly and his lungs (he has had both tapped and dra

In [19]:
train

Dataset({
    features: ['input', 'output'],
    num_rows: 7940
})

### Import T5 for baseline modeling

In [20]:
from huggingface_hub import login
login(token="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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [21]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-base")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base")
model = model.to('cuda')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained("t5-base")

config

T5Config {
  "_name_or_path": "t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
   

In [40]:
def tokenize_function(examples):
    inputs = examples["input"]
    targets = examples["output"]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    # Tokenize targets with the same parameters
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train_dataset = train.map(tokenize_function, batched=True)
tokenized_valid_dataset = valid.map(tokenize_function, batched=True)

Map:   0%|          | 0/7940 [00:00<?, ? examples/s]



Map:   0%|          | 0/992 [00:00<?, ? examples/s]

In [27]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [28]:
from transformers import Trainer, TrainingArguments

# Training Arguments
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,  # Set batch size to 1
    per_device_eval_batch_size=1,  # Set batch size to 1
    gradient_accumulation_steps=16,  # Simulate a larger batch size
    eval_accumulation_steps=2,
    do_train=True,
    do_eval=True,
    logging_steps=500,
    save_steps=500,
    eval_steps=500,
    num_train_epochs=2,
    eval_strategy="steps",
    save_total_limit=3,
    fp16=True,  # Use mixed precision training
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [29]:
trainer.train()

Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss
500,1.6717,1.177593


TrainOutput(global_step=992, training_loss=1.4427244740147744, metrics={'train_runtime': 1923.0773, 'train_samples_per_second': 8.258, 'train_steps_per_second': 0.516, 'total_flos': 9665379638968320.0, 'train_loss': 1.4427244740147744, 'epoch': 1.998992443324937})

In [30]:
results = trainer.evaluate()
print("Evaluation results:", results)

Evaluation results: {'eval_loss': 1.1075973510742188, 'eval_runtime': 46.1377, 'eval_samples_per_second': 21.501, 'eval_steps_per_second': 21.501, 'epoch': 1.998992443324937}


In [None]:
import sacrebleu
PROMPT = 'summarize: '

# Function to generate discharge summary
def generate_discharge_summary(tokenizer, model, input_text, max_input_length=512):
    if not input_text:
        return ""
    # Tokenize the input text with truncation
    inputs = tokenizer(PROMPT + input_text, return_tensors='pt', truncation=True, padding='longest', max_length=max_input_length)
    # Generate summary
    summary_ids = model.generate(inputs["input_ids"], max_length=200, num_beams=4, length_penalty=2.0, early_stopping=True)
    # Decode the generated summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return summary


# Convert Hugging Face Dataset to list of dictionaries
train_list = train.to_dict()

generated_summaries = []
reference_summaries = []

num_examples = 10

for i in range(num_examples):
    input_text = train_list['input'][i]
    reference_summary = train_list['output'][i]

    generated_summary = generate_discharge_summary(tokenizer, model, input_text)

    generated_summaries.append(generated_summary)
    reference_summaries.append(reference_summary)

# Compute BLEU scores
bleu = sacrebleu.corpus_bleu(generated_summaries, [reference_summaries])
print("\n👉  BLEU Score:", bleu.score)

# Compare some examples
for i in range(5):
    print(f"\nExample {i + 1}:")
    print("Input Text:", train_list['input'][i])
    print("Generated Summary:", generated_summaries[i])
    print("Reference Summary:", reference_summaries[i])

In [None]:
# Load ROUGE metric
rouge = evaluate.load('rouge')

# Calculate ROUGE scores
results = rouge.compute(predictions=generated_summaries, references=reference_summaries)

print("\n👉  ROUGE Scores:", results)

# Compare some examples
for i in range(5):
    print(f"\nExample {i + 1}:")
    print("Input Text:", train_list['input'][i])
    print("Generated Summary:", generated_summaries[i])
    print("Reference Summary:", reference_summaries[i])