# Question generation: TyDi dataset
In this notebook, we will see how to fine-tune and evaluate a question generation model on TyDiQA dataset

## Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [2]:
model_name_or_path="google/mt5-small"
modality="passage"
dataset_name="tydiqa"
max_len=200
target_max_len=40
output_dir="models/qg/tydi/"
learning_rate=0.0001
num_train_epochs=2
per_device_train_batch_size=8
per_device_eval_batch_size=32
evaluation_strategy='epoch'

In [3]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='epoch',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    )
training_args.predict_with_generate=True
training_args.remove_unused_columns = False
training_args.prediction_loss_only = False


## Loading the Model

Here we load the model based on the model_name and modality parameter set above. For TyDiQA we keep modality='passage'. Other option is modality='table'

In [4]:
from primeqa.qg.models.qg_model import QGModel

qg_model = QGModel(model_name_or_path, modality=modality)

Loaded NER model for  Arabic
Loaded NER model for  English
Loaded NER model for  Finnish
Loaded NER model for  Russian


# Loading Data

Here we load the TyDiQA dataset using Huggingface's datasets library. (Other supported datasets are: 'wikisql', 'squad', 'squad_v2')

In [5]:
from primeqa.qg.processors.data_loader import QGDataLoader

qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    modality=modality,
    dataset_name=dataset_name,
    input_max_len=max_len,
    target_max_len=target_max_len
    )

train_dataset = qgdl.create(dataset_split="train[:100]")
valid_dataset = qgdl.create(dataset_split="validation[:50]")

# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [6]:
from primeqa.qg.trainers.qg_trainer import QGTrainer
from primeqa.qg.metrics.generation_metrics import rouge_metrics
from primeqa.qg.utils.data_collator import T2TDataCollator
import os

compute_metrics = rouge_metrics(qg_model.tokenizer)

trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=T2TDataCollator(),
    compute_metrics=compute_metrics
    )

train_results = trainer.train()
trainer.save_model()
print(train_results.metrics)

***** Running training *****
  Num examples = 100
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 26


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,13.397782,0.0,0.0,0.0,0.0
2,No log,12.8077,0.0,0.0,0.0,0.0


***** Running Evaluation *****
  Num examples = 50
  Batch size = 32
***** Running Evaluation *****
  Num examples = 50
  Batch size = 32


Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 4.1157, 'train_samples_per_second': 48.594, 'train_steps_per_second': 6.317, 'total_flos': 41307217920000.0, 'train_loss': 21.04597590519832, 'epoch': 2.0}


## Evaluation

Here we evaluate the trained model on validation set

In [7]:
metrics = trainer.evaluate()
print(metrics)

***** Running Evaluation *****
  Num examples = 50
  Batch size = 32


{'eval_loss': 12.807700157165527, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 0.4931, 'eval_samples_per_second': 101.39, 'eval_steps_per_second': 4.056, 'epoch': 2.0}
