# Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [7]:
model_name_or_path="t5-small"
modality="table"
dataset_name="wikisql"
max_len=200
target_max_len=40
output_dir="models/qg/trials"
learning_rate=0.0001
num_train_epochs=1
per_device_train_batch_size=8
per_device_eval_batch_size=8


# Loading the Model

Here we load the model based on the model_name parameter set above.  We use a QG model for modality=table

In [8]:
from transformers import (
    DataCollator,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)
from primeqa.qg.processors.data_loader import QGDataLoader
import torch
from dataclasses import dataclass,field
from primeqa.qg.models.qg_model import QGModel
from primeqa.qg.trainers.qg_trainer import QGTrainer
from typing import Optional, List, Dict
from examples.qg.run_qg import T2TDataCollator, ModelArguments, DataTrainingArguments


import json
import logging
import os

seed=42

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=False,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='no',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    seed=seed
    )

set_seed(training_args.seed)

model_args = ModelArguments(
        model_name_or_path=model_name_or_path,
        modality=modality
    )

data_args = DataTrainingArguments(
    dataset_name = dataset_name,
    max_len = max_len,
    target_max_len = target_max_len
    )

qg_model = QGModel(model_args.model_name_or_path, modality=model_args.modality)



Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

# Loading Data

Here we load the Wikisql dataset using Huggingface's datasets library.
Note that this step is optional and is here to show the sample data format.
The next tep of data processing loads and processes the data via our custom QGDataLoader class.

In [11]:
from tqdm import tqdm
from datasets import load_dataset
from primeqa.qg.utils.constants import SqlOperants, QGSpecialTokens

data = load_dataset('wikisql')
train_dataset = data['train']
print(train_dataset[0])




Using custom data configuration default


Downloading and preparing dataset wiki_sql/default (download: 24.95 MiB, generated: 147.57 MiB, post-processed: Unknown size, total: 172.52 MiB) to /u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d...


Generating test split:   0%|          | 0/15878 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8421 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/56355 [00:00<?, ? examples/s]

Dataset wiki_sql downloaded and prepared to /u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

{'phase': 1, 'question': 'Tell me what the notes are for South Australia ', 'table': {'header': ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes'], 'page_title': '', 'page_id': '', 'types': ['text', 'text', 'text', 'text', 'text', 'text'], 'id': '1-1000181-1', 'section_title': '', 'caption': '', 'rows': [['Australian Capital Territory', 'blue/white', 'Yaa·nna', 'ACT · CELEBRATION OF A CENTURY 2013', 'YIL·00A', 'Slogan screenprinted on plate'], ['New South Wales', 'black/yellow', 'aa·nn·aa', 'NEW SOUTH WALES', 'BX·99·HI', 'No slogan on current series'], ['New South Wales', 'black/white', 'aaa·nna', 'NSW', 'CPX·12A', 'Optional white slimline series'], ['Northern Territory', 'ochre/white', 'Ca·nn·aa', 'NT · OUTBACK AUSTRALIA', 'CB·06·ZZ', 'New series began in June 2011'], ['Queensland', 'maroon/white', 'nnn·aaa', 'QUEENSLAND · SUNSHINE STATE', '999·TLG', 'Slogan embossed on plate'], ['South Australia', 'black/white', 'Snnn·aaa', 'SOUTH AUS

# Data Processing 

We process the data using QGDataLoader class to make it fit for QG training.

In [3]:
qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    dataset_name=data_args.dataset_name,
    input_max_len=data_args.max_len,
    target_max_len=data_args.target_max_len
    )

train_dataset = qgdl.create("train")

valid_dataset = qgdl.create("validation")

print(len(train_dataset))
print(train_dataset[0])



Using custom data configuration default
Reusing dataset wiki_sql (/u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 56355/56355 [00:24<00:00, 2286.49it/s]


  0%|          | 0/57 [00:00<?, ?ba/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Using custom data configuration default
Reusing dataset wiki_sql (/u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 8421/8421 [00:03<00:00, 2393.96it/s]


  0%|          | 0/9 [00:00<?, ?ba/s]

56355
{'input_ids': tensor([ 1738, 32100,  2507,     7, 32100, 12892, 22031, 32101,  4081, 32101,
          180,  9744,   566,     3,  6727, 13733, 24933,   188, 32102,   150,
        22031,    30,   750,   939, 32103,  1015,    87,    17, 21301, 10972,
        32104,  5027,    87,  1549,  9232,  3243, 32104, 12439, 32104, 12892,
        22031, 32104, 12892,   939, 32104,  2507,     7,     1,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,  

# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [5]:
trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    data_collator=T2TDataCollator()
    )

trainer.train(
        model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
trainer.save_model()

***** Running training *****
  Num examples = 56355
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 7045


Step,Training Loss
500,2.1484
1000,1.6289
1500,1.4896


Saving model checkpoint to models/qg/trials/checkpoint-500
Configuration saved in models/qg/trials/checkpoint-500/config.json
Model weights saved in models/qg/trials/checkpoint-500/pytorch_model.bin
tokenizer config file saved in models/qg/trials/checkpoint-500/tokenizer_config.json
Special tokens file saved in models/qg/trials/checkpoint-500/special_tokens_map.json
Copy vocab file to models/qg/trials/checkpoint-500/spiece.model
Saving model checkpoint to models/qg/trials/checkpoint-1000
Configuration saved in models/qg/trials/checkpoint-1000/config.json
Model weights saved in models/qg/trials/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in models/qg/trials/checkpoint-1000/tokenizer_config.json
Special tokens file saved in models/qg/trials/checkpoint-1000/special_tokens_map.json
Copy vocab file to models/qg/trials/checkpoint-1000/spiece.model
Saving model checkpoint to models/qg/trials/checkpoint-1500
Configuration saved in models/qg/trials/checkpoint-1500/config.json
