In [None]:
# any pip installs

! pip install transformers datasets rouge-score evaluate

### Imports

## Data Preperation

In [None]:

from datasets import load_dataset


In [None]:
dataset = load_dataset("ccdv/arxiv-summarization")
dataset

In [None]:
def filter_article(example):
    return len(example['article']) < 8000

In [None]:
small_articles_train = dataset["train"].filter(filter_article)
small_articles_val = dataset["validation"].filter(filter_article)
small_articles_test = dataset["test"].filter(filter_article)
print(small_articles_train)
print(small_articles_val)
print(small_articles_test)

In [None]:
def show_samples(dataset, num_samples=2, seed=42):
    sample = dataset.shuffle(seed=seed).select(range(num_samples))
    for example in sample:
        print(f"\n'>> Article: {example['article']}'")
        print(f"'>> Abstract: {example['abstract']}'")

In [None]:
show_samples(small_articles_train)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

In [None]:
max_input_length = 512
max_target_length = 30

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["article"],
        max_length=max_input_length,
        truncation=True,
    )
    labels = tokenizer(
        examples["abstract"], max_length=max_target_length, truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
small_articles_train = small_articles_train.map(preprocess_function, batched=True)
small_articles_val = small_articles_val.map(preprocess_function, batched=True)
small_articles_test = small_articles_test.map(preprocess_function, batched=True)
print(small_articles_train)
print(small_articles_val)
print(small_articles_test)

In [None]:
import evaluate

rouge_score = evaluate.load("rouge")

In [None]:
import nltk
from nltk.tokenize import sent_tokenize

nltk.download("punkt")

In [None]:
def three_sentence_summary(text):
    return "\n".join(sent_tokenize(text)[:3])

print(three_sentence_summary(small_articles_train[1]["article"]))

In [None]:
def evaluate_baseline(dataset, metric):
    summaries = [three_sentence_summary(text) for text in dataset["article"]]
    return metric.compute(predictions=summaries, references=dataset["abstract"])

In [None]:
from transformers import Seq2SeqTrainingArguments


batch_size = 8
num_train_epochs = 8
# Show the training loss with every epoch
logging_steps = len(small_articles_train) // batch_size
model_name = "summarizer-hc"

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_name}-finetuned-summarizer",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
)


In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=small_articles_train,
    eval_dataset=small_articles_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

assert(len(small_articles_train) > 0)

In [None]:
print(len(small_articles_train))

In [None]:
trainer.train()