#  Unsupervised learning addition model with generator

Author: S. Menary [sbmenary@gmail.com]

Date: 11/4/2023  (last update: 15/4/2023)

Overview: Train a `sequence -> sequence` model where the input sequence is a text representation of a simple sum $\sum_{i=1}^N A_i$ for a configurable number $N$ of integers $A_i\in\mathbb{Z}$, and the output is a set of logits representing the probability of each token in the output sequence. Integers may have a configurable number of digits. At inference time, chains of text are generated auto-regressively until the terminate-sequence token is reached. The loss function is a sparse categorical entropy.

---

## 1. Set up program

###  Import

All imports go here at the top of the notebook

In [1]:
##=========================##
##   All imports go here   ##
##=========================##

##  Import entire python stdlib packages
import logging, os, sys

##  Import entire pypi packages
import tensorflow as tf

##  Remove tensorflow INFO messages
tf.get_logger().setLevel('WARNING')

##  Add directory above this to system path to expose mathsformer package location
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

##  Import individual modules/objects from local packages
from tensorflow.keras.optimizers import Adam, AdamW, SGD
from mathsformer import config, data, transformers, utils
from mathsformer import selfsupervised_learning_addition_model_backend as backend


## 1. Configure run

Set configuration variables for entire program

In [2]:
##==============================##
##   Set custom config values   ##
##==============================##

custom_config = {
    "global" : {
        "base_seed"        : -1,
        "working_dir"      : "SSL_addition_generator_notebook_[problem_tag]_[model_tag]_[date]",
        "problem_tag"      : "int123_num124",
        "model_tag"        : "width160_enc6_loop1_dec6_loop1_post3_adam",
        "log_lvl_iostream" : logging.INFO,
        "log_lvl_fstream"  : logging.DEBUG,
    },
    "data" : {
        "train_data" : {
            "int_lengths"      : [1, 2, 3],
            "num_ints"         : [1, 2, 4],
            "batch_size"       : 32,
            "num_batches"      : 2000,
            "gen_base_seed"    : 101,
            "gen_reproducible" : False, 
        },
        "val_data" : {
            "int_lengths"      : [1, 2, 3],
            "num_ints"         : [3, 6, 7],
            "batch_size"       : 32,
            "num_batches"      : 100,
            "gen_base_seed"    : 102,
            "gen_reproducible" : True,
        },
        "test_data" : {
            "int_lengths"      : [4, 5],
            "num_ints"         : [6, 7],
            "batch_size"       : 32,
            "num_batches"      : 100,
            "gen_base_seed"    : 103,
            "gen_reproducible" : True,
        },
        "characters"              : ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-'],
        "mask_char"               : 'M',
        "seq_start_char"          : 'B',
        "seq_end_char"            : 'E',
        "negative_char"           : 'N',
        "dtype"                   : "int32",
    },
    "model" : {
        "load_pretrained_model" : None,
        "name"                  : "mathsformer_LLM",
        "dtype"                 : "float32",
        "dropout"               : 0,
        "jit_compile"           : False,
        "use_old_loss"          : True,
        "optimizer"             : Adam,
        "optimizer_args"        : {"learning_rate":1e-4},
        "positional_encoding" : {
            "num_freqs"         : 32,
            "min_period"        : 4,
            "max_period"        : 400,
            "learnable"         : True,
        },
        "ndim_embedding"        : 64,
        "comb_type"             : 'average',
        "pre_encoder"           : {
            "num_layers"        : -1,
            "ndim"              : 256,
            "skip_connect"      : True,
        },
        "pre_decoder" : {
            "num_layers"        : -1,
            "ndim"              : 256,
            "skip_connect"      : True,
        },
        "encoder" : {
            "num_blocks"        : 6,
            "num_loops"         : 1,
            "num_heads"         : 8,
            "ndim"              : 64,
            "ndim_att_hidden"   : 128,
            "ndim_ff_hidden"    : 512,
            "skip_connect"      : True,
            "share_weights"     : True,
        },
        "decoder" : {
            "num_blocks"        : 6,
            "num_loops"         : 1,
            "num_heads"         : 8,
            "ndim"              : 64,
            "ndim_att_hidden"   : 128,
            "ndim_ff_hidden"    : 512,
            "skip_connect"      : True,
            "share_weights"     : False,
        },
        "post_decoder" : {
            "num_layers"        : 3,
            "ndim"              : 512,
        },
    },
    "training" : {
        "train"          : True,
        "max_epochs"     : 100000,
        "log_after_epoch" : {
            "do"          : True,
            "log_lvl"     : logging.DEBUG,
        },
        "early_stopping" : {
            "do"                   : True,
            "patience"             : 6,
            "monitor"              : "val_loss",
            "mode"                 : "min",
            "restore_best_weights" : True,
        },
        "model_checkpoint" : {
            "do"       : True,
            "filename" : "model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5",
        },
        "layer_weights_record" : {
            "do"               : True,
            "batch_frequency"  : 2000,
            "recursive"        : True,
        },
        "adaptive_learning_rate" : {
            "do"                 : True,
            "decay_factor"       : 0.3,
            "monitor"            : "loss",
            "mode"               : "min",
            "patience"           : 2,
            "log_lvl"            : logging.DEBUG,
        },
        "print_tables_during_training" : {
            "do"        : True,
            "num_print" : 10,
        },
    },
    "evaluate" : {
        "num_print"            : 20,
        "save_model"           : True,
        "plot_weights"         : False,
        "plot_training_curves" : True,
    },
}


In [3]:
##===================================##
##   Load and validate full config   ##
##===================================##

##  Create config object containing default values
cfg = config.Config(backend.DEFAULT_CONFIG)

##  Override with custom values
cfg.load_dict(custom_config)

##  Validate config
backend.validate_config(cfg)

##  Print success
print(utils.fancy_message(f"Config created"))

##  For convenience, split configs for different sections
cfg_global   = cfg["global"  ]
cfg_data     = cfg["data"    ]
cfg_model    = cfg["model"   ]
cfg_training = cfg["training"]
cfg_evaluate = cfg["evaluate"]


===   Config created   ===


##  2. Set up environment

- Create working directory
- Create logger
- Log package versions for reproducibility
- Log config values for reproducibility
- Set random seeds for reproducibility


In [4]:
##==============================##
##   Create working directory   ##
##==============================##

##  Report success
working_dir, logger, base_seed, np_seed, tf_seed = utils.initialise_program(
    "unsupervised_learning_addition_model_generator (notebook)", 
    working_dir       = cfg_global["working_dir"], 
    cfg               = cfg,
    base_seed         = cfg_global["base_seed"],
    log_lvl_iostream  = cfg_global["log_lvl_iostream"],
    log_lvl_fstream   = cfg_global["log_lvl_fstream" ],
)


===   Working directory created at SSL_addition_generator_notebook_int123_num124_width160_enc6_loop1_dec6_loop1_post3_adam_2023_05_23   ===
   INFO initialise_logging: Begin logging on 2023-05-23 at 10:13:35
   INFO initialise_program: Program description: unsupervised_learning_addition_model_generator (notebook)
   INFO initialise_program: Working directory: SSL_addition_generator_notebook_int123_num124_width160_enc6_loop1_dec6_loop1_post3_adam_2023_05_23
   INFO log_versions: ------------------------------------------------------+-----------------------------------------------------------------------------------
   INFO log_versions:                                              PACKAGE  |  VERSION
   INFO log_versions: ------------------------------------------------------+-----------------------------------------------------------------------------------
   INFO log_versions:                                               Python  |  3.10.11 | packaged by conda-forge | (main, May 10 2

   INFO log_versions:                                           ptyprocess  |  0.7.0
   INFO log_versions:                                            pure_eval  |  0.2.2
   INFO log_versions:                                    pure_eval.version  |  0.2.2
   INFO log_versions:                                               pydevd  |  2.9.5
   INFO log_versions:                                             pygments  |  2.15.1
   INFO log_versions:                                            pyparsing  |  3.0.9
   INFO log_versions:                                                   re  |  2.2.1
   INFO log_versions:                                             requests  |  2.31.0
   INFO log_versions:                                 requests.__version__  |  2.31.0
   INFO log_versions:                                                 idna  |  3.4
   INFO log_versions:                                        idna.idnadata  |  15.0.0
   INFO log_versions:                                    idna.p

   INFO initialise_program: Registered config value data > mask_char: M
   INFO initialise_program: Registered config value data > seq_start_char: B
   INFO initialise_program: Registered config value data > seq_end_char: E
   INFO initialise_program: Registered config value data > negative_char: N
   INFO initialise_program: Registered config value data > dtype: int32
   INFO initialise_program: Registered config value model > load_pretrained_model: None
   INFO initialise_program: Registered config value model > name: mathsformer_LLM
   INFO initialise_program: Registered config value model > dtype: float32
   INFO initialise_program: Registered config value model > dropout: 0
   INFO initialise_program: Registered config value model > learning_rate: 0.001
   INFO initialise_program: Registered config value model > jit_compile: False
   INFO initialise_program: Registered config value model > positional_encoding > num_freqs: 32
   INFO initialise_program: Registered config value mode

##  3. Create training data

###  Create tokeniser

Tokeniser object handles the transformation from strings to tensors and back again

In [5]:
##======================##
##   Create tokeniser   ##
##======================##

token_transform = data.TokenTransform.from_dictionary(cfg_data)
token_transform.summary(print_fn=logger.info)


   INFO summary: TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO summary: Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)
   INFO summary: Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}
   INFO summary: Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


###  Create data generators for train/val/test sets

Data generators create tensor inputs/outputs for the model on-the-fly


In [6]:
##============================##
##   Create data generators   ##
##============================##

negative_char = cfg_data.get("negative_char")
train_gen, train_gen_reproducible, val_gen, test_gen = backend.get_data_generators(cfg_data, token_transform)


   INFO get_data_generators: Training data generator created with the following config: Generator of [1, 2, 4] integers of length [1, 2, 3] in 2000 batches of size 32 (base_seed=101, reproducible=False)
   INFO get_data_generators: Output shapes for a test batch are ((32, 18), (32, 5)), (32, 5)
   INFO get_data_generators: Validation data generator created with the following config: Generator of [3, 6, 7] integers of length [1, 2, 3] in 100 batches of size 32 (base_seed=102, reproducible=True)
   INFO get_data_generators: Output shapes for a test batch are ((32, 31), (32, 6)), (32, 6)
   INFO get_data_generators: Test data generator created with the following config: Generator of [6, 7] integers of length [4, 5] in 100 batches of size 32 (base_seed=103, reproducible=True)
   INFO get_data_generators: Output shapes for a test batch are ((32, 45), (32, 8)), (32, 8)


##  4.  Create model

Create the keras model object that handles sequence-sequence transformations from alread-tokenised data

In [7]:
##===================================================##
##   Load or create self-supervised learning model   ##
##===================================================##

##  Get filename for load model
fname = cfg_model.get("load_pretrained_model", None)

##  Load model if fname is not None, otherwise create from scratch
if fname is not None :
    logger.info   (f"Loading model from: {fname}")
    logger.warning("Loading a pretrained model will disregard model config!")
    model = backend.load_text_to_text_model(fname)
    model.optimizer.learning_rate.assign(cfg_model["learning_rate"])  ## Reset LR to config value
else :
    logger.info(f"Creating new text-to-text model")
    model = backend.create_text_to_text_model_from_config(cfg_model, token_transform)

##  Create hack to catch model summary
model_summary = []
model.summary(print_fn = lambda s : model_summary.append(s))

##  Print model summary
logger.info("Model created with summary:")
for s in model_summary : logger.info(s)


   INFO <module>: Creating new text-to-text model




   INFO <module>: Model created with summary:


INFO:mathsformer:Model created with summary:


   INFO <module>: Model: "mathsformer_LLM"


INFO:mathsformer:Model: "mathsformer_LLM"


   INFO <module>: __________________________________________________________________________________________________


INFO:mathsformer:__________________________________________________________________________________________________


   INFO <module>:  Layer (type)                Output Shape                 Param #   Connected to                  


INFO:mathsformer: Layer (type)                Output Shape                 Param #   Connected to                  






   INFO <module>:  mathsformer_LLM_encoder_in  [(None, None)]               0         []                            


INFO:mathsformer: mathsformer_LLM_encoder_in  [(None, None)]               0         []                            


   INFO <module>:  put_layer (InputLayer)                                                                           


INFO:mathsformer: put_layer (InputLayer)                                                                           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_en  (1, None)                    0         ['mathsformer_LLM_encoder_inpu


INFO:mathsformer: mathsformer_LLM_encoder_en  (1, None)                    0         ['mathsformer_LLM_encoder_inpu


   INFO <module>:  umerate (Enumerate)                                                t_layer[0][0]']               


INFO:mathsformer: umerate (Enumerate)                                                t_layer[0][0]']               


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_em  (None, None, 64)             1024      ['mathsformer_LLM_encoder_inpu


INFO:mathsformer: mathsformer_LLM_encoder_em  (None, None, 64)             1024      ['mathsformer_LLM_encoder_inpu


   INFO <module>:  bedding (Embedding)                                                t_layer[0][0]']               


INFO:mathsformer: bedding (Embedding)                                                t_layer[0][0]']               


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_po  (1, None, 64)                32        ['mathsformer_LLM_encoder_enum


INFO:mathsformer: mathsformer_LLM_encoder_po  (1, None, 64)                32        ['mathsformer_LLM_encoder_enum


   INFO <module>:  sition_encoding (Positiona                                         erate[0][0]']                 


INFO:mathsformer: sition_encoding (Positiona                                         erate[0][0]']                 


   INFO <module>:  lEncoding)                                                                                       


INFO:mathsformer: lEncoding)                                                                                       


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_em  (None, None, 64)             0         ['mathsformer_LLM_encoder_embe


INFO:mathsformer: mathsformer_LLM_encoder_em  (None, None, 64)             0         ['mathsformer_LLM_encoder_embe


   INFO <module>:  b_and_pos (Average)                                                dding[0][0]',                 


INFO:mathsformer: b_and_pos (Average)                                                dding[0][0]',                 


   INFO <module>:                                                                      'mathsformer_LLM_encoder_posi


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_posi


   INFO <module>:                                                                     tion_encoding[0][0]']         


INFO:mathsformer:                                                                    tion_encoding[0][0]']         


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_emb_


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_emb_


   INFO <module>:  ock_1 (EncoderBlock)                                               and_pos[0][0]']               


INFO:mathsformer: ock_1 (EncoderBlock)                                               and_pos[0][0]']               


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


   INFO <module>:  ock_2 (EncoderBlock)                                               k_1[0][0]']                   


INFO:mathsformer: ock_2 (EncoderBlock)                                               k_1[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_in  [(None, None)]               0         []                            


INFO:mathsformer: mathsformer_LLM_decoder_in  [(None, None)]               0         []                            


   INFO <module>:  put_layer (InputLayer)                                                                           


INFO:mathsformer: put_layer (InputLayer)                                                                           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


   INFO <module>:  ock_3 (EncoderBlock)                                               k_2[0][0]']                   


INFO:mathsformer: ock_3 (EncoderBlock)                                               k_2[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_en  (1, None)                    0         ['mathsformer_LLM_decoder_inpu


INFO:mathsformer: mathsformer_LLM_decoder_en  (1, None)                    0         ['mathsformer_LLM_decoder_inpu


   INFO <module>:  umerate (Enumerate)                                                t_layer[0][0]']               


INFO:mathsformer: umerate (Enumerate)                                                t_layer[0][0]']               


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


   INFO <module>:  ock_4 (EncoderBlock)                                               k_3[0][0]']                   


INFO:mathsformer: ock_4 (EncoderBlock)                                               k_3[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_em  (None, None, 64)             1024      ['mathsformer_LLM_decoder_inpu


INFO:mathsformer: mathsformer_LLM_decoder_em  (None, None, 64)             1024      ['mathsformer_LLM_decoder_inpu


   INFO <module>:  bedding (Embedding)                                                t_layer[0][0]']               


INFO:mathsformer: bedding (Embedding)                                                t_layer[0][0]']               


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_po  (1, None, 64)                32        ['mathsformer_LLM_decoder_enum


INFO:mathsformer: mathsformer_LLM_decoder_po  (1, None, 64)                32        ['mathsformer_LLM_decoder_enum


   INFO <module>:  sition_encoding (Positiona                                         erate[0][0]']                 


INFO:mathsformer: sition_encoding (Positiona                                         erate[0][0]']                 


   INFO <module>:  lEncoding)                                                                                       


INFO:mathsformer: lEncoding)                                                                                       


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


   INFO <module>:  ock_5 (EncoderBlock)                                               k_4[0][0]']                   


INFO:mathsformer: ock_5 (EncoderBlock)                                               k_4[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_em  (None, None, 64)             0         ['mathsformer_LLM_decoder_embe


INFO:mathsformer: mathsformer_LLM_decoder_em  (None, None, 64)             0         ['mathsformer_LLM_decoder_embe


   INFO <module>:  b_and_pos (Average)                                                dding[0][0]',                 


INFO:mathsformer: b_and_pos (Average)                                                dding[0][0]',                 


   INFO <module>:                                                                      'mathsformer_LLM_decoder_posi


INFO:mathsformer:                                                                     'mathsformer_LLM_decoder_posi


   INFO <module>:                                                                     tion_encoding[0][0]']         


INFO:mathsformer:                                                                    tion_encoding[0][0]']         


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


INFO:mathsformer: mathsformer_LLM_encoder_bl  (None, None, 64)             266496    ['mathsformer_LLM_encoder_bloc


   INFO <module>:  ock_6 (EncoderBlock)                                               k_5[0][0]']                   


INFO:mathsformer: ock_6 (EncoderBlock)                                               k_5[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_emb_


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_emb_


   INFO <module>:  ock_1 (DecoderBlock)                                               and_pos[0][0]',               


INFO:mathsformer: ock_1 (DecoderBlock)                                               and_pos[0][0]',               


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_2 (DecoderBlock)                                               k_1[0][0]',                   


INFO:mathsformer: ock_2 (DecoderBlock)                                               k_1[0][0]',                   


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_3 (DecoderBlock)                                               k_2[0][0]',                   


INFO:mathsformer: ock_3 (DecoderBlock)                                               k_2[0][0]',                   


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_4 (DecoderBlock)                                               k_3[0][0]',                   


INFO:mathsformer: ock_4 (DecoderBlock)                                               k_3[0][0]',                   


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_5 (DecoderBlock)                                               k_4[0][0]',                   


INFO:mathsformer: ock_5 (DecoderBlock)                                               k_4[0][0]',                   


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_decoder_bl  (None, None, 64)             465856    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_6 (DecoderBlock)                                               k_5[0][0]',                   


INFO:mathsformer: ock_6 (DecoderBlock)                                               k_5[0][0]',                   


   INFO <module>:                                                                      'mathsformer_LLM_encoder_bloc


INFO:mathsformer:                                                                     'mathsformer_LLM_encoder_bloc


   INFO <module>:                                                                     k_6[0][0]']                   


INFO:mathsformer:                                                                    k_6[0][0]']                   


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  mathsformer_LLM_feedfwd_bl  (None, None, 16)             569872    ['mathsformer_LLM_decoder_bloc


INFO:mathsformer: mathsformer_LLM_feedfwd_bl  (None, None, 16)             569872    ['mathsformer_LLM_decoder_bloc


   INFO <module>:  ock_post_attention (FeedFo                                         k_6[0][0]']                   


INFO:mathsformer: ock_post_attention (FeedFo                                         k_6[0][0]']                   


   INFO <module>:  rwardBlock)                                                                                      


INFO:mathsformer: rwardBlock)                                                                                      


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  






   INFO <module>: Total params: 4966096 (18.94 MB)


INFO:mathsformer:Total params: 4966096 (18.94 MB)


   INFO <module>: Trainable params: 4966096 (18.94 MB)


INFO:mathsformer:Trainable params: 4966096 (18.94 MB)


   INFO <module>: Non-trainable params: 0 (0.00 Byte)


INFO:mathsformer:Non-trainable params: 0 (0.00 Byte)


   INFO <module>: __________________________________________________________________________________________________


INFO:mathsformer:__________________________________________________________________________________________________


In [8]:
##==============================================================##
##   Create transformer wrapper for model and token_transform   ##
##==============================================================##

transformer = transformers.Transformer_Text_to_Text(model, token_transform)


In [9]:
##=========================================##
##   Test transformer on data generators   ##
##=========================================##

backend.test_transformer(transformer, train_gen, val_gen, test_gen, negative_char=negative_char)


   INFO test_transformer: Running text --> text mathsformer inference on some training data:


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Running text --> text mathsformer inference on some training data:


   INFO print_predictions_table: -----------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:-----------------------------------------------------------------------------------


   INFO print_predictions_table:            INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


INFO:mathsformer.selfsupervised_learning_addition_model_backend:           INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


   INFO print_predictions_table: -----------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:-----------------------------------------------------------------------------------


   INFO print_predictions_table:             N573         N573       23B937 23993773N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:            N573         N573       23B937 23993773N1+011                      ?   


   INFO print_predictions_table:     N5-N22-N9+N6           20       E11437                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:    N5-N22-N9+N6           20       E11437                                     ?   


   INFO print_predictions_table:           819+39          858       E8B937                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:          819+39          858       E8B937                                     ?   


   INFO print_predictions_table:         N983+N65        N1048       E39933                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:        N983+N65        N1048       E39933                                     ?   


   INFO print_predictions_table:   N46-N74-N93-99           22       E11437                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:  N46-N74-N93-99           22       E11437                                     ?   


   INFO print_predictions_table:    N3+38+645+588         1268       E31933                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:   N3+38+645+588         1268       E31933                                     ?   


   INFO print_predictions_table:            N4-N9            5       E39437                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:           N4-N9            5       E39437                                     ?   


   INFO print_predictions_table:           N71-N1          N70       E3B937                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:          N71-N1          N70       E3B937                                     ?   


   INFO print_predictions_table:  N17+753-N173+91         1000       E51N33                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend: N17+753-N173+91         1000       E51N33                                     ?   


   INFO print_predictions_table:        N999+N497        N1496       E3B937                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:       N999+N497        N1496       E3B937                                     ?   


   INFO test_transformer: Running text --> text mathsformer inference on some validation data:


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Running text --> text mathsformer inference on some validation data:


   INFO print_predictions_table: -------------------------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:-------------------------------------------------------------------------------------------------


   INFO print_predictions_table:                          INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


INFO:mathsformer.selfsupervised_learning_addition_model_backend:                         INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


   INFO print_predictions_table: -------------------------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:-------------------------------------------------------------------------------------------------


   INFO print_predictions_table:                       29-N7+22           58       E89937                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:                      29-N7+22           58       E89937                                     ?   


   INFO print_predictions_table:      5+N5-N540-444+179-40+N911         N676       211433 21443376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:     5+N5-N540-444+179-40+N911         N676       211433 21443376N1+011                      ?   


   INFO print_predictions_table:         N45+42-N896+9+8-N540+7         1457       2+4N33 21143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:        N45+42-N896+9+8-N540+7         1457       2+4N33 21143776N1+011                      ?   


   INFO print_predictions_table:        836-602+N161+4+N9+393+1          462       211437 21443376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:       836-602+N161+4+N9+393+1          462       211437 21443376N1+011                      ?   


   INFO print_predictions_table:  N509+240+N56-N8-N829-N996-N51         1559       2+1433 2+143776N0+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend: N509+240+N56-N8-N829-N996-N51         1559       2+1433 2+143776N0+011                      ?   


   INFO print_predictions_table:           80+N443-288+N3+N7-N6         N655       251N33 21143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:          80+N443-288+N3+N7-N6         N655       251N33 21143776N1+011                      ?   


   INFO print_predictions_table:                      154-9+N57           88       E59437                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:                     154-9+N57           88       E59437                                     ?   


   INFO print_predictions_table:        166-N20-N832-N68-N8-987          107       2+1437 21443776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:       166-N20-N832-N68-N8-987          107       2+1437 21443776N1+011                      ?   


   INFO print_predictions_table:                      N192-45-6         N243       E39933                                     ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:                     N192-45-6         N243       E39933                                     ?   


   INFO print_predictions_table:           N770-837-9-N197+N8-3        N1430       211437 21143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:          N770-837-9-N197+N8-3        N1430       211437 21143776N1+011                      ?   


   INFO test_transformer: Running text --> text mathsformer inference on some test data:


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Running text --> text mathsformer inference on some test data:


   INFO print_predictions_table: ---------------------------------------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:---------------------------------------------------------------------------------------------------------------


   INFO print_predictions_table:                                        INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


INFO:mathsformer.selfsupervised_learning_addition_model_backend:                                       INPUT         TRUE   PRED(MASK)      PRED(GEN)      CORRECT     RESIDUAL


   INFO print_predictions_table: ---------------------------------------------------------------------------------------------------------------


INFO:mathsformer.selfsupervised_learning_addition_model_backend:---------------------------------------------------------------------------------------------------------------


   INFO print_predictions_table:    N97671+N8835+N9014+9955+65061+44422-90112       N86194     25143366 2+143376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:   N97671+N8835+N9014+9955+65061+44422-90112       N86194     25143366 2+143376N1+011                      ?   


   INFO print_predictions_table:   N1266+N75283-5768+N10089-27664-5437-N56073       N69434     25443366 2+143376N0+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:  N1266+N75283-5768+N10089-27664-5437-N56073       N69434     25443366 2+143376N0+011                      ?   


   INFO print_predictions_table:    N5382-57189-N68512-N5653-9463-N63176-5275        60032     2+443366 2+143376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:   N5382-57189-N68512-N5653-9463-N63176-5275        60032     2+443366 2+143376N1+011                      ?   


   INFO print_predictions_table:  61773+N9522-N34810-8029+N63152+N17201-78161       N79482     25143346 2+143376N0+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend: 61773+N9522-N34810-8029+N63152+N17201-78161       N79482     25143346 2+143376N0+011                      ?   


   INFO print_predictions_table:    4233-N2588-2146+N9614-N7957+N65478-N32179       N30281     25143376 2+143376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:   4233-N2588-2146+N9614-N7957+N65478-N32179       N30281     25143376 2+143376N1+011                      ?   


   INFO print_predictions_table:          1776-N5476-N6098-N13559+N9463+N9384         8062     2E443N76 2+143376N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:         1776-N5476-N6098-N13559+N9463+N9384         8062     2E443N76 2+143376N1+011                      ?   


   INFO print_predictions_table:   N53908+N8830+N7087+N4936+56984+92040+N9740        64523     2+443366 2+143376N0+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:  N53908+N8830+N7087+N4936+56984+92040+N9740        64523     2+443366 2+143376N0+011                      ?   


   INFO print_predictions_table:         N35202-N7716+7928-N84468-32560-33281         N931     25143N76 21143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:        N35202-N7716+7928-N84468-32560-33281         N931     25143N76 21143776N1+011                      ?   


   INFO print_predictions_table:       30880+3141-2144-79811-N9700+94269-5850        50185     2+443N-6 2+143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:      30880+3141-2144-79811-N9700+94269-5850        50185     2+443N-6 2+143776N1+011                      ?   


   INFO print_predictions_table:           N3389-5913-3625-8679-N35168-N88200       101762     2+143N46 21143776N1+011                      ?   


INFO:mathsformer.selfsupervised_learning_addition_model_backend:          N3389-5913-3625-8679-N35168-N88200       101762     2+143N46 21143776N1+011                      ?   


##  5.  Train model

In [10]:
##===================================##
##   Create callbacks for training   ##
##===================================##

callbacks = backend.get_callbacks(cfg_training, working_dir, transformer=transformer, train_gen=train_gen_reproducible, 
                                  val_gen=val_gen, negative_char=negative_char)


   INFO get_callbacks: Registered training callback: LoggerCallback with loglvl=10


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registered training callback: LoggerCallback with loglvl=10


   INFO get_callbacks: Registered training callback: EarlyStopping with monitor=val_loss, mode=min, patience=6, restore_best_weights=True


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registered training callback: EarlyStopping with monitor=val_loss, mode=min, patience=6, restore_best_weights=True


   INFO get_callbacks: Registeried training callback: AdaptiveLearningRate with decay_factor=0.3, patience=2, monitor=loss, mode=min, log_lvl=10


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registeried training callback: AdaptiveLearningRate with decay_factor=0.3, patience=2, monitor=loss, mode=min, log_lvl=10


   INFO get_callbacks: Registeried training callback: ModelCheckpoint with filepath=SSL_addition_generator_notebook_int123_num124_width160_enc6_loop1_dec6_loop1_post3_adam_2023_05_23/model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registeried training callback: ModelCheckpoint with filepath=SSL_addition_generator_notebook_int123_num124_width160_enc6_loop1_dec6_loop1_post3_adam_2023_05_23/model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5


   INFO get_callbacks: Registered training callback: LayerWeightsRecord with batch_frequency=2000, recursive=True


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registered training callback: LayerWeightsRecord with batch_frequency=2000, recursive=True


   INFO get_callbacks: Registered training callback: LambdaCallback for test_transformer with num_print=10, negative_char='N'


INFO:mathsformer.selfsupervised_learning_addition_model_backend:Registered training callback: LambdaCallback for test_transformer with num_print=10, negative_char='N'


In [11]:
##=================##
##   Train model   ##
##=================##

do_train = cfg_training.get("train", True)

if do_train :
    max_epochs = cfg_training["max_epochs"]
    logger.info(f"Begin model training with max_epochs={max_epochs}")
    model.fit(train_gen, 
              epochs          = max_epochs,
              validation_data = val_gen,
              callbacks       = callbacks
             )
else :
    logger.warning("Skipping model training following global config instructions")


   INFO <module>: Begin model training with max_epochs=100000


INFO:mathsformer:Begin model training with max_epochs=100000
DEBUG:mathsformer.selfsupervised_learning_addition_model_backend:Setting variable to learning_rate:0


Epoch 1/100000


2023-05-23 10:13:59.020747: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


  83/2000 [>.............................] - ETA: 3:44 - loss: 1.9606 - scalar_masked_categorical_accuracy: 0.3897

KeyboardInterrupt: 

In [None]:
##================##
##   Save model   ##
##================##

do_save = cfg_evaluate.get("save_model", True)

if do_save :
    save_fname = f"{working_dir}/final_model.h5"
    model.save(save_fname)
    logger.info(f"Model saved to file {save_fname}")
else :
    logger.warning("Not saving model because no training was done")


## 6.  Evaluate model

In [None]:
  
##  Find out how many datapoints to print predictions for 
num_print = cfg_evaluate.get("num_print", 20)

##  Print tables
backend.test_transformer(transformer, train_gen, val_gen, test_gen, num_print=num_print, 
                         negative_char=negative_char)


##  7. Additional visualisations

In [None]:
##=============================================##
##   Visualise layer weights during training   ##
##=============================================##

if cfg_evaluate["plot_weights"] :
    
    logger.info("Plotting weights")
    backend.plot_weights(callbacks, show=True, close=True, savefig=f"{working_dir}/layer_weights.pdf")
    

In [None]:

if cfg_evaluate["plot_training_curves"] :
    
    if not hasattr(model, "history") :
        logger.error("Cannot print training curves because no model history exists - perhaps you skipped training?")
    else :
        logger.info("Plotting training curves")
        backend.plot_training_curves(model.history.history, show=True, close=True, savefig=f"{working_dir}/training_curves.pdf")
    