In [None]:
from datasets import load_dataset
from transformers import BartTokenizerFast, BartForConditionalGeneration

def preprocess_data(example, tokenizer):
    return tokenizer(
        example["document"],
        text_target=example["summary"],
        truncation=True
    )

model_name = "gogamza/kobart-base-v2"
tokenizer = BartTokenizerFast.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

dataset = load_dataset("daekeun-ml/naver-news-summarization-ko")
print(dataset)

tokenizer.model_max_length = model.config.max_position_embeddings
processed_dataset = dataset.map(
    lambda example: preprocess_data(example, tokenizer),
    batched=True,
    remove_columns=dataset["train"].column_names
)

sample = processed_dataset["train"]["labels"][0]
print(sample)
print(tokenizer.decode(sample))

In [None]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

seq2seq_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding="longest",
    return_tensors="pt"
)

seq2seq_dataloader = DataLoader(
    processed_dataset["train"],
    collate_fn=seq2seq_collator,
    batch_size=4,
    shuffle=False
)

seq2seq_iterator = iter(seq2seq_dataloader)
seq2seq_batch = next(seq2seq_iterator)
for key, value in seq2seq_batch.items():
    print(f"{key} : {value.shape}")

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="text-summarization",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=5e-5,
    num_train_epochs=1,
    eval_steps=200,
    logging_steps=200,
    seed=42
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=seq2seq_collator,
    train_dataset=processed_dataset["train"].select(range(10000)),
    eval_dataset=processed_dataset["validation"].select(range(100))
)

trainer.train()

In [None]:
import torch

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

sample = dataset["test"][0]
document = sample["document"]
inputs = tokenizer(document, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=256,
        num_beams=4,
        no_repeat_ngram_size=2,
        early_stopping=True
    )
print("원문 :", document)
print("정답 요약문 :", sample["summary"])
print("생성 요약문 :", tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
!pip install rouge_score

In [None]:
import evaluate

test_loader = DataLoader(
    processed_dataset["test"].select(range(100)),
    collate_fn=seq2seq_collator,
    batch_size=4,
    shuffle=False
)

generated_summaries = []
true_summaries = dataset["test"].select(range(100))["summary"]

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        output = model.generate(
            **batch,
            max_length=1026,
            num_beams=4,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        batch_summaries = tokenizer.batch_decode(output, skip_special_tokens=True)
        generated_summaries.extend(batch_summaries)

metric = evaluate.load("rouge")
rouge_scores = metric.compute(predictions=generated_summaries, references=true_summaries)
print(rouge_scores)