# Example for the Memory Augmented Generator (MAG) model

## Imports + model initialization

In [1]:
import pandas as pd
import tensorflow as tf

from loggers import set_level, add_handler
from models.qa import MAG
from utils import set_display_options
from datasets import get_dataset, prepare_dataset, train_test_split, test_dataset_time
from models.model_utils import get_model_history

from experiments_mag import config_from_name, training_config_from_name

#set_level('time')
set_display_options()

model_name = 'm5_nq_coqa_newsqa_mag_split_off_entq_ct_wt_ib_2_2_dense'
bert_base  = 'facebook/bart-large'

print("Tensorflow version : {}".format(tf.__version__))

Tensorflow version : 2.6.2


In [None]:
config = config_from_name(model_name)
config.pop('class')

if 'pretrained_name' in config and not is_model_name(config['pretrained_name'])::
    logging.warning('Pretrained model {} does not exists !'.format(config['pretrained_name']))
    config.pop('pretrained_name')

tf.config.set_visible_devices([], 'GPU')
if 'pretrained_name' in config:
    model = MAG.from_pretrained(nom = model_name, pretrained_name = model_name.replace('dense', 'mean'), ** config)
else:
    model = MAG(nom = model_name, ** config)

print(model)

In [None]:
model.model.freeze(trainable = True)
model.encoder.subsampling_layer.trainable = False
model.summary()
model.model.encoder.summary()

## Model instanciation + dataset loading

In [None]:
model = MAG(nom = model_name, max_to_keep = 2)

if 'dense' in model_name:
    lr = {'name' : 'WarmupScheduler', 'maxval' : 5e-5,'minval' : 1e-5, 'factor' : 512, 'warmup_steps' : 8192}
    lr = {'name' : 'DivideByStep', 'maxval' : 1e-5,'minval' : 1e-6, 'factor' : 0.1}
else:
    lr = 1e-5

model.compile(optimizer = 'adam', optimizer_config = {'lr' : lr}, metrics = ['TextAccuracy', 'F1'])
print(model)

In [None]:
datasets = 'squad' if 'nq' not in model_name else 'nq'
#datasets = ['nq', 'squad']

use_doc = True if 'nq' in datasets and 'doc' in model_name else False
use_doc = True

dataset = get_dataset(
    datasets, clean_text = True, skip_impossible = True, shuffle = True, use_long_answer = False,
    include_document = use_doc, keep_mode = 'all'
)
train, valid = dataset['train'], dataset['valid']

print("Dataset length :\n  Training set : {}\n  Validation set : {}".format(
    len(train), len(valid)
))

In [None]:
freqs = np.array([model.encode_data(row)[1][1] for row in tqdm(train.sample(10000).to_dict('records'))])
print(freqs)
print(np.sum(freqs <= 64 * 3))
plot(freqs, plot_type = 'hist')


## Training

In [None]:
TOKEN = None

add_handler('telegram', token = TOKEN)

In [None]:
fine_tuning = True

if fine_tuning and 'dense' not in model_name:
    model.get_optimizer().learning_rate.assign(1e-5)

#if 'dense' in model_name and model.epochs == 0:
#    model.encoder.subsampling_layer.trainable = False

epochs = 1 if fine_tuning else 1
if not isinstance(epochs, (list, tuple)): epochs = [epochs]

if datasets == 'squad':
    batch_size = 8 if fine_tuning else 16
else:
    if model.subsampling_factor < 2:
        batch_size = 3 if fine_tuning else 16
    elif model.subsampling_factor == 2:
        batch_size = 4 if fine_tuning else 16
    elif model.subsampling_factor == 3:
        batch_size = 5 if fine_tuning else 16
    elif model.subsampling_factor == 5:
        batch_size = 6 if fine_tuning else 16

max_negatives = 5

shuffle_size = 0 if sum(epochs) + model.epochs == 1 else batch_size * 32

augment_prct = 0. if use_doc else 0.25
nb_mask = 1 if 'aug' not in model_name else 2
min_mask_length = 1
max_mask_length = 1 if 'aug' not in model_name else 2

negative_mode = None
if 'ib' in model_name:
    negative_mode = 'batch'
elif use_doc:
    negative_mode = 'doc'

max_input_length = 512
max_output_length = 128

if use_doc: batch_size = batch_size // 2
elif 'split' in model_name : batch_size -= 1

print("Training samples   : {} - {} batches".format(len(train), len(train) // batch_size))
print("Validation samples : {} - {} batches".format(len(valid), len(valid) // (batch_size * 2)))

for e in epochs:
    hist = model.train(
        train, validation_data = valid, 
        epochs = e, batch_size = batch_size, valid_batch_size = 2.,
        shuffle_size = shuffle_size, max_input_length = max_input_length, max_output_length = max_output_length,
        negative_mode = negative_mode, max_negatives = max_negatives, 
        is_rectangular = False if use_doc else True,

        augment_prct = augment_prct, nb_mask = nb_mask, min_mask_length = min_mask_length, max_mask_length = max_mask_length
    )

In [None]:
model.plot_history()
print(model.history)

In [None]:
pd.DataFrame(model.history.trainings_infos)

In [None]:
from models.model_utils import get_model_history
from utils import time_to_string

h = get_model_history('m_nq_mag_off_3_12_mean')
pd.DataFrame(h.logs)
time_to_string(13000)

## Evaluate

In [None]:
model.test(valid, batch_size = 8)
model.save()

In [None]:
model.test(
    valid, batch_size = 12, max_input_length = 512, negative_mode = None, is_rectangular = True,
    max_negatives = 5, max_output_length = 32, add_loss = False, metrics = ['F1'],
    teacher_forcing_eval = False, eval_infer_config = {}, verbose = 0
)

## Prediction

In [None]:
set_level('info')

config = model.get_dataset_config(batch_size = 5, is_validation = True, shuffle_size = 0)
ds = prepare_dataset(valid.sample(25, random_state = 0), ** config, is_rectangular = not use_doc)

for batch in ds:
    model.predict_with_target(batch, n_pred = 10, debug = False)


In [None]:
question = [
    'How is the night vision of cat ?',
    'How is the night vision of cat ?',
    'What is the anoatomy of a cat ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'What is the origin of life ?'
]
context  = [
    'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.',
    [p.strip() + '.' for p in 'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.'.split('.') if len(p) > 0],
    ['The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.', 'The answer to everything is 42'],
    'A cat is an animal which has 4 paws and whiskers.',
    'A cat is an animal which has 4 paws and whiskers. However, everyone knows that the answer to everything is 42 !',
    ['A cat is an animal which has 4 paws and whiskers.', 'However, everyone knows that the answer to everything is 42 !'],
    'The answer to everything is 42.'
]

n = 1
#question, context = question[n], [context[n]]

if not isinstance(question, list): question = [question]
if not isinstance(context, list): context = [context]

answers = model.predict(question, context, title = 'cat', method = 'beam')

for q, c, a in zip(question, context, answers):
    print("Question : {}\nContext : {}\nAnswer : {}\n".format(q, c, model.infer_to_str(a[0], a[1]))


In [None]:
for q, c, a in zip(question, context, answers):
    print("Question : {}\nContext : {}\nAnswer : {}\n".format(q, c, a))
    print(model.infer_to_str(a[1], a[0]))

## Model comparison

In [None]:
import os
import numpy as np

from utils import plot_multiple
from models.model_utils import compare_models, get_models, remove_training_checkpoint

def _extract_topk_tests(infos, color_corr):
    top5_prefix = set([c.split('-')[0] for c in infos.columns if 'top5' in c and 'test' in c])
    to_drop     = [c for c in infos.columns if 'top5' in c]
    
    top5_results = {}
    for prefix in top5_prefix:
        top5_results[prefix] = {'x' : {}, 'with_legend' : False}
        for _, row in infos.iterrows():
            if row['nom'].startswith('m4_') and 'squad' not in prefix: continue
            
            k = np.array(sorted([int(c.split('-')[-1]) for c in infos.columns if c.startswith(prefix)]))
            score = {c.split('-')[1] : row[c] for c in infos.columns if c.startswith(prefix)}
            top5_results[prefix]['x'][row['nom']] = {
                'x' : k,
                'y' : np.array([score[str(k)] for k in k]),
                'c' : _colors[row[color_corr]],
                'ls': _styles[row['nom'][:2]] if 'split' not in row['nom'] else '-.'
            }
    return infos.drop(to_drop, axis = 1), top5_results

def plot_and_sort(infos, metric = 'val_loss', color_corr = 'encoder_subsampling_step', shape_corr = None, ascending = True):
    to_drop = [c for c in _cols_to_drop if c in infos.columns]
    to_drop += [c for c in infos.columns if 'test_doc' in c and '-' not in c]
    infos = infos.drop(to_drop, axis = 1)
    if 'negative_mode' in infos.columns:
        infos['negative_mode'].fillna('none', inplace = True)
        infos.loc[infos['negative_mode'] != 'doc', 'max_negatives'] = -1
        
    if 'split_contexts' in infos.columns:
        infos['split_contexts'].fillna(False, inplace = True)
        infos.loc[infos['split_contexts'] != False, 'max_sent_per_ctx'] = -1

    if 'encoder_subsampling_mode' in infos.columns:
        infos['encoder_subsampling_mode'].fillna('none', inplace = True)
        if 'encoder_subsampling_step' in infos.columns:
            infos['encoder_subsampling_step'].fillna(0, inplace = True)
            infos.loc[infos['encoder_subsampling_step'] < 2, 'encoder_subsampling_mode'] = 'none'

    infos['nom'] = infos.index
    infos, top_k_tests = _extract_topk_tests(infos, color_corr)
            
    plot_multiple(
        infos, corr = metric, ** top_k_tests, linewidth = 5,
        color_corr = color_corr if color_corr in infos.columns else None, color_order = _colors,
        shape_corr = shape_corr if shape_corr in infos.columns else None, shape_order = _shapes,
        link_from_to = ('pretrained_name', 'nom'),
        ncols = 4, x_size = 4, y_size = 4#, filename = 'mag_plots/{}.png'.format(metric), show = True
    )

    return infos.sort_values(metric, ascending = ascending)

_colors = {
    i : color for i, color in enumerate(['w', 'r', 'cyan', 'g', 'b', 'violet'])
}
_shapes = {
    mode : shape for mode, shape in [('none', 'o'), ('mean', 'x'), ('dense', 'D')]
}
_styles = {
    'm3' : 'dotted',
    'm4' : 'solid',
    'm5' : '--'
}
_cols_to_drop = [
    'input_format', 'shuffle_size', 'eval_infer_config', 'augment_prct', 'max_output_length'
]

names = get_models('m3_nq_mag_off_entq_ct_wt_*') + ['m3_nq']
names += get_models('m5_*') + get_models('m4_*')

# names += [n for n in os.listdir('pretrained_models') if n.startswith('test_mag_')]

infos = compare_models(names, True, True, epoch = 'last', add_training_config = True)

plot_and_sort(
    infos, 'val_loss',
    color_corr = 'encoder_subsampling_step' if 'encoder_subsampling_step' in infos.columns else 'negative_mode',
    shape_corr = 'encoder_subsampling_mode'
)

In [None]:
plot_and_sort(
    infos, 'test_F1',
    color_corr = 'encoder_subsampling_step' if 'encoder_subsampling_step' in infos.columns else 'negative_mode',
    shape_corr = 'encoder_subsampling_mode',
    ascending  = False
)

In [None]:
plot_and_sort(
    infos, 'test_squad_F1',
    color_corr = 'encoder_subsampling_step' if 'encoder_subsampling_step' in infos.columns else 'negative_mode',
    shape_corr = 'encoder_subsampling_mode',
    ascending  = False
)

In [None]:
import os

from utils import load_json, dump_json
from models.model_utils import get_models

names = get_models('m3_nq_mag*')

for name in names:
    filename = os.path.join('pretrained_models', name, 'config.json')
    
    config = load_json(filename)
    
    config['config'].update({
        'context_offset' : -1,
        'encoder_positional_offset' : 128
    })
    
    dump_json(filename, config, indent = 4)
    
    filename = os.path.join('pretrained_models', name, 'saving', 'model.json')
    
    config = load_json(filename)
    
    config['config'].update({
        'encoder_positional_offset' : 128
    })
    
    dump_json(filename, config, indent = 4)

In [None]:
import os

from models.model_utils import get_models, remove_training_checkpoint

for n in get_models('m4*'):
    remove_training_checkpoint(n)

## Dataset analysis

## Tests

In [None]:
#model.negative_mode = 'doc'
model.max_negatives = 5
model.max_input_length = 512
model.max_sentence_length = 128

config = model.get_dataset_config(batch_size = 16, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(valid, ** config, is_rectangular = False)

test_dataset_time(ds, steps = 1000 if not use_doc else 10)

In [None]:
import importlib

import models.qa.mag as mag
import models.qa.answer_generator_split as answer_generator_split

importlib.reload(answer_generator_split)
importlib.reload(mag)

In [None]:
def get_ds(model):
    model.negative_mode = 'doc'
    model.max_negatives = 8
    model.max_input_length = 512
    model.max_sentence_length = 128
    model.max_sent_per_ctx    = 5

    config = model.get_dataset_config(batch_size = 1, is_validation = True, prefetch = False, prefetch_size = 0)
    ds = prepare_dataset(valid, ** config, is_rectangular = False)
    return ds

ds = get_ds(model)

for (inp, out) in ds.take(20):
    print("Inputs shape : {}".format([tuple(i.shape) for i in inp[:-2]]))
    print("Total length : {}\n".format(inp[1] + tf.reduce_sum(inp[3], axis = -1)))
    #print("Outputs shape : {}".format([tuple(i.shape) for i in out]))

In [None]:
from custom_train_objects.optimizers import WarmupScheduler, DivideByStep

lr = WarmupScheduler(maxval = 5e-5, minval = 5e-6, factor = 32, warmup_steps = 128)
lr = DivideByStep(0.5, maxval = 5e-5, minval = 5e-6)
lr.plot(50000)