# Table Question generation: WikiSQL dataset
In this notebook, we will see how to fine-tune and evaluate a question generation model on WikiSQL dataset.

## Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [2]:
model_name_or_path="t5-small"
modality="table"
dataset_name="wikisql"
max_len=200
target_max_len=40
output_dir="../../models/qg/wikisql_nb"
learning_rate=0.0001
num_train_epochs=2
per_device_train_batch_size=8
per_device_eval_batch_size=32
evaluation_strategy='epoch'

In [3]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='epoch',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    )
training_args.predict_with_generate=True
training_args.remove_unused_columns = False
training_args.prediction_loss_only = False

## WikiSQL data
Here we load one instance of WikiSQL and visualize it. This part of the code is not needed to train the model.

In [4]:
from datasets import load_dataset
from tabulate import tabulate

def print_wikisql_instance(train_instance):
    table = train_instance['table']
    print('Table:\n',tabulate(table['rows'], headers=table['header'], tablefmt='grid'))

    print('Question = ',train_instance['question'])
    print('SQL = ', train_instance['sql']['human_readable'])

train_instance = load_dataset('wikisql', split='train[11:12]')[0]
print_wikisql_instance(train_instance)

Using custom data configuration default
Reusing dataset wiki_sql (/dccstor/cmv/.cache/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)


Table:
 +-----------------------+-------------------------------+------------------------+--------------------+------------------------+
| Aircraft              | Description                   | Max Gross Weight       | Total disk area    | Max disk Loading       |
| Robinson R-22         | Light utility helicopter      | 1,370 lb (635 kg)      | 497 ft² (46.2 m²)  | 2.6 lb/ft² (14 kg/m²)  |
+-----------------------+-------------------------------+------------------------+--------------------+------------------------+
| Bell 206B3 JetRanger  | Turboshaft utility helicopter | 3,200 lb (1,451 kg)    | 872 ft² (81.1 m²)  | 3.7 lb/ft² (18 kg/m²)  |
+-----------------------+-------------------------------+------------------------+--------------------+------------------------+
| CH-47D Chinook        | Tandem rotor helicopter       | 50,000 lb (22,680 kg)  | 5,655 ft² (526 m²) | 8.8 lb/ft² (43 kg/m²)  |
+-----------------------+-------------------------------+------------------------+-------

The SQL gets converted to a string format which goes as input to generator to generate question

In [5]:
from primeqa.qg.processors.table_qg.wikisql_processor import WikiSqlDataset

data = WikiSqlDataset()
processed_data = data.preprocess_data_for_qg('train[11:12]')
print('Question = ', processed_data['question'][0])
print('\nInput to generator = ', processed_data['input'][0])

Using custom data configuration default
Reusing dataset wiki_sql (/dccstor/cmv/.cache/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 1/1 [00:00<00:00, 43.99it/s]

Question =  What is the max gross weight of the Robinson R-22?

Input to generator =  select <<sep>> Max Gross Weight <<sep>> Aircraft <<cond>> equal <<cond>> Robinson R-22 <<answer>> 1,370 lb (635 kg) <<header>> Aircraft <<hsep>> Description <<hsep>> Max Gross Weight <<hsep>> Total disk area <<hsep>> Max disk Loading






## Loading the Model

Here we load the model based on the model_name and modality parameter set above. For WikiSQL we keep modality='table'. Other option is modality='passage'

In [6]:
from primeqa.qg.models.qg_model import QGModel

qg_model = QGModel(model_name_or_path, modality=modality)

# Loading Data

Here we load the data

In [7]:
from primeqa.qg.processors.data_loader import QGDataLoader

qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    dataset_name=dataset_name,
    input_max_len=max_len,
    target_max_len=target_max_len
    )

train_dataset = qgdl.create("train[:100]")
valid_dataset = qgdl.create("validation[:50]")

Using custom data configuration default
Reusing dataset wiki_sql (/dccstor/cmv/.cache/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 100/100 [00:00<00:00, 1183.50it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Using custom data configuration default
Reusing dataset wiki_sql (/dccstor/cmv/.cache/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 50/50 [00:00<00:00, 1191.83it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]

# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [None]:
from primeqa.qg.trainers.qg_trainer import QGTrainer
from primeqa.qg.metrics.generation_metrics import rouge_metrics
from primeqa.qg.utils.data_collator import T2TDataCollator
import os

compute_metrics = rouge_metrics(qg_model.tokenizer)

trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    data_collator=T2TDataCollator(),
    compute_metrics=compute_metrics
    )

train_results = trainer.train()
trainer.save_model()
print(train_results.metrics)

## Evaluation

Here we evaluate the trained model on validation set

In [None]:
metrics = trainer.evaluate()
print(metrics)