In [None]:
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments)
import evaluate
import torch
import numpy as np
import random


In [None]:
seed = 413
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
inputs = {"raw": "dialogue", "resolved": "resolved_text"}
data = load_from_disk("samsum_new")
model_checkpoint = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


In [None]:
def preprocess(data, input):
    inputs = data[input]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    labels = tokenizer(text_target=data["summary"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
rouge = evaluate.load("rouge")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    predictions = np.array(predictions)
    predictions = predictions.astype(np.int64)
    predictions = np.clip(predictions, 0, tokenizer.vocab_size - 1)
    predictions = predictions.tolist()
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    if isinstance(labels, tuple):
        labels = labels[0]
    labels = np.array(labels)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = labels.astype(np.int64).tolist()
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return {k: round(v * 100, 4) for k, v in result.items()}


In [28]:
for version, input_ in inputs.items():

    tokenized_data = data.map(lambda x: preprocess(x, input_), batched=True, remove_columns=data["train"].column_names)

    model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
    for name, param in model.named_parameters():
        if any(f"encoder.block.{i}." in name for i in range(3)):
          param.requires_grad = False

    training_args = Seq2SeqTrainingArguments(
        output_dir=f"/content/t5-base-{version}-finetuned",
        eval_strategy="epoch",
        learning_rate=0.0001,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        weight_decay=0.01,
        save_total_limit=1,
        num_train_epochs=3,
        predict_with_generate=True,
        fp16=True,
        push_to_hub=False,
        logging_dir=f"/content/logs-{version}",
        logging_strategy="no",
        logging_steps=1000000,
        report_to="none",
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_data["train"],
        eval_dataset=tokenized_data["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_model(f"/content/t5-base-{version}-finetuned")

    metrics = trainer.evaluate()

    print(f"\nMetrics for {version} input:")
    for key in ["eval_loss", "eval_rouge1", "eval_rouge2", "eval_rougeL"]:
        label_map = {
            "eval_loss": "Loss",
            "eval_rouge1": "ROUGE1",
            "eval_rouge2": "ROUGE2",
            "eval_rougeL": "ROUGEL",
        }
        if key in metrics:
            print(f"{label_map[key]:<10}: {metrics[key]:.4f}")


{'eval_loss': 1.4028204679489136, 'eval_rouge1': 47.856, 'eval_rouge2': 23.9294, 'eval_rougeL': 40.0097, 'eval_rougeLsum': 40.0419, 'eval_runtime': 63.5825, 'eval_samples_per_second': 12.865, 'eval_steps_per_second': 1.62, 'epoch': 1.0}
{'eval_loss': 2.844597578048706, 'eval_rouge1': 28.1794, 'eval_rouge2': 10.1408, 'eval_rougeL': 24.7284, 'eval_rougeLsum': 24.7015, 'eval_runtime': 61.5508, 'eval_samples_per_second': 13.29, 'eval_steps_per_second': 1.673, 'epoch': 2.0}
{'eval_loss': 2.8258883953094482, 'eval_rouge1': 27.3721, 'eval_rouge2': 9.4496, 'eval_rougeL': 23.9784, 'eval_rougeLsum': 23.9712, 'eval_runtime': 62.4779, 'eval_samples_per_second': 13.093, 'eval_steps_per_second': 1.649, 'epoch': 3.0}
{'train_runtime': 1709.0635, 'train_samples_per_second': 25.86, 'train_steps_per_second': 3.233, 'train_loss': 2.4396065191820484, 'epoch': 3.0}
{'eval_loss': 2.8258883953094482, 'eval_rouge1': 27.3721, 'eval_rouge2': 9.4496, 'eval_rougeL': 23.9784, 'eval_rougeLsum': 23.9712, 'eval_runti

In [None]:
def summary(sample_idx, data, model, tokenizer, model_resolved, tokenizer_resolved, device):
    text = data["test"][sample_idx]["dialogue"]
    resolved_text = data["test"][sample_idx]["resolved_text"]
    reference = data["test"][sample_idx]["summary"]

    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
    output = model.generate(**inputs, max_length=128)
    summary = tokenizer.decode(output[0], skip_special_tokens=True)

    inputs_resolved = tokenizer_resolved(resolved_text, return_tensors="pt", max_length=512, truncation=True).to(device)
    output_resolved = model_resolved.generate(**inputs_resolved, max_length=128)
    summary_resolved = tokenizer_resolved.decode(output_resolved[0], skip_special_tokens=True)

    print("Input")
    print(text)
    print()
    print("Summary from Raw Model")
    print(summary)
    print()
    print("Summary from Anaphora Resolution Model")
    print(summary_resolved)
    print()
    print("Reference Summary")
    print(reference)
    print()


In [None]:
tokenizer = AutoTokenizer.from_pretrained("/content/t5-base-raw-finetuned")
model = AutoModelForSeq2SeqLM.from_pretrained("/content/t5-base-raw-finetuned").to(device)

tokenizer_resolved = AutoTokenizer.from_pretrained("/content/t5-base-resolved-finetuned")
model_resolved = AutoModelForSeq2SeqLM.from_pretrained("/content/t5-base-resolved-finetuned").to(device)


In [None]:
summary(
    sample_idx=17,
    data=data,
    model=model,
    tokenizer=tokenizer,
    model_resolved=model_resolved,
    tokenizer_resolved=tokenizer_resolved,
    device=device
)


Input
Igor: Shit, I've got so much to do at work and I'm so demotivated. 
John: It's pretty irresponsible to give that much work to someone on their notice period.
Igor: Yeah, exactly! Should I even care?
John: It's up to you, but you know what they say...
Igor: What do you mean?
John: Well, they say how you end things shows how you really are...
Igor: And now how you start, right?
John: Gotcha! 
Igor: So what shall I do then? 
John: It's only two weeks left, so grit your teeth and do what you have to do. 
Igor: Easy to say, hard to perform.
John: Come on, stop thinking, start doing! 
Igor: That's so typical of you!  ;)  

Summary from Raw Model
John has been working for his job.

Summary from Anaphora Resolution Model
Igor has two weeks left to finish work.

Reference Summary
Igor has a lot of work on his notice period and he feels demotivated. John thinks he should do what he has to do nevertheless. 

