# Gemma 2 Fine-Tuning for Abstractive Medical Summaries

This notebook fine-tunes the Hugging Face `google/gemma-2` model on the provided abstractive summarization dataset.


In [None]:
!pip install -q transformers datasets accelerate peft bitsandbytes sentencepiece evaluate rouge_score bert-score

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


In [None]:
!ls

sample_data


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!ls drive/MyDrive/COLAB/NLP

gemma2_finetune.ipynb  outputs	   train.json	validation.json
nlp_gemma_2	       test.json   train.jsonl	validation.jsonl
nlp_t5_small	       test.jsonl  Untitled


In [None]:
from huggingface_hub import login
from pathlib import Path

login()

DATA_DIR = Path("/content/drive/MyDrive/COLAB/NLP/nlp_gemma_2")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from pathlib import Path
import json
from typing import Dict, List

import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from tqdm import tqdm

OUTPUT_DIR = Path("/content/drive/MyDrive/COLAB/NLP/nlp_gemma_2/outputs/gemma2_abs_sum")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "google/gemma-2-2b"

PROMPT_TEMPLATE = """
### Instruction:
You are a helpful medical summarization assistant. Given a clinical question and an article, provide a short abstractive answer that directly addresses the question.

### Question:
{question}

### Article:
{article}

### Response:
"""

MAX_INPUT_TOKENS = 2048
GEN_MAX_NEW_TOKENS = 256

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

HAS_CUDA = torch.cuda.is_available()

if HAS_CUDA:
    major_capability, _ = torch.cuda.get_device_capability(0)

    USE_BF16 = major_capability >= 8
    USE_FP16 = not USE_BF16
else:
    USE_BF16 = False
    USE_FP16 = False

if USE_BF16:
    DTYPE = torch.bfloat16
elif USE_FP16:
    DTYPE = torch.float16
else:
    DTYPE = torch.float32


In [None]:
def _load_split(split_name: str) -> Dataset:
    # Load one dataset split (train/validation/test) from our JSON files.
    file_path = DATA_DIR / f"{split_name}.json"
    raw_data = json.loads(file_path.read_text())

    items: List[Dict] = []

    # Each entry is grouped by question ID
    for qid, entry in raw_data.items():
        question_text = entry.get("question", "").strip()

        # Each question can have multiple answer articles
        for aid, ans in entry.get("answers", {}).items():
            article_text = ans.get("article", "").strip()
            target_text = ans.get("answer_abs_summ", "").strip()

            # Skip empty articles or summaries
            if not article_text or not target_text:
                continue

            items.append({
                "id": f"{qid}_{aid}",
                "question": question_text,
                "article": article_text,
                "target": target_text,
            })

    return Dataset.from_list(items)


# Load all splits
train_ds = _load_split("train")
validation_ds = _load_split("validation")
test_ds = _load_split("test")

print(train_ds[0])
print(f"train={len(train_ds)} | validation={len(validation_ds)} | test={len(test_ds)}")


{'id': '133_133_Answer2', 'question': 'how much oxazepam could cause an overdose?', 'article': "Oxazepam overdose Benzodiazepine overdose Serax overdose Adumbran overdose Serenid Forte overdose Zapex overdose Novoxapam overdose Oxpam overdose Summary Oxazepam is a medicine used to treat anxiety and symptoms of alcohol withdrawal. It belongs to the class of medicines known as benzodiazepines. Oxazepam overdose occurs when someone accidentally or intentionally takes too much of this medicine. Benzodiazepines are the most common prescription drugs used in suicide attempts. This is for information only and not for use in the treatment or management of an actual overdose. DO NOT use it to treat or manage an actual overdose. If you or someone you are with overdoses, call your local emergency number (such as 911), or your local poison center can be reached directly by calling the national toll-free Poison Help hotline (1-800-222-1222) from anywhere in the United States. Poisonous Ingredient O

In [None]:
def format_prompt(example: Dict[str, str]) -> str:
    # Fill the training prompt with the question and article text.
    return PROMPT_TEMPLATE.format(
        question=example["question"],
        article=example["article"]
    )


def generate_predictions(model, dataset: Dataset, outfile: Path, split_name: str):
    # Run the model and generate summaries, and save everything to a JSON file.

    model.eval()
    results = []

    for row in tqdm(dataset, total=len(dataset), desc=f"{split_name} predictions"):
        # Build input prompt
        prompt = format_prompt(row)

        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_INPUT_TOKENS,
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=GEN_MAX_NEW_TOKENS,
                do_sample=False,
                temperature=0.0,
            )

        gen_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]

        prediction_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

        results.append({
            "id": row["id"],
            "question": row["question"],
            "reference": row["target"],      # gold summary
            "prediction": prediction_text,   # model's output
        })

    payload = {
        "model_name": model.config._name_or_path,
        "split": split_name,
        "predictions": results,
    }

    # Save predictions
    outfile.parent.mkdir(parents=True, exist_ok=True)
    outfile.write_text(json.dumps(payload, indent=2))
    print(f"Saved {len(results)} predictions to {outfile}")


# ----- Lets Generating baseline predictions -----

baseline_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto" if HAS_CUDA else None,
)

# Run on validation and test sets
for split_name, ds in [("validation", validation_ds), ("test", test_ds)]:
    output_path = OUTPUT_DIR / f"baseline_{split_name}_predictions.json"
    generate_predictions(baseline_model, ds, output_path, split_name)

# Free up GPU memory
del baseline_model


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

validation predictions:   0%|          | 0/51 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
validation predictions: 100%|██████████| 51/51 [07:27<00:00,  8.77s/it]


Saved 51 predictions to outputs/gemma2_abs_sum/baseline_validation_predictions.json


test predictions: 100%|██████████| 109/109 [15:52<00:00,  8.74s/it]

Saved 109 predictions to outputs/gemma2_abs_sum/baseline_test_predictions.json





In [None]:
# Add the formatted prompt to each dataset item
def add_prompt(example):
    example["prompt"] = format_prompt(example)
    return example

train_prompt = train_ds.map(add_prompt)
validation_prompt = validation_ds.map(add_prompt)


def tokenize_example(example):

    full_text = example["prompt"] + example["target"] + tokenizer.eos_token

    # Tokenize the combined text
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
    )

    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized


tokenized_train = train_prompt.map(
    tokenize_example,
    remove_columns=train_prompt.column_names,
)

tokenized_validation = validation_prompt.map(
    tokenize_example,
    remove_columns=validation_prompt.column_names,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

print(tokenized_train[0].keys())


Map:   0%|          | 0/392 [00:00<?, ? examples/s]

Map:   0%|          | 0/51 [00:00<?, ? examples/s]

Map:   0%|          | 0/392 [00:00<?, ? examples/s]

Map:   0%|          | 0/51 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'attention_mask', 'labels'])


In [None]:
# Loadingt the base model (Gemma-2)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto" if HAS_CUDA else None,
)


# Training configuration
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR / "checkpoints"),
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    warmup_ratio=0.03,
    logging_steps=5,
    save_steps=50,
    save_total_limit=1,
    bf16=USE_BF16,
    fp16=USE_FP16,
    weight_decay=0.01,
    report_to="none",
)


# HuggingFace Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_validation,
    data_collator=data_collator,
)

# Starting the fine-tuning
trainer.train()

# Save the final model and tokenizer
save_path = OUTPUT_DIR / "finetuned-model"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved to: {save_path}")


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
5,1.4858
10,1.575
15,1.4788
20,1.3987
25,1.4575
30,1.3547
35,1.3156
40,1.2517
45,1.3447
50,1.2194


Model saved to: outputs/gemma2_abs_sum/finetuned-model


In [None]:
# Load the fine-tuned model back
finetuned_model = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR / "finetuned-model",
    torch_dtype=DTYPE,
    device_map="auto" if HAS_CUDA else None,
)

# Now Running predictions on the validation and test splits
for split_name, dataset_split in [("validation", validation_ds), ("test", test_ds)]:
    output_path = OUTPUT_DIR / f"finetuned_{split_name}_predictions.json"
    generate_predictions(
        finetuned_model,
        dataset_split,
        output_path,
        split_name
    )

print("Finished generating predictions for the fine-tuned model.")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

validation predictions: 100%|██████████| 51/51 [04:17<00:00,  5.05s/it]


Saved 51 predictions to outputs/gemma2_abs_sum/finetuned_validation_predictions.json


test predictions:  60%|█████▉    | 65/109 [06:02<03:08,  4.28s/it]

In [None]:
import evaluate
from bert_score import score as bertscore

# Load ROUGE evaluator
rouge = evaluate.load("rouge")

def show_scores(path: Path):
    payload = json.loads(path.read_text())

    preds = [item["prediction"] for item in payload["predictions"]]
    refs  = [item["reference"] for item in payload["predictions"]]

    # ROUGE evaluation
    rouge_scores = rouge.compute(
        predictions=preds,
        references=refs
    )

    # BERTScore evaluation
    P, R, F1 = bertscore(
        preds,
        refs,
        lang="en",
        model_type="bert-base-uncased"
    )
    bert_f1 = float(F1.mean())

    print(f"\n=== {path.name} ===")
    print("ROUGE scores:")
    for k, v in rouge_scores.items():
        print(f"  {k}: {round(v, 4)}")

    print(f"BERTScore F1: {round(bert_f1, 4)}")


files_to_check = [
    OUTPUT_DIR / "baseline_validation_predictions.json",
    OUTPUT_DIR / "finetuned_validation_predictions.json",
    OUTPUT_DIR / "baseline_test_predictions.json",
    OUTPUT_DIR / "finetuned_test_predictions.json",
]

for file in files_to_check:
    if file.exists():
        show_scores(file)


Downloading builder script: 0.00B [00:00, ?B/s]


=== baseline_validation_predictions.json ===
ROUGE scores:
  rouge1: 0.1878
  rouge2: 0.0578
  rougeL: 0.1275
  rougeLsum: 0.1291
BERTScore F1: 0.522

=== finetuned_validation_predictions.json ===
ROUGE scores:
  rouge1: 0.2596
  rouge2: 0.1044
  rougeL: 0.1882
  rougeLsum: 0.1889
BERTScore F1: 0.5764

=== baseline_test_predictions.json ===
ROUGE scores:
  rouge1: 0.1995
  rouge2: 0.0704
  rougeL: 0.1334
  rougeLsum: 0.1351
BERTScore F1: 0.5311

=== finetuned_test_predictions.json ===
ROUGE scores:
  rouge1: 0.2618
  rouge2: 0.1055
  rougeL: 0.1846
  rougeLsum: 0.1854
BERTScore F1: 0.5795
