#  Loopy decoder

Author: S. Menary [sbmenary@gmail.com]

Date: 17/5/2023

Overview: Train an `integer -> sequence` model where the output sequence is a (logit probabilities over the tokens of a) text representation of the input integer. The inputs are first embedded into a higher dimensional space using a `Fourier` positional encoding with fixed frequencies spanning a logarithmic series, allowing them to encode over many orders of magnitude. The decoder is composed of a simple decoder block, which may be applied many times in over to simulate inductive logic, followed by a feed-forward block to derive the final logits. This decoder may be frozen and used in future `encoder-decoder` networks.

Note: we do not use skip connections when propagating through loops

Note: think about best architecture design with fixed-vector embedding (not a sequence of feature-vectors)

---

## 1. Set up program

###  Import

All imports go here at the top of the notebook

In [1]:
##=========================##
##   All imports go here   ##
##=========================##

##  Import entire python stdlib packages
import logging, os, sys

##  Import entire pypi packages
import tensorflow as tf

##  Remove tensorflow INFO messages
tf.get_logger().setLevel('WARNING')

##  Add directory above this to system path to expose mathsformer package location
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

##  Import individual modules/objects from local packages
from tensorflow.keras.optimizers import Adam, SGD
from mathsformer import config, data, transformers, utils
from mathsformer import selfsupervised_learning_addition_model_backend as backend


In [2]:
##===========================================================================##
##   Additional imports go here - to be removed after migration to backend   ##
##===========================================================================##

import numpy as np

from tensorflow.keras.layers     import Add, Average, Concatenate, Embedding, Input
from tensorflow.keras.models     import Model
from tensorflow.keras.optimizers import Adam

from mathsformer.tf_objects import (DecoderBlock, EncoderBlock, Enumerate, FeedForwardBlock, LearnableMixture, MaskedCategoricalAccuracy,
                                    MaskedSparseCategoricalCrossentropy, PositionalEncoding)

from collections.abc import Callable


In [3]:
##==============================##
##   Set custom config values   ##
##==============================##

custom_config = {
    "global" : {
        "base_seed"        : -1,
        "working_dir"      : "baseline_decoder_[problem_tag]_[model_tag]_[date]",
        "problem_tag"      : "int123467810",
        "model_tag"        : "baseline",
        "log_lvl_iostream" : logging.INFO,
        "log_lvl_fstream"  : logging.DEBUG,
    },
    "data" : {
        "train_data" : {
            "int_lengths"      : [1, 2, 3, 4, 6],
            "batch_size"       : 32,
            "num_batches"      : 1000,
            "gen_base_seed"    : 100,
            "gen_reproducible" : False, 
        },
        "val_data" : {
            "int_lengths"      : [5, 7],
            "batch_size"       : 32,
            "num_batches"      : 50,
            "gen_base_seed"    : 101,
            "gen_reproducible" : True,
        },
        "test_data" : {
            "int_lengths"      : [8, 9, 10],
            "batch_size"       : 32,
            "num_batches"      : 100,
            "gen_base_seed"    : 102,
            "gen_reproducible" : True,
        },
        "characters"              : ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-'],
        "mask_char"               : 'M',
        "seq_start_char"          : 'B',
        "seq_end_char"            : 'E',
        "negative_char"           : 'N',
        "dtype_int"               : "int64",
        "dtype"                   : "int32",
    },
    "model" : {
        "load_pretrained_model" : None,
        "name"                  : "decoder_model",
        "dtype"                 : "float32",
        "dropout"               : 0.1,
        "jit_compile"           : False,
        "use_old_loss"          : True,
        "optimizer"             : Adam,
        "optimizer_args"        : {"learning_rate":1e-4},
        "int_positional_encoding" : {
            "num_freqs"         : 128,
            "min_period"        : 3,
            "max_period"        : 1e13,
            "learnable"         : True,
        },
        "str_positional_encoding" : {
            "num_freqs"         : 32,
            "min_period"        : 4,
            "max_period"        : 400,
            "learnable"         : True,
        },
        "ndim_embedding"        : 32,
        "pre_decoder" : {
            "num_layers"        : -1,
            "ndim"              : 256,
            "skip_connect"      : True,
        },
        "decoder" : {
            "num_blocks"        : 6,
            "num_loops"         : 1,
            "num_heads"         : 8,
            "ndim"              : 32,
            "ndim_att_hidden"   : 128,
            "ndim_ff_hidden"    : 512,
            "skip_connect"      : True,
        },
        "post_decoder" : {
            "num_layers"        : 3,
            "ndim"              : 512,
        },
    },
    "training" : {
        "train"          : True,
        "max_epochs"     : 100000,
        "log_after_epoch" : {
            "do"          : True,
            "log_lvl"     : logging.DEBUG,
        },
        "early_stopping" : {
            "do"                   : True,
            "patience"             : 5,
            "monitor"              : "val_loss",
            "mode"                 : "min",
            "restore_best_weights" : True,
        },
        "model_checkpoint" : {
            "do"       : True,
            "filename" : "model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5",
        },
        "layer_weights_record" : {
            "do"               : True,
            "batch_frequency"  : 2000,
            "recursive"        : True,
        },
        "adaptive_learning_rate" : {
            "do"                 : True,
            "decay_factor"       : 0.3,
            "monitor"            : "loss",
            "mode"               : "min",
            "patience"           : 1,
            "log_lvl"            : logging.DEBUG,
        },
    },
    "evaluate" : {
        "num_print"            : 20,
        "save_model"           : True,
        "plot_weights"         : False,
        "plot_training_curves" : True,
    },
}


In [4]:
##===================================##
##   Load and validate full config   ##
##===================================##

##  Create config object containing default values
cfg = config.Config(backend.DEFAULT_CONFIG)

##  Override with custom values
cfg.load_dict(custom_config)

##  Validate config
backend.validate_config(cfg)

##  Print success
print(utils.fancy_message(f"Config created"))

##  For convenience, split configs for different sections
cfg_global   = cfg["global"  ]
cfg_data     = cfg["data"    ]
cfg_model    = cfg["model"   ]
cfg_training = cfg["training"]
cfg_evaluate = cfg["evaluate"]


===   Config created   ===


In [5]:
##==============================##
##   Create working directory   ##
##==============================##

##  Report success
working_dir, logger, base_seed, np_seed, tf_seed = utils.initialise_program(
    "unsupervised_learning_addition_model_generator (notebook)", 
    working_dir       = cfg_global["working_dir"], 
    cfg               = cfg,
    base_seed         = cfg_global["base_seed"],
    log_lvl_iostream  = cfg_global["log_lvl_iostream"],
    log_lvl_fstream   = cfg_global["log_lvl_fstream" ],
)


===   Working directory created at baseline_decoder_int123467810_baseline_2023_06_14   ===
   INFO initialise_logging: Begin logging on 2023-06-14 at 16:31:01
   INFO initialise_program: Program description: unsupervised_learning_addition_model_generator (notebook)
   INFO initialise_program: Working directory: baseline_decoder_int123467810_baseline_2023_06_14
   INFO log_versions: ------------------------------------------------------+-----------------------------------------------------------------------------------
   INFO log_versions:                                              PACKAGE  |  VERSION
   INFO log_versions: ------------------------------------------------------+-----------------------------------------------------------------------------------
   INFO log_versions:                                               Python  |  3.10.11 | packaged by conda-forge | (main, May 10 2023, 19:01:19) [Clang 14.0.6 ]
   INFO log_versions:                                              

   INFO log_versions:                                               pydevd  |  2.9.5
   INFO log_versions:                                             pygments  |  2.15.1
   INFO log_versions:                                            pyparsing  |  3.0.9
   INFO log_versions:                                                   re  |  2.2.1
   INFO log_versions:                                             requests  |  2.31.0
   INFO log_versions:                                 requests.__version__  |  2.31.0
   INFO log_versions:                                                 idna  |  3.4
   INFO log_versions:                                        idna.idnadata  |  15.0.0
   INFO log_versions:                                    idna.package_data  |  3.4
   INFO log_versions:                                              urllib3  |  1.26.15
   INFO log_versions:                                     urllib3._version  |  1.26.15
   INFO log_versions:                                   urlli

   INFO initialise_program: Registered config value data > dtype: int32
   INFO initialise_program: Registered config value data > dtype_int: int64
   INFO initialise_program: Registered config value model > load_pretrained_model: None
   INFO initialise_program: Registered config value model > name: decoder_model
   INFO initialise_program: Registered config value model > dtype: float32
   INFO initialise_program: Registered config value model > dropout: 0.1
   INFO initialise_program: Registered config value model > learning_rate: 0.001
   INFO initialise_program: Registered config value model > jit_compile: False
   INFO initialise_program: Registered config value model > positional_encoding > num_freqs: 16
   INFO initialise_program: Registered config value model > positional_encoding > min_period: 4
   INFO initialise_program: Registered config value model > positional_encoding > max_period: 250
   INFO initialise_program: Registered config value model > ndim_embedding: 32
   INFO

In [6]:
##======================##
##   Create tokeniser   ##
##======================##

token_transform = data.TokenTransform.from_dictionary(cfg_data)
token_transform.summary(print_fn=logger.info)


   INFO summary: TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO summary: Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)
   INFO summary: Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}
   INFO summary: Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


In [7]:

def create_text_to_text_model(vocab_length:int, 
                              name:str, 
                              do_compile:bool     = True,
                              use_old_loss:bool   = False,
                              dtype_encoded_in    = tf.int32, 
                              dtype_decoder_in    = tf.int32, 
                              dtype               = tf.float32, 
                              dropout:float       = 0.1, 
                              jit_compile:bool    = None,
                              optimizer           = Adam,
                              optimizer_args:dict = None,
                              encoder_num_freqs:int = 96, encoder_min_period:float = 3, encoder_max_period:float = 1e13, encoder_learnable:bool = False,
                              pos_enc_num_freqs:int = 32, pos_enc_min_period:float = 4, pos_enc_max_period:float = 500 , pos_enc_learnable:bool = False,
                              ndim_embedding:int          = 64,
                              num_decoder_blocks:int      = 6 , ndim_decoder:int            = 64 , skip_connect_decoder:bool  = True,
                              num_heads_decoder:int       = 8 , ndim_att_hidden_decoder:int = 128, ndim_ff_hidden_decoder:int = 128, 
                              num_decoder_loops:int       = 1 ,
                              num_post_layers_decoder:int = 3 , ndim_post_layers_decoder:int = 512, 
                             ) :
    """
    """
    ##  Resolve mutable default args
    if optimizer_args is None :
        optimizer_args = {'learning_rate': 1e-3}
    
    ##=============================================##
    ##===   Input layer - Output shape [B, S]   ===##
    ##=============================================##
    x_in_enc = Input((None,), dtype=dtype_encoded_in, name=f"{name}_encoded_input_layer")
    x_in_dec = Input((None,), dtype=dtype_decoder_in, name=f"{name}_decoder_input_layer")
            
    ##===========================================================================##
    ##===  Token embedding, masking 0s - Output shape [B, S, ndim_embedding]  ===##
    ##===========================================================================##
    x_embed_dec = Embedding(vocab_length, 
                            ndim_embedding, 
                            mask_zero=True, 
                            dtype=dtype, 
                            name=f"{name}_decoder_embedding")(x_in_dec)
    
    ##=========================================================================##
    ##===  Enumerate indices for positional encoding - Output shape [B, S]  ===##
    ##=========================================================================##
    x_pos_dec = Enumerate(name=f"{name}_decoder_enumerate", dtype=dtype)(x_in_dec, minimal_dims=False)
    
    ##========================================================================##
    ##===  Positional encoding - Output shape [B, S, 2*pos_enc_num_freqs]  ===##
    ##========================================================================##
    x_pos_enc = PositionalEncoding(num_freqs  = encoder_num_freqs, 
                                   min_period = encoder_min_period, 
                                   max_period = encoder_max_period, 
                                   learnable  = encoder_learnable,
                                   dtype      = dtype, 
                                   name       = f"{name}_encoder_position_encoding")(x_in_enc)
    
    x_pos_dec = PositionalEncoding(num_freqs  = pos_enc_num_freqs, 
                                   min_period = pos_enc_min_period, 
                                   max_period = pos_enc_max_period, 
                                   learnable  = pos_enc_learnable,
                                   dtype      = dtype, 
                                   name       = f"{name}_decoder_position_encoding")(x_pos_dec)
    
    ##============================================================##
    ##===  Decoder blocks - Output shape [B, S, ndim_decoder]  ===##
    ##============================================================##
    decoder_blocks = []
    for layer_idx in range(num_decoder_blocks) :
        decoder_blocks.append(DecoderBlock(
                                 ndim_decoder, 
                                 num_heads_decoder, 
                                 ndim_att_hidden_decoder, 
                                 ndim_ff_hidden_decoder, 
                                 dropout_mha  = dropout, 
                                 dtype        = dtype, 
                                 layer_norm   = True, 
                                 skip_connect = False, 
                                 name         = f"{name}_decoder_block_{layer_idx+1}"))
        
    x_dec = x_embed_dec
    for loop_idx in range(num_decoder_loops) :
        x_dec = Concatenate(name=f"{name}_decoder_emb_and_pos_loop{loop_idx+1}", dtype=dtype)([x_dec, x_pos_dec])
        for decoder_block in decoder_blocks :
            x_dec = decoder_block([x_dec, x_pos_enc])
        
    ##==================================================================================================##
    ##===  Predict logit probabilities using feed-forward block - Output shape [B, S, vocab_length]  ===##
    ##==================================================================================================##
    ##  - use layer_norm instead of batch_norm because elements in sequence are not independent
    x = FeedForwardBlock(vocab_length, 
                         ndim_hidden       = ndim_post_layers_decoder, 
                         num_hidden_layers = num_post_layers_decoder, 
                         skip_connect      = False, 
                         layer_norm        = True, 
                         batch_norm        = False, 
                         dtype             = dtype, 
                         name              = f"{name}_feedfwd_block_post_attention")(x_dec)
    
    ##  Create model
    model = Model([x_in_enc, x_in_dec], x, name=name)
    
    ##  Compile model with sparse categorical crossentropy loss and accuracy metric
    if do_compile :
        acc  = MaskedCategoricalAccuracy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0)
        loss = MaskedSparseCategoricalCrossentropy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0, from_logits=True)
        model.compile(loss        = loss, 
                      optimizer   = optimizer(**optimizer_args), 
                      metrics     = [acc],
                      jit_compile = jit_compile)
    
    ##  Return model
    return model

In [8]:
##===================================================##
##   Load or create self-supervised learning model   ##
##===================================================##

##  Get filename for load model
fname = cfg_model.get("load_pretrained_model", None)

##  Load model if fname is not None, otherwise create from scratch
if fname is not None :
    logger.info   (f"Loading model from: {fname}")
    logger.warning("Loading a pretrained model will disregard model config!")
    model = backend.load_text_to_text_model(fname)
    model.optimizer.learning_rate.assign(cfg_model["learning_rate"])  ## Reset LR to config value
else :
    logger.info(f"Creating new text-to-text model")
    model = create_text_to_text_model(name=cfg_model["name"], vocab_length=token_transform.vocab_length)

##  Create hack to catch model summary
model_summary = []
model.summary(print_fn = lambda s : model_summary.append(s))

##  Print model summary
logger.info("Model created with summary:")
for s in model_summary : logger.info(s)


   INFO <module>: Creating new text-to-text model




   INFO <module>: Model created with summary:


INFO:mathsformer:Model created with summary:


   INFO <module>: Model: "decoder_model"


INFO:mathsformer:Model: "decoder_model"


   INFO <module>: __________________________________________________________________________________________________


INFO:mathsformer:__________________________________________________________________________________________________


   INFO <module>:  Layer (type)                Output Shape                 Param #   Connected to                  


INFO:mathsformer: Layer (type)                Output Shape                 Param #   Connected to                  






   INFO <module>:  decoder_model_decoder_inpu  [(None, None)]               0         []                            


INFO:mathsformer: decoder_model_decoder_inpu  [(None, None)]               0         []                            


   INFO <module>:  t_layer (InputLayer)                                                                             


INFO:mathsformer: t_layer (InputLayer)                                                                             


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_enum  (None, None)                 0         ['decoder_model_decoder_input_


INFO:mathsformer: decoder_model_decoder_enum  (None, None)                 0         ['decoder_model_decoder_input_


   INFO <module>:  erate (Enumerate)                                                  layer[0][0]']                 


INFO:mathsformer: erate (Enumerate)                                                  layer[0][0]']                 


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_embe  (None, None, 64)             1024      ['decoder_model_decoder_input_


INFO:mathsformer: decoder_model_decoder_embe  (None, None, 64)             1024      ['decoder_model_decoder_input_


   INFO <module>:  dding (Embedding)                                                  layer[0][0]']                 


INFO:mathsformer: dding (Embedding)                                                  layer[0][0]']                 


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_posi  (None, None, 64)             32        ['decoder_model_decoder_enumer


INFO:mathsformer: decoder_model_decoder_posi  (None, None, 64)             32        ['decoder_model_decoder_enumer


   INFO <module>:  tion_encoding (PositionalE                                         ate[0][0]']                   


INFO:mathsformer: tion_encoding (PositionalE                                         ate[0][0]']                   


   INFO <module>:  ncoding)                                                                                         


INFO:mathsformer: ncoding)                                                                                         


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_encoded_inpu  [(None, None)]               0         []                            


INFO:mathsformer: decoder_model_encoded_inpu  [(None, None)]               0         []                            


   INFO <module>:  t_layer (InputLayer)                                                                             


INFO:mathsformer: t_layer (InputLayer)                                                                             


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_emb_  (None, None, 128)            0         ['decoder_model_decoder_embedd


INFO:mathsformer: decoder_model_decoder_emb_  (None, None, 128)            0         ['decoder_model_decoder_embedd


   INFO <module>:  and_pos_loop1 (Concatenate                                         ing[0][0]',                   


INFO:mathsformer: and_pos_loop1 (Concatenate                                         ing[0][0]',                   


   INFO <module>:  )                                                                   'decoder_model_decoder_positi


INFO:mathsformer: )                                                                   'decoder_model_decoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_encoder_posi  (None, None, 192)            96        ['decoder_model_encoded_input_


INFO:mathsformer: decoder_model_encoder_posi  (None, None, 192)            96        ['decoder_model_encoded_input_


   INFO <module>:  tion_encoding (PositionalE                                         layer[0][0]']                 


INFO:mathsformer: tion_encoding (PositionalE                                         layer[0][0]']                 


   INFO <module>:  ncoding)                                                                                         


INFO:mathsformer: ncoding)                                                                                         


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             915648    ['decoder_model_decoder_emb_an


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             915648    ['decoder_model_decoder_emb_an


   INFO <module>:  k_1 (DecoderBlock)                                                 d_pos_loop1[0][0]',           


INFO:mathsformer: k_1 (DecoderBlock)                                                 d_pos_loop1[0][0]',           


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


   INFO <module>:  k_2 (DecoderBlock)                                                 1[0][0]',                     


INFO:mathsformer: k_2 (DecoderBlock)                                                 1[0][0]',                     


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


   INFO <module>:  k_3 (DecoderBlock)                                                 2[0][0]',                     


INFO:mathsformer: k_3 (DecoderBlock)                                                 2[0][0]',                     


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


   INFO <module>:  k_4 (DecoderBlock)                                                 3[0][0]',                     


INFO:mathsformer: k_4 (DecoderBlock)                                                 3[0][0]',                     


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


   INFO <module>:  k_5 (DecoderBlock)                                                 4[0][0]',                     


INFO:mathsformer: k_5 (DecoderBlock)                                                 4[0][0]',                     


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_decoder_bloc  (None, None, 64)             612160    ['decoder_model_decoder_block_


   INFO <module>:  k_6 (DecoderBlock)                                                 5[0][0]',                     


INFO:mathsformer: k_6 (DecoderBlock)                                                 5[0][0]',                     


   INFO <module>:                                                                      'decoder_model_encoder_positi


INFO:mathsformer:                                                                     'decoder_model_encoder_positi


   INFO <module>:                                                                     on_encoding[0][0]']           


INFO:mathsformer:                                                                    on_encoding[0][0]']           


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  


   INFO <module>:  decoder_model_feedfwd_bloc  (None, None, 16)             569872    ['decoder_model_decoder_block_


INFO:mathsformer: decoder_model_feedfwd_bloc  (None, None, 16)             569872    ['decoder_model_decoder_block_


   INFO <module>:  k_post_attention (FeedForw                                         6[0][0]']                     


INFO:mathsformer: k_post_attention (FeedForw                                         6[0][0]']                     


   INFO <module>:  ardBlock)                                                                                        


INFO:mathsformer: ardBlock)                                                                                        


   INFO <module>:                                                                                                   


INFO:mathsformer:                                                                                                  






   INFO <module>: Total params: 4547472 (17.35 MB)


INFO:mathsformer:Total params: 4547472 (17.35 MB)


   INFO <module>: Trainable params: 4547472 (17.35 MB)


INFO:mathsformer:Trainable params: 4547472 (17.35 MB)


   INFO <module>: Non-trainable params: 0 (0.00 Byte)


INFO:mathsformer:Non-trainable params: 0 (0.00 Byte)


   INFO <module>: __________________________________________________________________________________________________


INFO:mathsformer:__________________________________________________________________________________________________


In [28]:
loggers = dict(logging.root.manager.loggerDict.items())

loggers["root"] = logger

for name, _logger in loggers.items() :
    if not hasattr(_logger, "level") : continue
    if len(_logger.handlers) == 0 : continue
    print(name)
    print(_logger, _logger.handlers)
    print()

tornado

IPKernelApp

tensorflow

urllib3

charset_normalizer

requests

mathsformer
<Logger mathsformer (DEBUG)> [<StreamHandler stdout (INFO)>, <FileHandler /Users/Ste/PROJECTS/misc/ML-sandbox/ML-sandbox/Project_Maths_Transformer/3_pretrained_decoder/baseline_decoder_int123467810_baseline_2023_06_14/log.txt (DEBUG)>]

root
<Logger mathsformer (DEBUG)> [<StreamHandler stdout (INFO)>, <FileHandler /Users/Ste/PROJECTS/misc/ML-sandbox/ML-sandbox/Project_Maths_Transformer/3_pretrained_decoder/baseline_decoder_int123467810_baseline_2023_06_14/log.txt (DEBUG)>]



In [None]:
##=================================##
##   RandomNumberGenerator class   ##
##=================================##
##
class RandomNumberGenerator(tf.keras.utils.Sequence) :
    
    def __init__(self, token_transform:data.TokenTransform, int_lengths:list, batch_size:int, num_batches:int, 
                 base_seed:int=-1, reproducible:bool=False, negative_char:str='-', dtype=tf.int32) :
        """
        class RandomNumberGenerator
        
        Data generator used to create individual batches of input/output data on-the-fly for a keras model.
        If reproducible=True then self[i] will always generate the same result, otherwise it will sample new
        data every time it is called.
        WARNING: data are not guaranteed to be unique - different batches may contain the same datapoints!
        Inputs  are integers
        Outputs are tokens representing the string
        
        Inputs:
        
            >  token_transform, TokenTransform
               Method for transforming strings to/from tokenised tensors
               
            >  int_lengths, list
               List of allowed integer-lengths (N.B. will be sampled uniformly, so 1-digit numbers occur with the
               same frequency as N-digit numbers!)
               
            >  batch_size, int
               Number of sequences per batch
               
            >  num_batches, int
               Number of batches to constitute a full epoch
               
            >  base_seed, int, default=-1
               Random seed used to initialise random number generator, if -1 then fall back to system time
        
            >  reproducible, bool, default=False
               If True then re-initialise the rng seed to base_seed + i when calling self[i] for reproducible results
               Otherwise do not re-initialise rng, allowing it to continue generating potentially new datapoints

            >  negative_char, str, default='N'
               Character used to represent a negative number
               
            >  dtype, dtype, default=tf.int32
               Dtype for the integer tensor
        """
        if base_seed < 0 :
            base_seed = int(time.time())
        
        self.token_transform = token_transform
        self.int_lengths     = int_lengths
        self.batch_size      = batch_size
        self.num_batches     = num_batches
        self.base_seed       = base_seed
        self.reproducible    = reproducible
        self.negative_char   = negative_char
        self.dtype           = dtype
        self.reset_rng()

    
    def __getitem__(self, index:int) :
        """
        Returns a new set of tensors (X, Y) with length self.batch_size
        
        Inputs:
        
            >  index, int
               Index of the call, only meaningful if we have self.reproducible = True
        """
        if self.reproducible :
            self.reset_rng(self.base_seed + index)
        Z = [self._generate_x_y() for _ in range(self.batch_size)]
        X = tf.constant([z[0] for z in Z], dtype=self.dtype)[:, tf.newaxis]
        Y = self.token_transform.strings_to_tensor([z[1] for z in Z])
        return [X, Y[:,:-1]], Y[:,1:]
    
    
    def __len__(self) :
        """
        Following generator convention: returns number of batches
        """
        return self.num_batches
    
    
    def __str__(self) :
        """
        Returns a string summarising the generator configuration
        """
        return f"Generator of integers of length {self.int_lengths} in {self.num_batches} batches of size {self.batch_size} (base_seed={self.base_seed}, reproducible={self.reproducible})"
    
    
    def _generate_x_y(self) :
        """
        Returns a pair of (int, string) = (X, Y) where X is a number and Y is its string repr
        The number of digits in X is uniformly sampled from self.int_lengths
        
        """
        length      = self.rng.choice(self.int_lengths)
        sign        = self.rng.choice(["", self.negative_char])
        lead_char   = str(self.rng.randint(1, 10))
        other_chars = "".join([str(self.rng.randint(0, 10)) for i in range(length-1)])
        int_string  = sign + lead_char + other_chars
        int_int     = int(int_string.replace(self.negative_char, "-"))
        return int_int, int_string
    
    
    def get_as_tensors(self, num_batches:int=-1) :
        """
        Create a number of batches and combine their outputs into a single set of tensors
        
        Inputs:
        
            >  num_batches, int, default=-1
               Number of batches to generate, if < 1 then fall back to self.num_batches
        """
        ##  If num batches not set then return all of them
        if num_batches < 1 :
            num_batches = self.num_batches
        
        ##  Containers to stores batches
        X, Y_in, Y_out = [], [] ,[]

        ##  Fill containers with batch results
        for i in range(num_batches) :
            [x, yi], yo = self[i]
            X, Y_in, Y_out = X + [x], Y_in + [yi], Y_out + [yo]

        ##  Find max widths of tensors, which currently have ragged shapes
        len_x, len_yi, len_yo = max([xp.shape[1] for xp in X]), max([xp.shape[1] for xp in Y_in]), max([xp.shape[1] for xp in Y_out])

        ##  Pad all tensors to the same width
        for i in range(len(X)) :
            X    [i] = tf.pad(X    [i], [[0, 0], (0, len_x -X    [i].shape[1])])
            Y_in [i] = tf.pad(Y_in [i], [[0, 0], (0, len_yi-Y_in [i].shape[1])])
            Y_out[i] = tf.pad(Y_out[i], [[0, 0], (0, len_yo-Y_out[i].shape[1])])

        ##  Concatenate batch results into single tensor
        X, Y_in, Y_out = tf.concat(X, axis=0), tf.concat(Y_in, axis=0), tf.concat(Y_out, axis=0)
        
        ##  Return
        return X, Y_in, Y_out
    
    
    def reset_rng(self, seed:int=-1) :
        """
        Set the internal rng with the seed provided
        
        Inputs:
        
            >  seed, int, default=-1
               Random seed, if < 0 then fall back to self.base_seed
        """
        if seed < 0 :
            seed = self.base_seed
        self.rng = np.random.RandomState(seed)
        
        
    def summary(self, print_fn:Callable[[str],None]=None) :
        """
        Print a summary of the generator
        
        Inputs:
        
            >  print_fn, callable with signature print_fn(str), default=print
               Function used to print strings
        """
        if print_fn is None :
            print_fn = print
        print_fn(str(self))

In [None]:

def get_data_generators(cfg_data, token_transform) :
    """
    Create train/val/test data generators

    Inputs:

        >  cfg, Config
           Data configuration

        >  token_transform, TokenTransform
           Tokeniser object
    """
    ##  Create training data generator
    train_gen = RandomNumberGenerator(
                                    token_transform = token_transform, 
                                    int_lengths     = cfg_data["train_data"]["int_lengths"],
                                    batch_size      = cfg_data["train_data"]["batch_size"],
                                    num_batches     = cfg_data["train_data"]["num_batches"],
                                    base_seed       = cfg_data["train_data"]["gen_base_seed"],
                                    reproducible    = cfg_data["train_data"]["gen_reproducible"],
                                    negative_char   = cfg_data["negative_char"],
                                    dtype           = cfg_data.get("dtype_int", tf.int32),)
    
    ##  Create training data generator that has forced reproducible=True
    train_gen_reproducible = RandomNumberGenerator(
                                    token_transform = token_transform, 
                                    int_lengths     = cfg_data["train_data"]["int_lengths"],
                                    batch_size      = cfg_data["train_data"]["batch_size"],
                                    num_batches     = cfg_data["train_data"]["num_batches"],
                                    base_seed       = cfg_data["train_data"]["gen_base_seed"],
                                    reproducible    = True,
                                    negative_char   = cfg_data["negative_char"],
                                    dtype           = cfg_data.get("dtype_int", tf.int32),)
    
    ##  Log a sample training batch
    logger.info(f"Training data generator created with the following config: {train_gen}")
    (X, Y_in), Y_out = train_gen[0]
    logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")

    ##  Create validation data generator
    val_gen = RandomNumberGenerator(
                                    token_transform = token_transform, 
                                    int_lengths     = cfg_data["val_data"]["int_lengths"],
                                    batch_size      = cfg_data["val_data"]["batch_size"],
                                    num_batches     = cfg_data["val_data"]["num_batches"],
                                    base_seed       = cfg_data["val_data"]["gen_base_seed"],
                                    reproducible    = cfg_data["val_data"]["gen_reproducible"],
                                    negative_char   = cfg_data["negative_char"],
                                    dtype           = cfg_data.get("dtype_int", tf.int32),)
    
    ##  Log a sample validation batch
    logger.info(f"Validation data generator created with the following config: {val_gen}")
    (X, Y_in), Y_out = val_gen[0]
    logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")

    ##  Create test data generator
    test_gen = RandomNumberGenerator(
                                    token_transform = token_transform, 
                                    int_lengths     = cfg_data["test_data"]["int_lengths"],
                                    batch_size      = cfg_data["test_data"]["batch_size"],
                                    num_batches     = cfg_data["test_data"]["num_batches"],
                                    base_seed       = cfg_data["test_data"]["gen_base_seed"],
                                    reproducible    = cfg_data["test_data"]["gen_reproducible"],
                                    negative_char   = cfg_data["negative_char"],
                                    dtype           = cfg_data.get("dtype_int", tf.int32),)
    
    ##  Log a sample test batch
    logger.info(f"Test data generator created with the following config: {test_gen}")
    (X, Y_in), Y_out = test_gen[0]
    logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")
    
    ##  Return all three generators
    return train_gen, train_gen_reproducible, val_gen, test_gen


In [None]:
##==============================================================##
##   Create transformer wrapper for model and token_transform   ##
##==============================================================##

transformer = transformers.Transformer_Text_to_Text(model, token_transform)


In [None]:
##============================##
##   Create data generators   ##
##============================##

train_gen, train_gen_reproducible, val_gen, test_gen = get_data_generators(cfg_data, token_transform)


In [None]:
##=========================================##
##   Test transformer on data generators   ##
##=========================================##

negative_char = cfg_data.get("negative_char")
backend.test_transformer(transformer, train_gen, val_gen, test_gen, negative_char=negative_char)


In [None]:
##===================================##
##   Create callbacks for training   ##
##===================================##

callbacks = backend.get_callbacks(cfg_training, working_dir, transformer=transformer, train_gen=train_gen_reproducible, 
                                  val_gen=val_gen, negative_char=negative_char)


In [None]:
##=================##
##   Train model   ##
##=================##

do_train = cfg_training.get("train", True)

if do_train :
    max_epochs = cfg_training["max_epochs"]
    logger.info(f"Begin model training with max_epochs={max_epochs}")
    model.fit(train_gen, 
              epochs          = max_epochs,
              validation_data = val_gen,
              callbacks       = callbacks
             )
else :
    logger.warning("Skipping model training following global config instructions")


In [None]:
[X, Y_in], Y_out = train_gen_reproducible[0]

model.predict([X, Y_in])