# Example for the Answer Generator model

## Imports + model initialization

In [1]:
import time
import numpy as np
import pandas as pd
import tensorflow as tf

from models.qa import AnswerGenerator
from utils import set_display_options
from datasets import get_dataset, prepare_dataset, train_test_split, test_dataset_time

set_display_options()

model_name = 'bart_qa'
bert_base  = 'facebook/bart-large'

print("Tensorflow version : {}".format(tf.__version__))

Tensorflow version : 2.10.0


In [None]:
config = {
    'lang'             : 'en',
    'input_format'     : ['{question}', '{context}'],
    'output_format'    : '{answer}',
    'text_encoder'     : bert_base,
    'max_input_length' : 512,
    
    'pretrained' : bert_base
}

model = AnswerGenerator(nom = model_name, ** config)

print(model)

In [None]:
model.summary()

## Model instanciation + dataset loading

In [None]:
model = AnswerGenerator(nom = model_name, max_to_keep = 1)

lr = {'name' : 'WarmupScheduler', 'maxval' : 5e-5,'minval' : 1e-5, 'factor' : 512, 'warmup_steps' : 8192}
lr = 1e-5

model.compile(optimizer = 'adam', optimizer_config = {'lr' : lr}, metrics = ['TextAccuracy', 'F1'])
print(model)

In [None]:
datasets = 'squad' if 'nq' not in model_name else 'nq'

dataset = get_dataset(datasets, clean_text = True, skip_impossible = True, keep_mode = 'longest')
train, valid = dataset['train'], dataset['valid']


print("Dataset length :\n  Training set : {}\n  Validation set : {}".format(
    len(train), len(valid)
))

## Training

In [None]:
epochs = 1
batch_size = 8 if datasets == 'squad' else 6
shuffle_size = batch_size * 32

max_input_length = 512
max_output_length = 128

print("Training samples   : {} - {} batches".format(len(train), len(train) // batch_size))
print("Validation samples : {} - {} batches".format(len(valid), len(valid) // (batch_size * 2)))

hist = model.train(
    train, validation_data = valid, 
    epochs = epochs, batch_size = batch_size, valid_batch_size = 2.,
    shuffle_size = shuffle_size, max_input_length = max_input_length, max_output_length = max_output_length
)

In [None]:
model.plot_history()
print(model.history)

## Evaluate

In [None]:
model.test(valid)

## Prediction

In [None]:
config = model.get_dataset_config(batch_size = 2, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(valid.sample(10, random_state = 0), ** config)

for batch in ds.take(5):
    model.predict_with_target(batch, n_pred = 5)


## Tests

In [None]:
config = model.get_dataset_config(batch_size = 16, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(valid, ** config)

test_dataset_time(ds, steps = 100)

In [None]:
from custom_train_objects.optimizers import WarmupScheduler

lr = WarmupScheduler(maxval = 1e-3, minval = 1e-4, factor = 256, warmup_steps = 4096)
lr.plot(25000)

In [None]:
lr = model.get_optimizer().learning_rate
lr.assign(5e-4)

In [None]:
lr = model.get_optimizer().learning_rate
print(lr)