# Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [1]:
model_name_or_path="t5-small"
modality="table"
dataset_name="wikisql"
max_len=200
target_max_len=40
output_dir="models/qg/trials"
learning_rate=0.0001
num_train_epochs=1
per_device_train_batch_size=8
per_device_eval_batch_size=8


# Loading the Model

Here we load the model based on the model_name parameter set above.  We use a QG model for modality=table

In [2]:
from transformers import (
    DataCollator,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)
from primeqa.qg.processors.data_loader import QGDataLoader
import torch
from dataclasses import dataclass,field
from primeqa.qg.models.qg_model import QGModel
from primeqa.qg.trainers.qg_trainer import QGTrainer
from typing import Optional, List, Dict
from primeqa.qg.trainers.qg_trainer_utils import T2TDataCollator, ModelArguments, DataTrainingArguments, QGTrainingArguments, InferenceArguments
from examples.qg.run_qg import TrainingArguments


import json
import logging
import os

seed=42

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=False,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='no',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    seed=seed
    )

set_seed(training_args.seed)

model_args = ModelArguments(
        model_name_or_path=model_name_or_path,
        modality=modality
    )

data_args = DataTrainingArguments(
    dataset_name = dataset_name,
    max_len = max_len,
    target_max_len = target_max_len
    )

qg_model = QGModel(model_args.model_name_or_path, modality=model_args.modality)



# Loading Data

Here we load the Wikisql dataset using Huggingface's datasets library.

In [3]:
qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    dataset_name=data_args.dataset_name,
    input_max_len=data_args.max_len,
    target_max_len=data_args.target_max_len
    )

train_dataset = qgdl.create("train")

valid_dataset = qgdl.create("validation")

print(len(train_dataset))
print(train_dataset[0])



Using custom data configuration default
Reusing dataset wiki_sql (/u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 56355/56355 [00:24<00:00, 2286.49it/s]


  0%|          | 0/57 [00:00<?, ?ba/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Using custom data configuration default
Reusing dataset wiki_sql (/u/jaydesen/.cache/huggingface/datasets/wiki_sql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 8421/8421 [00:03<00:00, 2393.96it/s]


  0%|          | 0/9 [00:00<?, ?ba/s]

56355
{'input_ids': tensor([ 1738, 32100,  2507,     7, 32100, 12892, 22031, 32101,  4081, 32101,
          180,  9744,   566,     3,  6727, 13733, 24933,   188, 32102,   150,
        22031,    30,   750,   939, 32103,  1015,    87,    17, 21301, 10972,
        32104,  5027,    87,  1549,  9232,  3243, 32104, 12439, 32104, 12892,
        22031, 32104, 12892,   939, 32104,  2507,     7,     1,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,  

# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [4]:
trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    data_collator=T2TDataCollator()
    )

trainer.train(
        model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
trainer.save_model()

NameError: name 'sub_train_dataset' is not defined