In [1]:
!pip3 install transformers datasets torch accelerate trl rouge_score evaluate

Collecting trl
  Downloading trl-0.17.0-py3-none-any.whl.metadata (12 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-many

In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
from datasets import load_dataset
from transformers import (
    GPT2Config,
    GPT2TokenizerFast,
    Trainer,
    TrainingArguments,
)
import evaluate

2025-05-16 09:31:41.735696: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747387901.899010      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747387901.948182      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# %%
# 3) Tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
class GPT2ForSummarization(nn.Module):
    def __init__(self, model_name="gpt2"):
        super().__init__()
        self.config = GPT2Config.from_pretrained(model_name)
        self.transformer = torch.hub.load('huggingface/pytorch-transformers',
                                          'model', 'gpt2', config=self.config) \
                              if False else \
                          __import__('transformers').GPT2Model.from_pretrained(model_name, config=self.config)
        # explicit LM head
        self.lm_head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None, labels=None):
        # 1) get hidden states
        out = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        hidden = out.last_hidden_state                   # [B, L, D]
        logits = self.lm_head(hidden)                    # [B, L, V]

        loss = None
        if labels is not None:
            # shift for LM loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
        return {"loss": loss, "logits": logits}


In [5]:
MAX_ART_LEN = 512
MAX_SUM_LEN = 128
MAX_SEQ_LEN = MAX_ART_LEN + 1 + MAX_SUM_LEN + 1  # +1 for eos tokens

def preprocess(example):
    # tokenize article & summary
    art = tokenizer(example["article"], truncation=True, max_length=MAX_ART_LEN, add_special_tokens=False)
    sumy = tokenizer(example["highlights"], truncation=True, max_length=MAX_SUM_LEN, add_special_tokens=False)

    # build input_ids: article + eos + summary + eos
    input_ids = art["input_ids"] + [tokenizer.eos_token_id] + sumy["input_ids"] + [tokenizer.eos_token_id]
    labels    = [-100] * (len(art["input_ids"]) + 1) + sumy["input_ids"] + [tokenizer.eos_token_id]

    # pad/truncate to MAX_SEQ_LEN
    if len(input_ids) < MAX_SEQ_LEN:
        pad_len = MAX_SEQ_LEN - len(input_ids)
        input_ids += [tokenizer.pad_token_id] * pad_len
        labels    += [-100] * pad_len
    else:
        input_ids = input_ids[:MAX_SEQ_LEN]
        labels    = labels[:MAX_SEQ_LEN]

    return {"input_ids": input_ids, "attention_mask": [1 if i!=tokenizer.pad_token_id else 0
                                                       for i in input_ids],
            "labels": labels}

# load splits
train_raw = load_dataset("cnn_dailymail", "3.0.0", split="train[:50000]")
val_raw   = load_dataset("cnn_dailymail", "3.0.0", split="validation[:10000]")

# map preprocess
train_ds = train_raw.map(preprocess, remove_columns=["article","highlights","id"])
val_ds   = val_raw.map(preprocess,   remove_columns=["article","highlights","id"])


README.md:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [6]:
training_args = TrainingArguments(
    output_dir="sft",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=torch.cuda.is_available(),
    report_to=[]
)

# %%
# 7) Trainer Instantiation
model = GPT2ForSummarization("gpt2").to(device)

def collate_fn(batch):
    return {
        "input_ids":      torch.tensor([b["input_ids"] for b in batch], dtype=torch.long),
        "attention_mask": torch.tensor([b["attention_mask"] for b in batch], dtype=torch.long),
        "labels":         torch.tensor([b["labels"] for b in batch], dtype=torch.long),
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    tokenizer=tokenizer,
)

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

  trainer = Trainer(


In [7]:
trainer.train()
model_save_dir = "sft"
model.transformer.save_pretrained(model_save_dir)
torch.save(model.lm_head.state_dict(), f"{model_save_dir}/lm_head.bin")
tokenizer.save_pretrained(model_save_dir)

Epoch,Training Loss,Validation Loss
1,2.8093,3.216357
2,2.4497,2.997135
3,2.1722,2.95212


('sft/tokenizer_config.json',
 'sft/special_tokens_map.json',
 'sft/vocab.json',
 'sft/merges.txt',
 'sft/added_tokens.json',
 'sft/tokenizer.json')

In [8]:
from tqdm.auto import tqdm
MAX_SUM_LEN = 128

device = "cuda" if torch.cuda.is_available() else "cpu"

def generate_summary(article: str) -> str:
    inputs = tokenizer(
        article + tokenizer.eos_token,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(device)
    generated = inputs.input_ids  

    for _ in range(MAX_SUM_LEN):
        outputs = model(
            input_ids=generated,
            attention_mask=(generated != tokenizer.pad_token_id).long()
        )
        logits = outputs["logits"]  
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  # [1,1]
        generated = torch.cat([generated, next_token], dim=1)       # [1, L+1]

        if next_token.item() == tokenizer.eos_token_id:
            break
    summary_ids = generated[0, inputs.input_ids.size(1):].tolist()
    return tokenizer.decode(summary_ids, skip_special_tokens=True).strip()

test_ds = load_dataset("cnn_dailymail", "3.0.0", split="test[:5000]")

rouge = evaluate.load("rouge")
preds, refs = [], []

for ex in tqdm(test_ds, desc="Generating summaries"):
    preds.append(generate_summary(ex["article"]))
    refs.append(ex["highlights"])

# 4) Compute ROUGE scores
scores = rouge.compute(
    predictions=preds,
    references=refs,
    use_stemmer=True
)

# 5) Print formatted results
print("Evaluation ROUGE scores (in %):")
for k, v in scores.items():
    print(f"  {k}: {v * 100:.2f}")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Generating summaries:   0%|          | 0/5000 [00:00<?, ?it/s]

Evaluation ROUGE scores (in %):
  rouge1: 29.95
  rouge2: 10.16
  rougeL: 20.86
  rougeLsum: 27.80
