In [1]:
# Cell 1: Install dependencies
!pip install -q transformers datasets rouge_score evaluate

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


In [2]:
# Cell 2: Import libraries
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Trainer
import evaluate
import numpy as np

In [13]:
# Cell 3: Load dataset and manually select small subset
full_dataset = load_dataset("cnn_dailymail", "3.0.0")
dataset = {
    "train": full_dataset["train"].select(range(5000)),         # first 1000 samples
    "validation": full_dataset["validation"].select(range(200)),
    "test": full_dataset["test"].select(range(200)),
}

In [14]:
# Cell 4: Load tokenizer and model (T5-small for speed)
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [15]:
# Cell 5: Preprocessing parameters
max_input_length = 512
max_target_length = 128

def preprocess(ex):
    # Join list of articles into one string if needed
    article = " ".join(ex["article"]) if isinstance(ex["article"], list) else ex["article"]

    # Prefix for T5 summarization
    src = "summarize: " + article
    tgt = ex["highlights"]

    # Tokenize source and target
    inp = tokenizer(src, max_length=max_input_length, truncation=True, padding="max_length")
    tgt_tok = tokenizer(
        tgt,
        max_length=max_target_length,
        truncation=True,
        padding="max_length"
    )

    # Replace pad tokens in labels with -100
    labels = tgt_tok["input_ids"]
    labels = [label if label != tokenizer.pad_token_id else -100 for label in labels]

    inp["labels"] = labels

    return inp

# Apply preprocessing to each split in the dataset dict
tokenized = {
    split: ds.map(preprocess, batched=False, remove_columns=ds.column_names)
    for split, ds in dataset.items()
}

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [16]:
# Cell 6: Training setup
batch_size = 4
args = Seq2SeqTrainingArguments(
    output_dir="summarizer",
    eval_strategy="no",
    learning_rate=3e-5,
    gradient_accumulation_steps=batch_size,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=6,
    weight_decay=0.01,
    save_total_limit=1,
    logging_steps=10,
    predict_with_generate=True,
    report_to="none",
    eval_accumulation_steps=10,
    generation_max_length=64,
    generation_num_beams=1,
    push_to_hub=False,
    load_best_model_at_end=False,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [17]:
# Cell 7: Clear memory
from transformers import TrainerCallback

class ClearMemoryCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        import gc, torch
        gc.collect()
        torch.cuda.empty_cache()

In [18]:
# Cell 8: Train
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=None,
    callbacks=[ClearMemoryCallback()],
)
trainer.train()

  trainer = Trainer(


Step,Training Loss
10,2.6077
20,2.5155
30,2.2848
40,2.4086
50,2.2054
60,2.2176
70,2.2749
80,2.3097
90,2.2235
100,2.2133


TrainOutput(global_step=1878, training_loss=2.1260751520102015, metrics={'train_runtime': 1367.8907, 'train_samples_per_second': 21.932, 'train_steps_per_second': 1.373, 'total_flos': 4060254044160000.0, 'train_loss': 2.1260751520102015, 'epoch': 6.0})

In [19]:
# Cell 9: Validation
from torch.utils.data import DataLoader
import torch

def evaluate_model(model, tokenizer, dataset, data_collator, batch_size=4):
    model.eval()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    all_preds = []
    all_labels = []

    for batch in dataloader:
        # Move to device
        batch = {k: v.to(model.device) for k, v in batch.items()}

        # Generate summaries
        with torch.no_grad():
            outputs = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=64,
                num_beams=1
            )

        # Decode predictions and references
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        labels = batch["labels"].cpu().numpy()
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        all_preds.extend([p.strip() for p in decoded_preds])
        all_labels.extend([l.strip() for l in decoded_labels])

        # Optional: clear memory manually
        del batch, outputs
        torch.cuda.empty_cache()

    # Compute ROUGE
    rouge_metric = evaluate.load("rouge")
    rouge_metric.add_batch(predictions=all_preds, references=all_labels)
    results = rouge_metric.compute(use_stemmer=True)
    return {k: round(v * 100, 2) for k, v in results.items()}

results = evaluate_model(model, tokenizer, tokenized["validation"], data_collator)
print("Validation ROUGE:", results)

Validation ROUGE: {'rouge1': np.float64(31.1), 'rouge2': np.float64(12.01), 'rougeL': np.float64(22.15), 'rougeLsum': np.float64(22.2)}


In [23]:
# Cell 10: Generate a sample
sample = dataset["test"][74]
src = "summarize: " + sample["article"]

# Tokenize input
inputs = tokenizer(src, return_tensors="pt", truncation=True, max_length=512)

# Move model and inputs to same device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate summary
outputs = model.generate(**inputs, max_new_tokens=60)

# Decode and print
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
print("Reference:", sample["highlights"])

Generated: Three people killed and five wounded in clashes with armed assailants. A security guard of the provincial attorney general's office among the dead.
Reference: Three people killed; five wounded in attack on attorney general's office in Balkh province .
Staff and civilians have been rescued as gunmen engaged Afghan security forces .


In [21]:
# Cell 11: Test
results = evaluate_model(model, tokenizer, tokenized["test"], data_collator)
print("Test ROUGE:", results)

Test ROUGE: {'rouge1': np.float64(30.56), 'rouge2': np.float64(10.74), 'rougeL': np.float64(22.04), 'rougeLsum': np.float64(22.1)}
