In [1]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainerCallback,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    Seq2SeqTrainer
)
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import ftfy
import evaluate
import os

In [2]:
model_id = 't5-base'
model_folder = "model"
model_name = "t5_doc2query"
base_path = "."
batch_size = 6
checkpoints_path = "checkpoints"
gradient_accumulation_steps = 8
epochs = 100

In [3]:
def load_data():
    colnames = ["query", "relevant", "not relevant"]
    df_data = pd.read_csv(f"{base_path}/data/msmarco_triples.train.tiny.tsv", encoding="UTF=8", sep="\t", names=colnames)
    
    df_data["relevant"] = df_data["relevant"].apply(lambda text: ftfy.fix_text(text))
    df_data = df_data.drop("not relevant", axis=1)
    
    return df_data

In [4]:
def split_data(df):
    df_train, df_val = train_test_split(
        df,
        test_size=0.10,
        random_state=42
    )
    
    return df_train, df_val

In [5]:
class Doc2QueryDataset(torch.utils.data.Dataset):
    def __init__(self, ms_df, tokenizer):
        self.tokenized_topics = tokenizer(ms_df['query'].tolist(), return_length=True)
        self.tokenized_passage = tokenizer(ms_df['relevant'].tolist(), return_length=True)

    def __len__(self):
        return len(self.tokenized_topics['input_ids'])


    def __getitem__(self, index):
        return {
            'input_ids': self.tokenized_passage['input_ids'][index],
            'attention_mask': self.tokenized_passage['attention_mask'][index],
            'labels': self.tokenized_topics['input_ids'][index]
        }

In [6]:
class Doc2QueryTrainerCallback(TrainerCallback):
    def __init__(self, best_validation_yet=99999, model=None) -> None:
        super().__init__()

        self.best_validation_metric = best_validation_yet
        self.model = model


    def on_evaluate(self, args, state, control, model=None, metrics=None, **kwargs):
        print("metrics['eval_loss']={}".format(metrics['eval_loss']))
        print("metrics['eval_bleu']={}".format(metrics['eval_bleu']))


        if metrics['eval_bleu'] > self.best_validation_metric:
            self.model.save_pretrained(
                os.path.join(
                    checkpoints_path,
                    "checkpoint-{}-{:.4f}".format(
                        state.global_step,
                        metrics['eval_bleu']
                    )
                )
            )

            self.best_validation_metric = metrics['eval_bleu']

In [7]:
def load_data_set(train_df, validation_df, tokenizer):
    train_dataset = Doc2QueryDataset(train_df, tokenizer)
    eval_dataset = Doc2QueryDataset(validation_df, tokenizer)
    
    return train_dataset, eval_dataset

In [8]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

In [9]:
def compute_metrics(eval_preds):
    metric = evaluate.load("sacrebleu")
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]

    print("compute_metrics. preds.shape={}".format(preds.shape))

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    
    result["gen_len"] = np.mean(prediction_lens)
    
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [10]:
def train(tokenizer, model, train_dataset, eval_dataset):
    training_params = Seq2SeqTrainingArguments(
        output_dir=checkpoints_path,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        evaluation_strategy='steps',
        eval_steps=200,
        save_strategy='steps',
        save_steps=1000,
        logging_strategy='steps',
        logging_steps=10,
        save_total_limit=2,
        dataloader_pin_memory=True,
        predict_with_generate=True,
        generation_num_beams=10,
        fp16=True
    )
    
    label_pad_token_id = -100

    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8 if training_params.fp16 else None,
    )
    
    trainer_callback = Doc2QueryTrainerCallback(
        best_validation_yet=-1,
        model=model
    )
    
    num_training_steps = epochs * int(len(train_dataset) // (batch_size * gradient_accumulation_steps))

    optimzer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-3)
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimzer,
        0,
        num_training_steps,
        num_cycles=10
    )
    
    
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_params,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        callbacks=[trainer_callback],
        optimizers=(optimzer, scheduler),
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    
    trainer.train()

In [11]:
df_data = load_data()

In [12]:
df_train, df_val = split_data(df_data)

In [13]:
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [14]:
train_dataset, eval_dataset = load_data_set(df_train, df_val, tokenizer)

In [15]:
train(tokenizer, model, train_dataset, eval_dataset)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Bleu,Gen Len
200,1.7625,1.540922,19.9192,9.8027
400,1.4825,1.511125,20.4127,9.8727
600,1.2784,1.509185,20.2968,9.6573
800,1.1336,1.500389,21.0097,9.7464
1000,1.2964,1.668191,17.5656,9.6655
1200,1.3182,1.662287,17.1887,9.8264
1400,1.2235,1.648296,16.4835,9.9136
1600,1.2639,1.64638,16.1199,9.9609
1800,1.2519,1.646291,16.0158,9.9845
2000,1.2654,1.64735,16.0103,9.9764


compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.5409224033355713
metrics['eval_bleu']=19.9192
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.511125087738037
metrics['eval_bleu']=20.4127
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.5091851949691772
metrics['eval_bleu']=20.2968
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.5003886222839355
metrics['eval_bleu']=21.0097
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.6681907176971436
metrics['eval_bleu']=17.5656
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.6622869968414307
metrics['eval_bleu']=17.1887
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.6482958793640137
metrics['eval_bleu']=16.4835
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.646379828453064
metrics['eval_bleu']=16.1199
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=1.6462910175323486
metrics['eval_bleu']=16.0158
compute_metrics. pred

compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metrics['eval_bleu']=0.0195
compute_metrics. preds.shape=(1100, 20)
metrics['eval_loss']=nan
metri