# pip install

In [None]:
!pip install -q datasets transformers sentencepiece evaluate jiwer rouge-score sacrebleu
!pip install --upgrade accelerate

# General setting

In [None]:
import os
import torch
import numpy as np
import random
from transformers import AutoTokenizer, set_seed

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# training device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# pretrain model
size = 'base' # 'small', 'base'
model_checkpoint = f'google/mt5-{size}'

# training parameters
num_epochs = 20 # 10, 20
batch_size = 8
learning_rate = 2e-5 # 1e-3, 2e-5
linear_layer_lr = 1e-3 # 2e-5
optimizer_name = "adamw_torch" # "adamw_torch", "adafactor"

# seed
seed = 112
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
g = torch.Generator()
g.manual_seed(seed)
set_seed(seed)

# PairConcat parameters
method = 'PairConcat'
max_length = 32
model_name = f'{method}-{size}'

# report files
save_dir = f'models'
if not os.path.exists(save_dir): os.makedirs(save_dir)
train_report_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-linlr{linear_layer_lr}-{num_epochs}ep-seed{seed}-train.csv'
test_report_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-linlr{linear_layer_lr}-{num_epochs}ep-seed{seed}-test.csv'
gen_output_file = f'{save_dir}/{model_name}-seqlen{max_length}-{optimizer_name}-lr{learning_rate}-linlr{linear_layer_lr}-{num_epochs}ep-seed{seed}-gen.txt'

# Print parameters setting
print(f'Training Device                    : {device}')
print('====================')
print('Pre-train')
print(f'Model size                         : {size}')
print(f'Checkpoint                         : {model_checkpoint}')
print('====================')
print('Training parameters')
print(f'Batch size                         : {batch_size}')
print(f'Epochs                             : {num_epochs}')
print(f'Learning rate                      : {learning_rate}')
print(f'Linear layer lr                    : {linear_layer_lr}')
print(f'Optimizer name                     : {optimizer_name}')
print('====================')
print(f'{method} parameters')
print(f'Model name                         : {model_name}')
print(f'I/O length                         : {max_length}/{max_length}')
print('====================')
print(f'Train report                       : {train_report_file}')
print(f'Test report                        : {test_report_file}')
print(f'Generated text                     : {gen_output_file}')
print('====================')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=False)

# Utility

In [None]:
import heapq

def scoring(s1, s2, gap_penalty, match, mismatch):
  matrix = np.zeros( (len(s1)+1, len(s2)+1) )
  for i in range( len(s1)+1 ):
    matrix[i][0] = i*gap_penalty

  for j in range( len(s2)+1 ):
    matrix[0][j] = j*gap_penalty

  trace_sequence = ''

  for i in range(len(s1)+1):
    for j in range(len(s2)+1):
      if i == 0 and j == 0:
        #We're in the upper right corner
        matrix[i][j] = 0
      elif i == 0:
        matrix[i][j] = matrix[i][j-1] + gap_penalty
        trace_sequence += '-'
      elif j == 0:
        matrix[i][j] = matrix[i-1][j] + gap_penalty
        trace_sequence += '-'
      else:
        left = matrix[i][j-1] + gap_penalty
        top = matrix[i-1][j] + gap_penalty

        if s1[i-1] == s2[j-1]:
          trace_sequence += s1[i-1]
          diagonal = matrix[i-1][j-1] + match
        else:
          trace_sequence += '-'
          diagonal = matrix[i-1][j-1] + mismatch


        # print(diagonal, top, left)
        matrix[i][j] = max(diagonal, top, left)
  # print(trace_sequence)

  return matrix

def traceback(matrix, s1, s2):
  i = len(s1)
  j = len(s2)

  aligned_sequence = ''
  while i > 0 and j > 0:
    # current = matrix[i][j]
    left = matrix[i-1][j]
    diagonal = matrix[i-1][j-1]
    top = matrix[i][j-1]

    w1 = s1[i-1]
    w2 = s2[j-1]

    if left > diagonal:
      if left > top:
        # left is maximum
        aligned_sequence = '({},{})'.format(w1, '-') + aligned_sequence
        i=i-1
      else:
        # top is maximum
        aligned_sequence = '({},{})'.format('-', w2) + aligned_sequence
        j=j-1
    elif diagonal > top:
      # match or mismatch
      # diagonal is maximum
      if s1[i-1] == s2[j-1]:
        aligned_sequence = w1 + aligned_sequence
      else:
        # mismatch
        aligned_sequence = '({},{})'.format(w1, w2) + aligned_sequence
      i=i-1
      j=j-1
    else:
      # top is maximum
      aligned_sequence = '({},{})'.format('-', w2) + aligned_sequence
      j=j-1

  while i > 0:
    w1 = s1[i-1]
    aligned_sequence = '({},{})'.format(w1, '-') + aligned_sequence
    i=i-1

  while j > 0:
    w2 = s2[j-1]
    aligned_sequence = '({},{})'.format('-', w2) + aligned_sequence
    j=j-1

  return aligned_sequence

def needleman_wunsch(s1, s2, gap_penalty=-1, match=1, mismatch=-1):
  matrix = scoring(s1, s2, gap_penalty, match, mismatch)
  alignment = traceback(matrix, s1, s2)
  return matrix, alignment

def n_best_alignments(seq1, seq2, n, dp,
                      i=None,
                      j=None,
                      aligned_seq1='',
                      aligned_seq2=''):
    if i is None or j is None:
        i, j = len(seq1), len(seq2)

    if i == 0 and j == 0:
        return [(-dp[i][j], aligned_seq1, aligned_seq2)]

    candidates = []
    if i > 0 and j > 0:
        if seq1[i - 1] == seq2[j - 1]:
            score = dp[i - 1][j - 1] + 1
        else:
            score = dp[i - 1][j - 1] - 1
        if dp[i][j] == score:
            candidates.append(
                ( i - 1, j - 1,
                  seq1[i - 1] + ' ' + aligned_seq1,
                  seq2[j - 1] + ' ' + aligned_seq2 ) )

    if i > 0 and dp[i][j] == dp[i - 1][j] - 1:
        candidates.append(
            ( i - 1, j,
              seq1[i - 1] + ' ' + aligned_seq1,
              '-' + ' ' + aligned_seq2 ) )

    if j > 0 and dp[i][j] == dp[i][j - 1] - 1:
        candidates.append(
            ( i, j - 1,
              '-' + ' ' + aligned_seq1,
              seq2[j - 1] + ' ' + aligned_seq2 ) )

    results = []
    for (new_i, new_j, new_seq1, new_seq2) in candidates:
        results.extend(
            n_best_alignments(seq1, seq2, n, dp, new_i, new_j, new_seq1,
                              new_seq2))

    if len(results) > n:
        results = heapq.nsmallest(n, results)
    return results


In [None]:
def n_best_align_tokenized_tokens(seq1, seq2, n=3, max_length=32):
  seq1 = tokenizer.convert_ids_to_tokens( seq1 )
  seq2 = tokenizer.convert_ids_to_tokens( seq2 )

  dp, _ = needleman_wunsch(seq1, seq2)
  alignments = n_best_alignments(seq1, seq2, n, dp)

  aligned_ttids = []
  attn_mask = []
  for alignment in alignments:
    _, aligned_seq1, aligned_seq2 = alignment

    ttid1 = tokenizer.convert_tokens_to_ids( aligned_seq1.strip().split(' ') )
    ttid2 = tokenizer.convert_tokens_to_ids( aligned_seq2.strip().split(' ') )
    attn = [1] * len(ttid1)

    pad_length = max_length - len(ttid1)
    if pad_length > 0:
      pad = [tokenizer.pad_token_id] * pad_length
      ttid1 = ttid1 + pad

    pad_length = max_length - len(ttid2)
    if pad_length > 0:
      pad = [tokenizer.pad_token_id] * pad_length
      ttid2 = ttid2 + pad

    pad_length = max_length - len(attn)
    if pad_length > 0:
      pad = [tokenizer.pad_token_id] * pad_length
      attn = attn + pad

    aligned_ttids.append((ttid1, ttid2))
    attn_mask.append(attn)

  return aligned_ttids, attn_mask

In [None]:
from datasets import load_dataset

def preprocess_function(examples, max_sequence_length=32):
    model_inputs = {}

    input_ids = []
    attention_mask = []
    # Prepare labels
    ref = [s.replace('|', '') for s in examples['answer_segmented']]
    ref_ids = tokenizer(ref, max_length=max_length, padding='max_length',
                        truncation=True, return_tensors='pt')
    model_inputs['labels'] = ref_ids['input_ids']

    # Prepare inputs
    asr_texts = [s.replace('|', '') for s in examples['asr_segmented']]
    bp_texts = [s.replace('|', '') for s in examples['bangphim_segmented']]

    # Prepare n-best alignment inputs
    asr_ids = tokenizer(asr_texts, add_special_tokens=True)
    bp_ids = tokenizer(bp_texts, add_special_tokens=True)
    for s1_ids, s2_ids in zip(asr_ids.input_ids, bp_ids.input_ids): # Each input pairs
      alignments, attn_mask = n_best_align_tokenized_tokens(s1_ids, s2_ids, n=1, max_length=max_length) # dim = (1, 2, max_length)

      alignments = torch.tensor( alignments ) # (1, 2, 32)
      attn_mask = torch.tensor( attn_mask ) # (1, 32)

      input_ids.append( alignments )
      attention_mask.append( attn_mask )

    model_inputs['input_ids'] = torch.stack( input_ids )
    model_inputs['attention_mask'] = torch.stack( attention_mask )

    return model_inputs

def create_tokenized_dataset(input_filepath, validate_filepath=None, test_filepath=None):
    data_files = {}
    data_files["train"] = input_filepath
    if validate_filepath is not None: data_files["validate"] = validate_filepath
    if test_filepath is not None: data_files["test"] = test_filepath

    dataset = load_dataset("csv", data_files=data_files, delimiter=",")

    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
        num_proc=1,
        remove_columns=['Unnamed: 0',
                        'answer_segmented',
                        'asr_segmented',
                        'bangphim_segmented',
                        'room_id',
                        'alignment',
                        'flat_sequence',
                        'flat_position',
                        'flat_source',
                        'label', 'input', 'target'],
    )
    return tokenized_dataset


In [None]:
import evaluate

wer = evaluate.load("wer")
rouge = evaluate.load('rouge')
sbleu = evaluate.load("sacrebleu")
meteor = evaluate.load('meteor')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    decoded_preds, decoded_labels = [], []
    for pred_token_ids, label_token_ids in zip(predictions, labels):
        pred_tokens = [
            token for token in tokenizer.convert_ids_to_tokens(
                pred_token_ids)
            if token not in tokenizer.all_special_tokens
        ]
        label_tokens = [
            token for token in tokenizer.convert_ids_to_tokens(
                label_token_ids)
            if token not in tokenizer.all_special_tokens
        ]
        decoded_preds.append(' '.join(pred_tokens))
        decoded_labels.append(' '.join(label_tokens))

    # decoded_preds = tokenizer.batch_decode( predictions, skip_special_tokens=True )

    # # Replace -100 in the labels as we can't decode them.
    # labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # decoded_labels = tokenizer.batch_decode( labels, skip_special_tokens=True )

    # decoded_preds = [' '.join( pred.strip() ) for pred in decoded_preds]
    # decoded_labels = [' '.join( label.strip() ) for label in decoded_labels]


    wer_score    = wer.compute(predictions=decoded_preds, references=decoded_labels)
    rouge_score  = rouge.compute(predictions=decoded_preds, references=decoded_labels, tokenizer=lambda x: x.split())
    sbleu_score  = sbleu.compute(predictions=decoded_preds, references=decoded_labels)
    meteor_score = meteor.compute(predictions=decoded_preds, references=decoded_labels)

    result = {'wer': wer_score,
              'rouge1': rouge_score['rouge1'],
              'rouge2': rouge_score['rouge2'],
              'rougeL': rouge_score['rougeL'],
              'sacrebleu': sbleu_score['score'],
              'meteor': meteor_score['meteor'],
              }

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
def generate(model, tokenizer, input_ids, decoder_input_ids=None, device='cpu', max_length=20):
  encoded_sequence = None
  if decoder_input_ids == None:
    decoder_input_ids = (tokenizer("<pad>", add_special_tokens=False, return_tensors="pt").input_ids).to(device)
    assert decoder_input_ids[0, 0].item() == model.config.decoder_start_token_id, "`decoder_input_ids` should correspond to `model.config.decoder_start_token_id`"

  for i in range(max_length):
    if encoded_sequence == None:
      outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
      encoded_sequence = (outputs.encoder_last_hidden_state,) # get encoded sequence
      lm_logits = outputs.logits # get logits

      next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) # sample last token with highest prob
      decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1) # concat
      if next_decoder_input_ids == model.config.eos_token_id:
        # print('EOS occur')
        break
    else:
      lm_logits = model(None,
                    encoder_outputs=encoded_sequence,
                    decoder_input_ids=decoder_input_ids,
                    return_dict=True).logits

      next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) # sample last token with highest prob again
      decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1) # concat again
      if next_decoder_input_ids == model.config.eos_token_id:
        # print('EOS occur')
        break
  return decoder_input_ids

In [None]:
from transformers import MT5EncoderModel, MT5ForConditionalGeneration
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union

## PairConcatEncoder

In [None]:
import torch.nn as nn

class PairConcatEncoder(MT5EncoderModel):
    def __init__(self, config, encoders):
        super().__init__(config)
        del(self.shared)
        del(self.encoder)

        self.linear = nn.Linear(config.d_model * 2, config.d_model)
        self.encoder = encoders

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
          input_ids = np.squeeze( input_ids, axis=1 ) # [bs, 1, 2, 32] -> [bs, 2, 32]
          embedding = self.encoder.embed_tokens( input_ids )
          # embedding[:, 0] -> aligned seq1
          # embedding[:, 1] -> aligned seq2
          concat_embedding = torch.cat( (embedding[:, 0], embedding[:, 1]), dim=-1 ) # [32, 512][32, 512]
          inputs_embeds = self.linear( concat_embedding ) # (bs, 32, 512)
          # print('inputs_embeds: ', inputs_embeds.shape) # (bs, 32, 512)
          encoder_outputs = self.encoder( # MT5Stack.forward
              attention_mask=attention_mask,
              inputs_embeds=inputs_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )
          return encoder_outputs

class PairConcat(MT5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        pretrain = MT5ForConditionalGeneration.from_pretrained(model_checkpoint)
        del(self.shared)
        del(self.encoder)
        del(self.decoder)
        del(self.lm_head)

        self.encoder = PairConcatEncoder(config, pretrain.encoder)
        self.decoder = pretrain.decoder
        self.lm_head = pretrain.lm_head

    def get_encoder(self):
      return self.encoder

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                # warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # print(encoder_outputs)
        # print(return_dict)
        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)
            # print('decoder_input_ids', decoder_input_ids.shape)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        lm_logits = self.lm_head(sequence_output)
        # print('lm_logits.shape:', lm_logits.shape)
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            labels = labels.to(lm_logits.device)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):

        # cut decoder_input_ids if past is used
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]
            # print('prepare_inputs (past not None)', tokenizer.batch_decode( input_ids, skip_special_tokens=False ))
            # print()

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)


In [None]:
from transformers import AutoConfig

print('Pre-train')
print(f'Model size                         : {size}')
print(f'Checkpoint                         : {model_checkpoint}')
print('====================')
print(f'{method} parameters')
print(f'Model name                         : {model_name}')
print(f'I/O length                         : {max_length}/{max_length}')
print('====================')

mt5_gen_conf = AutoConfig.from_pretrained(model_checkpoint)
model = PairConcat(mt5_gen_conf)
model = model.to(device)
# model

# Dataset preparation by DataLoader

In [None]:
from torch.utils.data import DataLoader

tokenized_dataset = create_tokenized_dataset('/data/train.csv', '/data/validate.csv', '/data/test.csv')
tokenized_dataset

# Training

In [None]:
train_tokenized_dataset = tokenized_dataset["train"]
valid_tokenized_dataset = tokenized_dataset["validate"]

# prepare dataloader
train_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(train_tokenized_dataset, batch_size=batch_size)

valid_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
valid_dataloader = torch.utils.data.DataLoader(valid_tokenized_dataset, batch_size=batch_size)

In [None]:
from transformers import Seq2SeqTrainer
from transformers.optimization import get_scheduler, Adafactor
from torch.optim import AdamW

class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, *args, linear_layer_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.linear_layer_lr = linear_layer_lr if linear_layer_lr is not None else self.args.learning_rate

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        # Get model parameters
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]

        # Define four parameter groups: two custom layers and two for the remaining parameters
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if "linear" in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.linear_layer_lr,
            },
            {
                "params": [p for n, p in model.named_parameters() if "linear" not in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.args.learning_rate,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": self.args.learning_rate,
            },
        ]

        # Create the optimizer using the parameter groups
        if self.args.optim == 'adamw_torch':
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
        elif self.args.optim == 'adafactor':
            self.optimizer = Adafactor(optimizer_grouped_parameters, lr=self.args.learning_rate)
        else:
            raise ValueError(f"Invalid optimizer_type: {self.args.optim}. Choose 'adam_torch' or 'adafactor'.")

        # Create the learning rate scheduler
        self.lr_scheduler = get_scheduler(
            self.args.lr_scheduler_type,
            self.optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps,
        )

        return self.optimizer, self.lr_scheduler


In [None]:
from transformers import Seq2SeqTrainingArguments

args = Seq2SeqTrainingArguments(
    output_dir=f'./{model_name}',
    num_train_epochs=num_epochs,

    learning_rate=learning_rate,
    lr_scheduler_type='constant',

    logging_strategy='epoch',
    evaluation_strategy='epoch',
    # evaluation_strategy='steps',
    # eval_steps=100,
    # logging_strategy='steps',
    # logging_steps=100,

    # save_strategy='no',
    save_strategy = 'epoch',
    load_best_model_at_end = True,

    metric_for_best_model = 'wer',
    greater_is_better = False,

    # metric_for_best_model = 'loss',
    # greater_is_better = False,

    # metric_for_best_model = 'meteor',
    # greater_is_better = True,

    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,

    do_train=True,
    do_eval=True,

    # optim='adafactor',
    # optim='adamw_torch',
    optim=optimizer_name,

    predict_with_generate=True,
    # generation_max_length=max_length,
    # generation_num_beams

    # have to use half=False to avoid loss=0
    # ref:https://stackoverflow.com/questions/65332165/loss-is-nan-when-fine-tuning-huggingface-nli-model-both-roberta-bart
    fp16=False,

    # save the best model and the last
    # https://stackoverflow.com/a/67615225/3027437
    # https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442/4
    save_total_limit=2,
)

# Instantiate the custom trainer
# trainer = Seq2SeqTrainer(
trainer = CustomSeq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validate'],
    compute_metrics=compute_metrics,
    linear_layer_lr=linear_layer_lr,
)

trainer.train()

In [None]:
import pandas as pd

eval_metrics = trainer.evaluate()
print(eval_metrics)
eval_df = pd.DataFrame(eval_metrics, index=[0])
eval_df.to_csv(train_report_file, index=False)

# Test

In [None]:
test_tokenized_dataset = tokenized_dataset["test"]

# prepare dataloader
test_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
dataloader = torch.utils.data.DataLoader(test_tokenized_dataset, batch_size=1)

In [None]:
test_result = trainer.predict(test_tokenized_dataset)
test_result.metrics

In [None]:
results = {}
for k, v in test_result.metrics.items():
  if k in ['test_runtime','test_samples_per_second','test_steps_per_second']: continue
  k = k.replace('test_', '').title()
  results[k] = v
test_df = pd.DataFrame([results], index=[model_name])
test_df.to_csv(test_report_file, index=False)

# Generate results

In [None]:
with open(gen_output_file, 'w', encoding='utf-8') as output_file:
    for i, batch in enumerate(dataloader, 1):
        print(f'test inputs #{i}:')
        print('==================')

        output_file.write(f'test inputs #{i}:\n')
        output_file.write('==================\n')

        input_ids = (batch['input_ids'].to(device))
        labels_ids = (batch['labels'].to(device))

        inp = input_ids[:, :, :]
        inp1 =  np.reshape(batch['input_ids'][:, :, 0], (1, 32))
        inp2 =  np.reshape(batch['input_ids'][:, :, 1], (1, 32))

        s1 = generate(model, tokenizer, input_ids, decoder_input_ids=None, device=device, max_length=max_length)
        s2 = model.generate(input_ids, max_length=max_length)

        decoded_input_1 = tokenizer.batch_decode( inp1, skip_special_tokens=True )
        decoded_input_2 = tokenizer.batch_decode( inp2, skip_special_tokens=True )

        decoded_preds_1 = tokenizer.batch_decode( s1, skip_special_tokens=True )
        decoded_preds_2 = tokenizer.batch_decode( s2, skip_special_tokens=True )

        decoded_labels = tokenizer.batch_decode( labels_ids, skip_special_tokens=True )

        print('input_1       :', decoded_input_1[0])
        print('input_2       :', decoded_input_2[0])
        print('labels        :', decoded_labels[0])
        print('generate      :', decoded_preds_1[0])
        print('model.generate:', decoded_preds_2[0])
        print()

        output_file.write('input_1       : ' + decoded_input_1[0] + '\n')
        output_file.write('input_2       : ' + decoded_input_2[0] + '\n')
        output_file.write('labels        : ' + decoded_labels[0] + '\n')
        output_file.write('generate      : ' + decoded_preds_1[0] + '\n')
        output_file.write('model.generate: ' + decoded_preds_2[0] + '\n')
        output_file.write('\n')