In [None]:
%load_ext autoreload
%autoreload 2

import os
import json

from src.datasets import IndoSum
from src.common import get_device
from src.indobart.base import get_model, get_tokenizer

import numpy as np
import nltk
import evaluate
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer


from accelerate import Accelerator



In [None]:
accelerator = Accelerator()
device = accelerator.device
device

device(type='mps')

### Data Loading

In [3]:
indosum = IndoSum()
indosum.ds

DatasetDict({
    train: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 14262
    })
    test: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 3762
    })
    validation: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 750
    })
})

In [8]:
indosum.to_pd("train").head()

Unnamed: 0,document,id,summary
0,"Jakarta, CNN Indonesia - - Dokter Ryan Thamrin...",1501893029-lula-kamal-dokter-ryan-thamrin-saki...,Dokter Lula Kamal yang merupakan selebriti sek...
1,Selfie ialah salah satu tema terpanas di kalan...,1509072914-dua-smartphone-zenfone-baru-tawarka...,Asus memperkenalkan ZenFone generasi keempat...
2,"Jakarta, CNN Indonesia - - Dinas Pariwisata Pr...",1510613677-songsong-visit-2020-bengkulu-perkua...,Dinas Pariwisata Provinsi Bengkulu kembali men...
3,Merdeka.com - Indonesia Corruption Watch (ICW)...,1502706803-icw-ada-kejanggalan-atas-tewasnya-s...,Indonesia Corruption Watch (ICW) meminta Komis...
4,Merdeka.com - Presiden Joko Widodo (Jokowi) me...,1503039338-pembagian-sepeda-usai-upacara-penur...,Jokowi memimpin upacara penurunan bendera. Usa...


### Load Model

In [5]:
model = get_model()
tokenizer = get_tokenizer()

In [6]:
model

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (f

In [7]:
tokenizer

IndoNLGTokenizer(name_or_path='indobenchmark/indobart-v2', vocab_size=40004, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>', 'additional_special_tokens': ['<mask>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	40003: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}

### Train Model

In [None]:
# Setup evaluation
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

#### Preparation

In [None]:
# Prepare and tokenize dataset
def preprocess_function(examples):
    model_inputs = tokenizer(examples["document"], max_length=768, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    return result

tokenized_ds = indosum.ds().map(preprocess_function, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

def train_model(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length):
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir + "/checkpoint",
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        weight_decay=0.01,
        num_train_epochs=num_train_epochs,
        fp16=True,
        predict_with_generate=True,
        generation_max_length=generation_max_length,
        log_level="info",
        logging_first_step=True,
        logging_dir=output_dir + "/logs",
        resume_from_checkpoint=True,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_ds["train"],
        eval_dataset=tokenized_ds["validation"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    return trainer
    
def evaluate_model(trainer):
    eval_results = trainer.evaluate(eval_dataset=tokenized_ds["test"])
    return eval_results


def train_and_evaluate(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length):
    trainer = train_model(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length)
    eval_results = evaluate_model(trainer)
    
    return trainer, eval_results


#### Training & Evaluation

Try multiple generation max length with the rest parameters fixed.
Observes the best score and the corresponding generation max length.

In [None]:
experiments = []

for i in range(1, 6):
    generation_max_length = 50 + i * 10
    experiments.append({
        "output_dir": f"./results/00-indobart/0{i}",
        "per_device_batch_size": 8,
        "learning_rate": 3.75e-5,
        "num_train_epochs": 3,
        "generation_max_length": generation_max_length
    })

for exp in experiments:
    os.makedirs(exp["output_dir"], exist_ok=True)
    
    trainer, eval_results = train_and_evaluate(
        exp["output_dir"],
        exp["per_device_batch_size"],
        exp["learning_rate"],
        exp["num_train_epochs"],
        exp["generation_max_length"]
    )
    
    # print params and the results
    print("=== Results for experiment ===")
    print("-- Params --") 
    print(json.dumps(exp, indent=4))
    print("-- Eval results --")
    print(json.dumps(eval_results, indent=4))
    
    # save mapping between params and results
    with open(exp["output_dir"] + "/params.json", "w") as f:
        json.dump(exp, f)
    
    with open(exp["output_dir"] + "/eval_results.json", "w") as f:
        json.dump(eval_results, f)

