# Example for the Memory Augmented Generator (MAG) model

## Imports + model initialization

In [1]:
import pandas as pd
import tensorflow as tf

from loggers import set_level, add_handler
from models.nlu import infer_to_str
from models.qa import MAG
from utils import set_display_options
from datasets import get_dataset, prepare_dataset, test_dataset_time
from models import get_pretrained, get_models, get_model_history

from experiments_mag import config_from_name, training_config_from_name, testing_config_from_name, predict_config_from_name

set_display_options()

model_name = 'mag_nq_coqa_newsqa_split_off_wt_ib_2_2_mean'

print("Tensorflow version : {}".format(tf.__version__))

Tensorflow version : 2.10.0


In [3]:
config = config_from_name(model_name)
config.pop('class')

if 'pretrained_name' in config and not is_model_name(config['pretrained_name']):
    logging.warning('Pretrained model {} does not exist !'.format(config.pop('pretrained_name')))

tf.config.set_visible_devices([], 'GPU')
if 'pretrained_name' in config:
    model = MAG.from_pretrained(** config)
else:
    model = MAG(** config)

print(model)

When using token / word-level tokenizer, it can be useful to add 'detach_punctuation' in cleaners


All model checkpoint layers were used when initializing TFBartForConditionalGeneration.

All the layers of TFBartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBartForConditionalGeneration for predictions without further training.
100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [00:19<00:00,  5.00it/s]


Weights transfered successfully !


100%|████████████████████████████████████████████████████████████████████████████████| 160/160 [00:29<00:00,  5.48it/s]


Weights transfered successfully !
Initializing model with kwargs : {'model': <custom_architectures.transformers_arch.mag_wrapper.MAGWrapper object at 0x0000015B17049828>}
Initializing submodel : `model` !
Submodel model saved in pretrained_models\mag_nq_coqa_newsqa_split_off_wt_ib_2_2_mean\saving\model.json !
Model mag_nq_coqa_newsqa_split_off_wt_ib_2_2_mean initialized successfully !

Sub model model
- Inputs 	: unknown
- Outputs 	: unknown
- Number of layers 	: 1
- Number of parameters 	: 406.604 Millions
- Model not compiled

Transfer-learning from : facebook/bart-large
Already trained on 0 epochs (0 steps)

- Language : en
- Vocabulary (size = 50265) : ['<s>', '<pad>', '</s>', '<unk>', '.', 'Ġthe', ',', 'Ġto', 'Ġand', 'Ġof', 'Ġa', 'Ġin', '-', 'Ġfor', 'Ġthat', 'Ġon', 'Ġis', 'âĢ', "'s", 'Ġwith', 'ĠThe', 'Ġwas', 'Ġ"', 'Ġat', 'Ġit', ...]
- Input format : {question}
- Multi input format : {context}
- Output format : {answer}
- Split multi input (key : context) : True
- Max sentences per

In [None]:
model.model.freeze(trainable = True)
model.encoder.subsampling_layer.trainable = False
model.summary()
model.model.encoder.summary()

## Training

In [None]:
model = get_pretrained(model_name)

config = training_config_from_name(model_name, model.epochs > 0)

model.compile(** config.pop('compile_config', {}))

dataset = get_dataset(config['dataset'], ** config.pop('dataset_config', {}))

train, valid = dataset['train'], dataset['valid']

print("Training samples   : {} - {} batches".format(len(train), len(train) // config['batch_size']))
print("Validation samples : {} - {} batches".format(len(valid), len(valid) // (config['batch_size'] * 2)))

hist = model.train(
    train, validation_data = valid, valid_batch_size = 2., ** config
)

In [None]:
model.plot_history()
print(model.history)

## Evaluate

In [None]:
test_name = 'test'

model = get_pretrained(model_name)

config = testing_config_from_name(model_name, test_name, overwrite = False)

valid = get_dataset(config['dataset'], ** config.pop('dataset_config', {}), modes = 'valid')

hist = model.test(valid, ** config)

## Prediction

In [None]:
pred_name = 'pred'

model = get_pretrained(model_name)

config = predict_config_from_name(model_name, pred_name, overwrite = False)

valid = get_dataset(config['dataset'], ** config.pop('dataset_config', {}), modes = 'valid')

hist = model.test(valid, ** config)

In [None]:
question = [
    'How is the night vision of cat ?',
    'How is the night vision of cat ?',
    'What is the anoatomy of a cat ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'What is the origin of life ?'
]
context  = [
    'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.',
    [p.strip() + '.' for p in 'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.'.split('.') if len(p) > 0],
    ['The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.', 'The answer to everything is 42'],
    'A cat is an animal which has 4 paws and whiskers.',
    'A cat is an animal which has 4 paws and whiskers. However, everyone knows that the answer to everything is 42 !',
    ['A cat is an animal which has 4 paws and whiskers.', 'However, everyone knows that the answer to everything is 42 !'],
    'The answer to everything is 42.'
]

if not isinstance(question, list): question = [question]
if not isinstance(context, list): context = [context]

answers = model.predict(question, context, title = 'cat', method = 'beam')

for q, c, a in zip(question, context, answers):
    print("Question : {}\nContext : {}\nAnswer : {}\n".format(q, c, infer_to_str(a[0], a[1]))


## Tests

In [None]:
#model.negative_mode = 'doc'
model = get_pretrained(model_name)

model.max_input_texts = 4
model.use_multi_input = False
model.max_input_length = 512
model.max_sentence_length = 128
model.merge_multi_inputs  = True

valid = get_dataset('nq', include_document = True, keep_mode = 'all', modes = 'valid')

config = model.get_dataset_config(batch_size = 16, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(valid, ** config, is_rectangular = False)

test_dataset_time(ds, steps = 1000 if not model.use_multi_input else 10)

In [None]:
from custom_train_objects.optimizers import WarmupScheduler, DivideByStep

lr = WarmupScheduler(maxval = 5e-5, minval = 5e-6, factor = 32, warmup_steps = 128)
lr = DivideByStep(0.5, maxval = 5e-5, minval = 5e-6)
lr.plot(50000)