# 01 — Data → BART‑base Baseline

**What/Why**: Fine‑tune `facebook/bart-base` on CNN/DailyMail and establish a baseline with **ROUGE‑1/2/L** and **BERTScore** on the validation split.

Cells are fully commented; each step ends with **How to read results**.

### Literature Grounding
- **Summarizer:** We fine-tune `facebook/bart-base`, a denoising sequence-to-sequence Transformer introduced by Lewis et al. (2020) for abstractive summarization and other generation tasks.
- **Dataset:** CNN/DailyMail v3.0.0 provides news articles paired with human-written highlights (Hermann et al., 2015; See et al., 2017).

The training objective is the token-level negative log-likelihood:
\[
\mathcal{L}(\theta) = -\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}, x),
\]
where \(x\) is the article, \(y_t\) the reference summary tokens, and \(p_\theta\) the BART decoder distribution.



## Imports & seed

In [0]:
import os, json, random, time, pathlib
import numpy as np
import torch

from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq,
                          Seq2SeqTrainingArguments, Seq2SeqTrainer)

# Seeds
SEED = json.load(open('configs/run.json'))['seed']
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

# How to read results:
# We fix seeds and show the device. Baselines will be reproducible modulo GPU nondeterminism.


## Load dataset & tokenizer

In [0]:
cfg = json.load(open('configs/run.json'))
dataset_id = cfg['dataset_id']; dataset_config = cfg['dataset_config']

# Load full dataset
ds = load_dataset(dataset_id, dataset_config)

# Tokenizer/model
model_id = cfg['baseline_model']
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# Summarization task: set max lengths (conservative defaults)
MAX_INPUT = 1024
MAX_TARGET = 128

def preprocess(batch):
    # Line-by-line comments:
    # 1) Tokenize articles (inputs) and highlights (labels) with truncation
    inputs = tokenizer(batch['article'], max_length=MAX_INPUT, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch['highlights'], max_length=MAX_TARGET, truncation=True)
    inputs['labels'] = labels['input_ids']
    return inputs

# Map with multiprocessing for speed
columns = ['input_ids', 'attention_mask', 'labels']
ds_tokenized = ds.map(preprocess, batched=True, remove_columns=ds['train'].column_names)
ds_tokenized.set_format(type='torch', columns=columns)

print(ds_tokenized)

# How to read results:
# You should see train/validation/test arrow datasets with tokenized columns. 


## Data collator & model

In [0]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=None)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)

# How to read results:
# We use HF's default seq2seq collator; the model loads with its pretrained weights.


## Training arguments & Trainer

In [0]:
# Create output dir
out_dir = 'runs/bart-base'
os.makedirs(out_dir, exist_ok=True)

# Reasonable small training config for Colab T4/A100; adjust as needed
args = Seq2SeqTrainingArguments(
    output_dir=out_dir,
    evaluation_strategy='steps',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    weight_decay=0.01,
    num_train_epochs=1,              # increase to 3+ for a stronger baseline
    lr_scheduler_type='linear',
    warmup_ratio=0.03,
    logging_steps=50,
    eval_steps=200,
    save_steps=200,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    seed=SEED,
    bf16=True if torch.cuda.is_available() else False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=ds_tokenized['train'].select(range(2000)),      # small subset for demo; remove select for full training
    eval_dataset=ds_tokenized['validation'].select(range(1000)),  # small subset for demo; remove select for full eval
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print('Trainer ready.')

# How to read results:
# Training config is printed in logs; we train/eval on subsets for speed. Remove selects for the full run.


## Train

In [0]:
train_result = trainer.train()
trainer.save_model(out_dir)
metrics = train_result.metrics
print(metrics)

# Persist metrics
os.makedirs('results', exist_ok=True)
json.dump(metrics, open('results/train.metrics.json', 'w'), indent=2)

# How to read results:
# 'results/train.metrics.json' stores training metrics. The checkpoint is in 'runs/bart-base'.


## Validation: generate & compute ROUGE and BERTScore

In [0]:
from tqdm.auto import tqdm
import evaluate

# Generate summaries for a manageable slice (adjust up as needed)
val_ds = ds['validation'].select(range(500))

model.eval()
preds, refs = [], []
for ex in tqdm(val_ds):
    inputs = tokenizer(ex['article'], return_tensors='pt', truncation=True, max_length=1024).to(device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=128, num_beams=4, length_penalty=2.0)
    pred = tokenizer.decode(out[0], skip_special_tokens=True)
    preds.append(pred)
    refs.append(ex['highlights'])

# Save predictions
with open('results/val.baseline.jsonl', 'w', encoding='utf-8') as f:
    for p, r in zip(preds, refs):
        f.write(json.dumps({'pred': p, 'ref': r}, ensure_ascii=False) + '\n')
print('Wrote results/val.baseline.jsonl')

# ROUGE
rouge = evaluate.load('rouge')
rouge_scores = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
print('ROUGE:', rouge_scores)
json.dump(rouge_scores, open('results/val.baseline.rouge.json', 'w'), indent=2)

# BERTScore
bertscore = evaluate.load('bertscore')
bs = bertscore.compute(predictions=preds, references=refs, lang='en')
bertscore_mean = {k: float(np.mean(v)) for k, v in bs.items() if isinstance(v, list)}
print('BERTScore (mean):', bertscore_mean)
json.dump({'per_example': bs, 'mean': bertscore_mean}, open('results/val.baseline.bertscore.json', 'w'), indent=2)

# How to read results:
# ROUGE and BERTScore are printed and saved. These form your baseline metrics.
