In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121  # (nếu có GPU CUDA 12.1)
!pip install transformers datasets evaluate nltk
!pip install sentencepiece  # cần cho tokenizer BART/T5


# 0. Import libraries

In [None]:
import nltk
import numpy as np
from datasets import load_dataset
import evaluate
from transformers import (
    BartForConditionalGeneration,
    BartTokenizerFast,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)

nltk.download("punkt")

# 1. Load dataset

In [None]:
dataset = load_dataset("cnn_dailymail", "3.0.0")

# 2. Load tokenizer & model

In [None]:
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizerFast.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

#  3. Preprocess function

In [None]:
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["article"]]
    targets = [tgt for tgt in examples["highlights"]]
    model_inputs = tokenizer(
        inputs, max_length=max_input_length, truncation=True, padding="max_length"
    )
    labels = tokenizer(
        targets, max_length=max_target_length, truncation=True, padding="max_length"
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(
    preprocess_function, 
    batched=True, 
    remove_columns=["article", "highlights", "id"]
)


# 4. Data collator & metrics

In [None]:
!pip install rouge_score

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Tính ROUGE
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {key: round(value * 100, 2) for key, value in result.items()}
    return result


# 5. Training config

In [None]:
!pip install --upgrade transformers accelerate


In [None]:
import transformers
print(transformers.__version__)


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./models/bart_summarizer",
    learning_rate=5e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=15,  
    predict_with_generate=True,
    logging_dir="./logs",
    logging_steps=100,
    save_strategy="epoch",
    fp16=True,
)


In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],  
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# 6. Train

In [None]:
import wandb
wandb.login(key="2239cd9d93e77db267b258d6da608bf2a7e5a516")

In [None]:
trainer.train()
trainer.save_model("./models/bart_summarizer")
tokenizer.save_pretrained("./models/bart_summarizer")

# chạy evaluate sau training
results = trainer.evaluate()
print("ROUGE scores:", results)