# Question generation: TyDi dataset
In this notebook, we will see how to fine-tune and evaluate a question generation model on TyDiQA dataset

## Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [9]:
model_name_or_path="google/mt5-small"
modality="passage"
dataset_name="tydiqa"
max_len=200
target_max_len=40
output_dir="models/qg/tydi/"
learning_rate=0.0001
num_train_epochs=2
per_device_train_batch_size=8
per_device_eval_batch_size=32
evaluation_strategy='epoch'

In [30]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='epoch',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    )
training_args.predict_with_generate=True
training_args.remove_unused_columns = False
training_args.prediction_loss_only = False

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).



## Loading the Model

Here we load the model based on the model_name and modality parameter set above. For TyDiQA we keep modality='passage'. Other option is modality='table'

In [10]:
from primeqa.qg.models.qg_model import QGModel

qg_model = QGModel(model_name_or_path, modality=modality)

# Loading Data

Here we load the TyDiQA dataset using Huggingface's datasets library. (Other supported datasets are: 'wikisql', 'squad', 'squad_v2')

In [19]:
from primeqa.qg.processors.data_loader import QGDataLoader

qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    dataset_name=dataset_name,
    input_max_len=max_len,
    target_max_len=target_max_len
    )

train_dataset = qgdl.create("train[:100]")
valid_dataset = qgdl.create("validation[:50]")

Reusing dataset tydiqa (/dccstor/cmv/.cache/tydiqa/secondary_task/1.0.0/b8a6c4c0db10bf5703d7b36645e5dbae821b8c0e902dac9daeecd459a8337148)
100%|██████████| 100/100 [00:00<00:00, 3972.52it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]

Reusing dataset tydiqa (/dccstor/cmv/.cache/tydiqa/secondary_task/1.0.0/b8a6c4c0db10bf5703d7b36645e5dbae821b8c0e902dac9daeecd459a8337148)
100%|██████████| 50/50 [00:00<00:00, 2936.78it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]

Number of instances in train 100


# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [33]:
from primeqa.qg.trainers.qg_trainer import QGTrainer
from primeqa.qg.metrics.generation_metrics import rouge_metrics
from primeqa.qg.utils.data_collator import T2TDataCollator
import os

compute_metrics = rouge_metrics(qg_model.tokenizer)

trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    data_collator=T2TDataCollator(),
    compute_metrics=compute_metrics
    )

train_results = trainer.train()
trainer.save_model()
print(train_results.metrics)

***** Running training *****
  Num examples = 100
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 26


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,3.812071,0.0,0.0,0.0,0.0
2,No log,3.77309,0.0,0.0,0.0,0.0


***** Running Evaluation *****
  Num examples = 50
  Batch size = 32


EvalPrediction(predictions=array([[     0, 250099,    259,  26889,  10126,    445,    575,  50986,
         12852,    259,  26889,  10126,    445,    575,   1021,  16869,
           267,  68056,    343,      1],
       [     0, 250099,    259,  72270,   1372,    259,    343,      1,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259, 115305,    259, 115305,    477,    259,
        115305,    477,    259, 115305,      1,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    402,  71033,  14273,   1105,  68893,    343,
             1,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259,  54231,   7123,  77866,    259,    572,
         18197,    402,  50309,  25538,  15735,  13792,    259,    264,
           259,  54231,   7123,  77866],
       [     0, 250099,   5326,  20159,    259, 

***** Running Evaluation *****
  Num examples = 50
  Batch size = 32


EvalPrediction(predictions=array([[     0, 250099,    259,  26889,  10126,    445,    575,  50986,
         12852,   1021,  16869,    267,  53817,    343,      1,      0,
             0,      0,      0,      0],
       [     0, 250099,    259,  72270,   1372,    259,    343,      1,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259, 115305,    259,  24720,    259,  17912,
           259,  83386,    477,   5643,  17604,      1,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    402,  71033,  14273,   1105,  68893,    343,
             1,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259,  54231,   7123,  77866,    259,    572,
         18197,  78199,    259,  54231,   7123,  77866,    259,    572,
         18197,  78199,    259,  54231],
       [     0, 250099,   5326,  20159,    259, 



Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to models/qg/tydi/
Configuration saved in models/qg/tydi/config.json
Model weights saved in models/qg/tydi/pytorch_model.bin
tokenizer config file saved in models/qg/tydi/tokenizer_config.json
Special tokens file saved in models/qg/tydi/special_tokens_map.json
Copy vocab file to models/qg/tydi/spiece.model


## Evaluation

Here we evaluate the trained model on validation set

In [32]:
metrics = trainer.evaluate()
print(metrics)

***** Running Evaluation *****
  Num examples = 50
  Batch size = 32


EvalPrediction(predictions=array([[     0, 250099,    259,  26889,  10126,    445,    575,  50986,
         12852,   1021,  16869,    267,  53817,    343,      1,      0,
             0,      0,      0,      0],
       [     0, 250099,    259,    343,      1,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259, 115305,    259, 115305,    477,    259,
        115305,    477,    259, 115305,      1,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    402,  71033,  14273,   1105,  68893,    343,
             1,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0],
       [     0, 250099,    259,  54231,   7123,  77866,    259,    264,
           259,  54231,   7123,  77866,    259,    264,    259,  47130,
         33284,  26142,    260,      1],
       [     0, 250099,   5326,  20159,    343, 