In [2]:
import pandas as pd
import torch
import transformers
from datasets import Dataset
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
)


In [20]:
train_df = pd.read_csv("train_clean.csv")
val_df   = pd.read_csv("val_clean.csv")
test_df  = pd.read_csv("test_clean.csv")

print("Train shape:", train_df.shape)
print("Val shape:",   val_df.shape)
print("Test shape:",  test_df.shape)

print(train_df.head(2))
print(train_df.columns)

Train shape: (21671, 2)
Val shape: (2408, 2)
Test shape: (2676, 2)
                                                body  \
0  <SEX> M <SERVICE> PODIATRY <ALLERGIES> No Know...   
1  <SEX> M <SERVICE> MEDICINE <ALLERGIES> Codeine...   

                                             summary  
0  Mr. ___ was admitted after presenting to the E...  
1  ___ year old with a history of alcoholism, wit...  
Index(['body', 'summary'], dtype='object')


In [21]:
# BART’s Trainer works best with datasets.Dataset objects
train_ds = Dataset.from_pandas(train_df)
val_ds   = Dataset.from_pandas(val_df)
test_ds  = Dataset.from_pandas(test_df)

train_ds, val_ds, test_ds

(Dataset({
     features: ['body', 'summary'],
     num_rows: 21671
 }),
 Dataset({
     features: ['body', 'summary'],
     num_rows: 2408
 }),
 Dataset({
     features: ['body', 'summary'],
     num_rows: 2676
 }))

In [22]:
# model_name = "facebook/bart-base"  # <-- base, not large

# tokenizer = BartTokenizer.from_pretrained(model_name)
# model     = BartForConditionalGeneration.from_pretrained(model_name)

model_name = "./bart_base_mimic_checkpoints/checkpoint-latest"

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained(model_name)

In [10]:
summary_lengths = val_df["summary"].str.split().str.len()
print(summary_lengths.describe())

count    2408.000000
mean      387.260382
std       265.220139
min        43.000000
25%       204.000000
50%       318.000000
75%       506.000000
max      2188.000000
Name: summary, dtype: float64


In [23]:
# we’ll use longer max length for the note
max_input_length = 1024     # BART-base max positions
max_target_length = 512     # updated based on your summary stats to get optimal results

def tokenize_batch(batch):
    # encode the input medical note
    model_inputs = tokenizer(
        batch["body"],
        max_length=max_input_length,
        padding="max_length",   # fixed padding
        truncation=True,
    )

    # encodes the target summary
    labels = tokenizer(
        text_target=batch["summary"],
        max_length=max_target_length,
        padding="max_length",
        truncation=True,
    )

    # trainer expects labels["input_ids"]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [24]:
train_tok = train_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=train_ds.column_names,
)

val_tok = val_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=val_ds.column_names,
)

test_tok = test_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=test_ds.column_names,
)

train_tok, val_tok, test_tok

Map:   0%|          | 0/21671 [00:00<?, ? examples/s]

Map:   0%|          | 0/2408 [00:00<?, ? examples/s]

Map:   0%|          | 0/2676 [00:00<?, ? examples/s]

(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 21671
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 2408
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 2676
 }))

In [8]:
# copies the latest numbered checkpoint into a stable name

import os
import shutil

def copy_latest_checkpoint(output_dir):
    checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
    if not checkpoints:
        print("No checkpoints found yet.")
        return
    
    # sorts by step number
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
    latest = checkpoints[-1]

    src = os.path.join(output_dir, latest)
    dst = os.path.join(output_dir, "checkpoint-latest")

    # deletes the old stable checkpoint folder
    if os.path.exists(dst):
        shutil.rmtree(dst)

    shutil.copytree(src, dst)
    print(f"Saved latest checkpoint → {dst}")

In [25]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
)

In [26]:
use_fp16 = False   # M2 cannot use CUDA FP16, so we keep this False

training_args = TrainingArguments(
    output_dir="bart_base_mimic_checkpoints",
    save_strategy="epoch",              
    learning_rate=2e-5,                 # good LR for BART fine-tuning
    per_device_train_batch_size=1,      
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    
    # training duration
    num_train_epochs=4,                 
                                          

    #predict_with_generate=True,         
    fp16=use_fp16,                      
    logging_steps=100,

    save_total_limit=1, # keep only the latest checkpoint
    report_to="none",                   
)

In [27]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

#trainer.train() # for first epoch trained, then it trains from the checkpoint
trainer.train(resume_from_checkpoint="./bart_base_mimic_checkpoints/checkpoint-latest")
copy_latest_checkpoint("bart_base_mimic_checkpoints")

  trainer = Trainer(
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


Step,Training Loss
43400,2.5496
43500,3.0817
43600,2.3721
43700,2.2611
43800,2.2285
43900,2.2101
44000,2.4033
44100,2.1858
44200,2.1521
44300,2.0435




Saved latest checkpoint → bart_base_mimic_checkpoints/checkpoint-latest


In [28]:
test_metrics = trainer.evaluate(test_tok)
print(test_metrics)



{'eval_loss': 1.8920406103134155, 'eval_runtime': 395.0544, 'eval_samples_per_second': 6.774, 'eval_steps_per_second': 6.774, 'epoch': 4.0}


In [29]:
import evaluate

rouge = evaluate.load("rouge")

def compute_rouge(trainer, dataset, tokenizer, max_samples=200):
    preds = []
    refs = []

    # force everything to run on CPU 
    device = torch.device("cpu")
    trainer.model.to(device)

    # limit the sample count for speed
    n = min(len(dataset), max_samples)

    for i in range(n):
        item = dataset[i]

        # Converts the model inputs to CPU tensors
        model_inputs = {
            "input_ids": torch.tensor(item["input_ids"]).unsqueeze(0).to(device),
            "attention_mask": torch.tensor(item["attention_mask"]).unsqueeze(0).to(device),
        }

        # Generates the BART summary
        with torch.no_grad():
            generated_ids = trainer.model.generate(
                **model_inputs,
                max_length=max_target_length, 
                num_beams=4,
                early_stopping=True,
            )

        pred_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        preds.append(pred_text)

        # Decoded reference summary
        label_ids = [x for x in item["labels"] if x != -100]
        ref_text = tokenizer.decode(label_ids, skip_special_tokens=True)
        refs.append(ref_text)

    # Computes ROUGE
    scores = rouge.compute(predictions=preds, references=refs)
    # convert numpy float64 -> Python float
    scores = {k: float(v) for k, v in scores.items()}

    return scores


# Runs ROUGE
rouge_scores = compute_rouge(trainer, test_tok, tokenizer, max_samples=200)
# rouge_scores = compute_rouge(trainer, test_tok, tokenizer, max_samples=1)
print(rouge_scores)

{'rouge1': 0.39813846294729405, 'rouge2': 0.1527079656759699, 'rougeL': 0.23280175865510655, 'rougeLsum': 0.23196438433243122}


In [30]:
# picks one example from the test set
sample_idx  = 0  # change this to inspect different notes
sample_body = test_df.iloc[sample_idx]["body"]
sample_ref  = test_df.iloc[sample_idx]["summary"]

print("=== Original note (body) ===")
print(sample_body[:1500], "...")   # truncates for display
print("\n=== Reference summary ===")
print(sample_ref)
print("\n=== Model summary (BART-base) ===")

# tokenizes the input note for generation
inputs = tokenizer(
    sample_body,
    max_length=max_input_length,
    truncation=True,
    return_tensors="pt",
)

# move to GPU if available
model = model.to("cpu")
inputs = {k: v.to("cpu") for k, v in inputs.items()}


# generates summary
with torch.no_grad():
    generated_ids = model.generate(
        **inputs,
        max_length=max_target_length,
        num_beams=4,
        length_penalty=1.0,
        early_stopping=True,
    )

generated_summary = tokenizer.decode(
    generated_ids[0],
    skip_special_tokens=True,
)

print(generated_summary)

=== Original note (body) ===
<SEX> M <SERVICE> SURGERY <ALLERGIES> hydromorphone <ATTENDING> ___. <CHIEF COMPLAINT> Ventral hernia <MAJOR SURGICAL OR INVASIVE PROCEDURE> Ventral hernia repair with component separation and mesh onlay (over posterior sheath) <HISTORY OF PRESENT ILLNESS> ___ man with prior liver resection who now presents for elective repair of asymptomatic incisional hernia. <PAST MEDICAL HISTORY> hypertension, hyperlipidemia, bowel obstructions and gastric ulcer disease. Prior surgeries include sphincterotomy for anal fissure and bilateral inguinal hernia repairs. <SOCIAL HISTORY> ___ <FAMILY HISTORY> Both parents have coronary artery disease. <PERTINENT RESULTS> ___ 03: 04PM BLOOD WBC-7.9# RBC-3.84* Hgb-12.8* Hct-35.8* MCV-93 MCH-33.3* MCHC-35.8 RDW-14.2 RDWSD-48.1* Plt ___ ___ 05: 55AM BLOOD WBC-10.0 RBC-3.73* Hgb-12.3* Hct-36.0* MCV-97 MCH-33.0* MCHC-34.2 RDW-14.5 RDWSD-51.7* Plt ___ ___ 03: 04PM BLOOD Glucose-129* UreaN-19 Creat-0.9 Na-141 K-5.5* Cl-105 HCO3-19* AnG