# pip install

In [None]:
!pip install -q datasets transformers sentencepiece evaluate jiwer rouge-score sacrebleu
!pip install --upgrade accelerate

# General setting

In [None]:
import os
import torch
import numpy as np
import random
from transformers import AutoTokenizer, set_seed

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# training device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# pretrain model
size = 'base' # 'small', 'base'
model_checkpoint = f'google/mt5-{size}'

# training parameters
num_epochs = 20 # 10, 20
batch_size = 8
learning_rate = 2e-5 # 1e-3, 2e-5
optimizer_name = "adamw_torch" # "adamw_torch", "adafactor"

# seed
seed = 112
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
g = torch.Generator()
g.manual_seed(seed)
set_seed(seed)

# Simple2In1 parameters
method = 'Simple2In1'
max_length = 32
model_name = f'{method}-{size}'

# report files
save_dir = f'models'
if not os.path.exists(save_dir): os.makedirs(save_dir)
train_report_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-{num_epochs}ep-seed{seed}-train.csv'
test_report_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-{num_epochs}ep-seed{seed}-test.csv'
gen_output_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-{num_epochs}ep-seed{seed}-gen.txt'

# Print parameters setting
print(f'Training Device                    : {device}')
print('====================')
print('Pre-train')
print(f'Model size                         : {size}')
print(f'Checkpoint                         : {model_checkpoint}')
print('====================')
print('Training parameters')
print(f'Batch size                         : {batch_size}')
print(f'Epochs                             : {num_epochs}')
print(f'Learning rate                      : {learning_rate}')
print(f'Optimizer name                     : {optimizer_name}')
print('====================')
print(f'{method} parameters')
print(f'Model name                         : {model_name}')
print(f'I/O length                         : {max_length}/{max_length}')
print('====================')
print(f'Train report                       : {train_report_file}')
print(f'Test report                        : {test_report_file}')
print(f'Generated text                     : {gen_output_file}')
print('====================')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=False)

# Utility

In [None]:
from datasets import load_dataset

def preprocess_function(examples, max_length=32):
    model_inputs = {}

    # Prepare labels
    ref = [s.replace('|', '') for s in examples['answer_segmented']]
    ref_ids = tokenizer(ref, max_length=max_length, padding='max_length',
                        truncation=True, return_tensors='pt')
    model_inputs['labels'] = ref_ids['input_ids']

    # Prepare inputs
    asr_texts = [s.replace('|', '') for s in examples['asr_segmented']]
    bp_texts = [s.replace('|', '') for s in examples['bangphim_segmented']]

    # Prepare concatenated input, seq1</s>seq2</s><pad>...<pad>, dim = (max_length)
    concat_seq = tokenizer([ (asr, bp) for asr, bp in zip(asr_texts, bp_texts) ],
                             max_length=max_length, padding='max_length',
                  truncation=True, return_tensors='pt')
    model_inputs['input_ids'] = concat_seq['input_ids']
    model_inputs['attention_mask'] = concat_seq['attention_mask']

    return model_inputs

def create_tokenized_dataset(input_filepath, validate_filepath=None, test_filepath=None):
    data_files = {}
    data_files["train"] = input_filepath
    if validate_filepath is not None: data_files["validate"] = validate_filepath
    if test_filepath is not None: data_files["test"] = test_filepath

    dataset = load_dataset("csv", data_files=data_files, delimiter=",")

    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
        num_proc=1,
        remove_columns=['Unnamed: 0',
                        'answer_segmented',
                        'asr_segmented',
                        'bangphim_segmented',
                        'room_id',
                        'alignment',
                        'flat_sequence',
                        'flat_position',
                        'flat_source',
                        'label', 'input', 'target'],
    )
    return tokenized_dataset


In [None]:
import evaluate

wer = evaluate.load("wer")
rouge = evaluate.load('rouge')
sbleu = evaluate.load("sacrebleu")
meteor = evaluate.load('meteor')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    decoded_preds, decoded_labels = [], []
    for pred_token_ids, label_token_ids in zip(predictions, labels):
        pred_tokens = [
            token for token in tokenizer.convert_ids_to_tokens(
                pred_token_ids)
            if token not in tokenizer.all_special_tokens
        ]
        label_tokens = [
            token for token in tokenizer.convert_ids_to_tokens(
                label_token_ids)
            if token not in tokenizer.all_special_tokens
        ]
        decoded_preds.append(' '.join(pred_tokens))
        decoded_labels.append(' '.join(label_tokens))

    # print(decoded_preds)
    # print(decoded_labels)

    wer_score    = wer.compute(predictions=decoded_preds, references=decoded_labels)
    rouge_score  = rouge.compute(predictions=decoded_preds, references=decoded_labels, tokenizer=lambda x: x.split())
    sbleu_score  = sbleu.compute(predictions=decoded_preds, references=decoded_labels)
    meteor_score = meteor.compute(predictions=decoded_preds, references=decoded_labels)

    result = {'wer': wer_score,
              'rouge1': rouge_score['rouge1'],
              'rouge2': rouge_score['rouge2'],
              'rougeL': rouge_score['rougeL'],
              'sacrebleu': sbleu_score['score'],
              'meteor': meteor_score['meteor'],
              }

    return {k: round(v, 4) for k, v in result.items()}


In [None]:
def generate(model, tokenizer, input_ids, decoder_input_ids=None, device='cpu', max_length=20):
  encoded_sequence = None
  if decoder_input_ids == None:
    decoder_input_ids = (tokenizer("<pad>", add_special_tokens=False, return_tensors="pt").input_ids).to(device)
    assert decoder_input_ids[0, 0].item() == model.config.decoder_start_token_id, "`decoder_input_ids` should correspond to `model.config.decoder_start_token_id`"

  for i in range(max_length):
    if encoded_sequence == None:
      outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
      encoded_sequence = (outputs.encoder_last_hidden_state,) # get encoded sequence
      lm_logits = outputs.logits # get logits

      next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) # sample last token with highest prob
      decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1) # concat
      if next_decoder_input_ids == model.config.eos_token_id:
        # print('EOS occur')
        break
    else:
      lm_logits = model(None,
                    encoder_outputs=encoded_sequence,
                    decoder_input_ids=decoder_input_ids,
                    return_dict=True).logits

      next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) # sample last token with highest prob again
      decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1) # concat again
      if next_decoder_input_ids == model.config.eos_token_id:
        # print('EOS occur')
        break
  return decoder_input_ids

In [None]:
from transformers import MT5ForConditionalGeneration

print('Pre-train')
print(f'Model size                         : {size}')
print(f'Checkpoint                         : {model_checkpoint}')
print('====================')
print(f'{method} parameters')
print(f'Model name                         : {model_name}')
print(f'I/O length                         : {max_length}/{max_length}')
print('====================')

model = MT5ForConditionalGeneration.from_pretrained(model_checkpoint)
model = model.to(device)
# model

# Dataset preparation by DataLoader

In [None]:
from torch.utils.data import DataLoader

tokenized_dataset = create_tokenized_dataset('/data/train.csv', '/data/validate.csv', '/data/test.csv')
tokenized_dataset

# Training

In [None]:
train_tokenized_dataset = tokenized_dataset["train"]
valid_tokenized_dataset = tokenized_dataset["validate"]

# prepare dataloader
train_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(train_tokenized_dataset, batch_size=batch_size)

valid_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
valid_dataloader = torch.utils.data.DataLoader(valid_tokenized_dataset, batch_size=batch_size)

In [None]:
from transformers import Seq2SeqTrainer
from transformers.optimization import get_scheduler, Adafactor
from torch.optim import AdamW

class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, *args, align_attn_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.align_attn_lr = align_attn_lr if align_attn_lr is not None else self.args.learning_rate

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        # Get model parameters
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]

        # Define four parameter groups: two custom layers and two for the remaining parameters
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if "local_alignment_attention" in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.align_attn_lr,
            },
            {
                "params": [p for n, p in model.named_parameters() if "global_alignment_attention" in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.align_attn_lr,
            },
            {
                "params": [p for n, p in model.named_parameters() if "local_alignment_attention" not in n and "global_alignment_attention" not in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.args.learning_rate,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": self.args.learning_rate,
            },
        ]

        # Create the optimizer using the parameter groups
        if self.args.optim == 'adamw_torch':
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
        elif self.args.optim == 'adafactor':
            self.optimizer = Adafactor(optimizer_grouped_parameters, lr=self.args.learning_rate)
        else:
            raise ValueError(f"Invalid optimizer_type: {self.args.optim}. Choose 'adam_torch' or 'adafactor'.")

        # Create the learning rate scheduler
        self.lr_scheduler = get_scheduler(
            self.args.lr_scheduler_type,
            self.optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps,
        )

        return self.optimizer, self.lr_scheduler


In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

args = Seq2SeqTrainingArguments(
    output_dir=f'./{model_name}',
    num_train_epochs=num_epochs,

    learning_rate=learning_rate,
    lr_scheduler_type='constant',

    logging_strategy='epoch',
    evaluation_strategy='epoch',
    # evaluation_strategy='steps',
    # eval_steps=100,
    # logging_strategy='steps',
    # logging_steps=100,

    # save_strategy='no',
    save_strategy = 'epoch',
    load_best_model_at_end = True,

    metric_for_best_model = 'wer',
    greater_is_better = False,

    # metric_for_best_model = 'loss',
    # greater_is_better = False,

    # metric_for_best_model = 'meteor',
    # greater_is_better = True,

    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,

    do_train=True,
    do_eval=True,

    # optim='adafactor',
    # optim='adamw_torch',
    optim=optimizer_name,

    predict_with_generate=True,
    # generation_max_length=max_length,
    # generation_num_beams

    # have to use half=False to avoid loss=0
    # ref:https://stackoverflow.com/questions/65332165/loss-is-nan-when-fine-tuning-huggingface-nli-model-both-roberta-bart
    fp16=False,

    # save the best model and the last
    # https://stackoverflow.com/a/67615225/3027437
    # https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442/4
    save_total_limit=2,
)

# Instantiate the custom trainer
# trainer = Seq2SeqTrainer(
trainer = CustomSeq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validate'],
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
import pandas as pd

eval_metrics = trainer.evaluate()
print(eval_metrics)
eval_df = pd.DataFrame(eval_metrics, index=[0])
eval_df.to_csv(train_report_file, index=False)

# Test

In [None]:
test_tokenized_dataset = tokenized_dataset["test"]

# prepare dataloader
test_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
dataloader = torch.utils.data.DataLoader(test_tokenized_dataset, batch_size=1)

In [None]:
test_result = trainer.predict(test_tokenized_dataset)
test_result.metrics

In [None]:
results = {}
for k, v in test_result.metrics.items():
  if k in ['test_runtime','test_samples_per_second','test_steps_per_second']: continue
  k = k.replace('test_', '').title()
  results[k] = v
test_df = pd.DataFrame([results], index=[model_name])
test_df.to_csv(test_report_file, index=False)

# Generate results

In [None]:
with open(gen_output_file, 'w', encoding='utf-8') as output_file:
    for i, batch in enumerate(dataloader, 1):
        print(f'test inputs #{i}:')
        print('==================')

        output_file.write(f'test inputs #{i}:\n')
        output_file.write('==================\n')

        input_ids = (batch['input_ids'].to(device))
        labels_ids = (batch['labels'].to(device))

        s1 = generate(model, tokenizer, input_ids, decoder_input_ids=None, device=device, max_length=max_length)
        s2 = model.generate(input_ids=input_ids)

        decoded_input = (tokenizer.batch_decode( input_ids, skip_special_tokens=True ))[0].split(' ')

        decoded_preds_1 = tokenizer.batch_decode( s1, skip_special_tokens=True )
        decoded_preds_2 = tokenizer.batch_decode( s2, skip_special_tokens=True )

        decoded_labels = tokenizer.batch_decode( labels_ids, skip_special_tokens=True )

        print('input_1       :', decoded_input[0])
        print('input_2       :', decoded_input[1])
        print('labels        :', decoded_labels[0])
        print('generate      :', decoded_preds_1[0])
        print('model.generate:', decoded_preds_2[0])
        print()

        output_file.write('input_1       : ' + decoded_input[0] + '\n')
        output_file.write('input_2       : ' + decoded_input[1] + '\n')
        output_file.write('labels        : ' + decoded_labels[0] + '\n')
        output_file.write('generate      : ' + decoded_preds_1[0] + '\n')
        output_file.write('model.generate: ' + decoded_preds_2[0] + '\n')
        output_file.write('\n')