#  Unsupervised learning addition model with generator

Author: S. Menary [sbmenary@gmail.com]

Date: 5/4/2023  (last update: 9/4/2023)

Overview: Train a `sequence -> sequence` model where the input sequence is a text representation of a simple sum $\sum_{i=1}^N A_i$ for a configurable number $N$ of integers $A_i\in\mathbb{Z}$, and the output is a set of logits representing the probability of each token in the output sequence. Integers may have a configurable number of digits. At inference time, chains of text are generated auto-regressively until the terminate-sequence token is reached. The loss function is a sparse categorical entropy.

---

## 1. Set up program

###  Import

All imports go here at the top of the notebook

In [1]:
##=========================##
##   All imports go here   ##
##=========================##

##  Import entire python stdlib packages
import logging, math, os, pickle, random, sys, time

##  Import entire pypi packages
import numpy      as np
import tensorflow as tf

##  Import individual modules/objects from python stdlib packages
from pathlib import Path

##  Import individual modules/objects from pypi packages
from tensorflow.keras.callbacks  import EarlyStopping, ModelCheckpoint

##  Add directory above this to system path to expose mathsformer package location
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

##  Import individual modules/objects from local packages
from mathsformer.data        import TokenTransform
from mathsformer.tf_objects  import (create_custom_objects_dict, masked_accuracy, masked_sparse_categorical_crossentropy, 
                                     MetricRecord, LayerWeightsRecord, LoggerCallback)
from mathsformer.transformer import create_text_to_text_model, Transformer_Text_to_Text
from mathsformer.utils       import CustomLogLevel, create_working_directory, fancy_message, initialise_program


### Configuation

Set global configuration variables

In [2]:

from tensorflow.keras.callbacks import Callback

class AdaptiveLearningRate(Callback) :
    
    def __init__(self, decay_factor:float, patience:int=1, monitor:str=None, mode:str='min', 
                 variable=None, logger=None, log_lvl:int=logging.DEBUG) :
        """
        """
        ##  Check that a valid mode was provided
        mode = mode.lower()
        if mode not in ["min", "max"] :
            raise ValueError(f"mode must be 'min' or 'max', but '{mode}' provided")
            
        ##  Initialise constant variables
        self.decay_factor = decay_factor
        self.patience     = patience
        self.monitor      = monitor
        self.mode         = mode
        self.variable     = variable
        self.logger       = logger
        self.log_lvl      = log_lvl
        
        
    def on_epoch_end(self, epoch_idx:int, logs:dict) :
        """
        """
        ##  Get latest value
        monitor_val = logs[self.monitor]
        
        ##  Update run variables
        if np.isnan(self.best_val) or (self.mode == 'min' and  monitor_val < self.best_val) or (self.mode == 'max' and  monitor_val > self.best_val) :
            self.best_val = monitor_val
            self.num_itr  = 0
        else :
            self.num_itr += 1
            
        ##  Update learning rate
        if self.num_itr == self.patience :
            self.update()
            self.reset_run_variables()
        
        
    def on_train_begin(self, logs:dict) :
        """
        """
        
        ##  If no variable provided then fall back to model.optimizer.learning_rate
        if not self.variable :
            self.variable = self.model.optimizer.learning_rate
            if self.logger :
                self.logger.log(self.log_lvl, f"Setting variable to {self.variable.name}")
            
        ##  If no metric provided then fall back to self.model.metric_names[0]
        if not self.monitor :
            self.monitor = self.model.metric_names[0]
            if self.logger :
                self.logger.log(self.log_lvl, f"Setting monitor to {self.monitor}")
                
        ##  Initialise run variables
        self.reset_run_variables()
        
    
    def reset_run_variables(self) :
        """
        """
        self.best_val = np.nan
        self.num_itr  = 0
                
        
    def update(self) :
        """
        """
        current_value = self.variable.value()
        new_value     = self.decay_factor * current_value
        if logger :
            logger.log(self.log_lvl, f"Updating variable {self.variable.name} from {current_value.numpy()} to {new_value.numpy()}")
        self.variable.assign(new_value)
        

In [3]:
##===================##
##   Global config   ##
##===================##

##  Create dictionary of config values
##  -  config values to be set here and never changed!
##  -  use nested dictionary as a proxy for namespacing
global_config = {
    "global" : {
        "base_seed"         : -1,
        "working_directory" : "unsupervised_learning_addition_model_generator_[tag]_[date]",
        "tag"               : "baseline",
        "log_lvl_iostream"  : logging.INFO,
        "log_lvl_fstream"   : logging.DEBUG,
    },
    "data" : {
        "train_data" : {
            "int_lengths"      : [1, 2, 3],
            "num_ints"         : [1, 2, 4],
            "batch_size"       : 32,
            "num_batchs"       : 4000,
            "gen_base_seed"    : 100,
            "gen_reproducible" : False, 
        },
        "val_data" : {
            "int_lengths"      : [1, 2, 3],
            "num_ints"         : [3],
            "batch_size"       : 32,
            "num_batchs"       : 500,
            "gen_base_seed"    : 101,
            "gen_reproducible" : True,
        },
        "test_data" : {
            "int_lengths"      : [1, 2, 3],
            "num_ints"         : [3],
            "batch_size"       : 32,
            "num_batchs"       : 1000,
            "gen_base_seed"    : 102,
            "gen_reproducible" : True,
        },
        "characters"              : ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-'],
        "mask_char"               : 'M',
        "seq_start_char"          : 'B',
        "seq_end_char"            : 'E',
        "negative_char"           : 'N',
        "dtype"                   : "int32",
    },
    "model" : {
        "load_pretrained_model" : "",
        "name"                  : "mathsformer_LLM",
        "dtype"                 : "float32",
        "dropout"               : 0.1,
        "learning_rate"         : 1e-2,
        "positional_encoding" : {
            "num_freqs"         : 16,
            "min_period"        : 4,
            "max_period"        : 400,
        },
        "ndim_embedding"        : 32,
        "comb_type"             : 'average',
        "pre_encoder"           : {
            "num_layers"        : -1,
            "ndim"              : 128,
            "skip_connect"      : True,
        },
        "pre_decoder" : {
            "num_layers"        : -1,
            "ndim"              : 128,
            "skip_connect"      : True,
        },
        "encoder" : {
            "num_blocks"        : 4,
            "num_heads"         : 8,
            "ndim"              : 32,
            "ndim_att_hidden"   : 32,
            "ndim_ff_hidden"    : 128,
        },
        "decoder" : {
            "num_blocks"        : 4,
            "num_heads"         : 8,
            "ndim"              : 32,
            "ndim_att_hidden"   : 32,
            "ndim_ff_hidden"    : 128,
        },
        "post_decoder" : {
            "num_layers"        : 3,
            "ndim"              : 256,
        },
    },
    "training" : {
        "train"          : True,
        "max_epochs"     : 100000,
        "log_after_epoch" : {
            "do"          : True,
            "log_lvl"     : CustomLogLevel.DEBUG_LOW,
        },
        "early_stopping" : {
            "do"                   : True,
            "patience"             : 7,
            "monitor"              : "val_masked_accuracy",
            "mode"                 : "max",
            "restore_best_weights" : True,
        },
        "model_checkpoint" : {
            "do"       : True,
            "filename" : "model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5",
        },
        "layer_weights_record" : {
            "do"               : True,
            "batch_frequency"  : 250,
            "recursive"        : True,
        },
        "adaptive_learning_rate" : {
            "do"                 : True,
            "decay_factor"       : 0.3,
            "monitor"            : "val_masked_accuracy",
            "mode"               : "max",
            "patience"           : 1,
            "log_lvl"            : CustomLogLevel.DEBUG_LOW,
        },
    },
    "evaluate" : {
        "num_print" : 20,
    },
}


##  Report success
print(fancy_message(f"Created global_config"))


===   Created global_config   ===


###  Validate config

Look for some obvious confguration errors. WARNING: This is not an exhaustive search and can't be replied upon to catch all misconfigurations!

In [4]:
##===============================##
##   Global config - continued   ##
##===============================##

def validate_config(config) :
    """Raise exceptions in the case of program misconfigurations"""
    mask_char      = global_config["data"]["mask_char"]
    seq_start_char = global_config["data"]["seq_start_char"]
    seq_end_char   = global_config["data"]["seq_end_char"]
    negative_char  = global_config["data"]["negative_char"]
    char_tokens    = global_config["data"]["characters"]
    
    ##  Check that only single character tokens are provided
    for char_token in char_tokens :
        if len(char_token) == 1 : continue
        raise ValueError(f"All character tokens must be single characters but '{char_tokens}' found")
        
    ##  Check mask character is provided
    if len(mask_char) != 1 :
        raise ValueError(f"Mask character must be a single character but '{mask_char}' provided")
        
    ##  Check mask character in character list
    if mask_char not in char_tokens :
        raise ValueError(f"Mask character '{mask_char}' not found in character list: {char_tokens}")
    
    ##  Check that mask character is first in char_tokens list (ensures it's assigned a token of 0)
    if char_tokens[0] != mask_char :
        raise ValueError(f"Mask character '{mask_char}' must be the first in the char_tokens list provided, "
                        +f"instead found list: {char_tokens}")
        
    ##  Check seq_start_char character is provided
    if len(seq_start_char) != 1 :
        raise ValueError(f"Sequence start character must be a single character but '{seq_start_char}' provided")
        
    ##  Check seq_start_char character in character list
    if seq_start_char not in char_tokens :
        raise ValueError(f"Sequence start character '{seq_start_char}' not found in character list: {char_tokens}")
        
    ##  Check seq_end_char character is provided
    if len(seq_end_char) != 1 :
        raise ValueError(f"Sequence end character must be a single character but '{seq_end_char}' provided")
        
    ##  Check seq_start_char character in character list
    if seq_end_char not in char_tokens :
        raise ValueError(f"Sequence end character '{seq_end_char}' not found in character list: {char_tokens}")
        
    ##  Check negative_char character is provided
    if len(negative_char) != 1 :
        raise ValueError(f"Negative symbol character must be a single character but '{negative_char}' provided")
        
    ##  Check negative_char character in character list
    if negative_char not in char_tokens :
        raise ValueError(f"Negative symbol character '{negative_char}' not found in character list: {char_tokens}")
        
    ##  If here then config validated correctly
    print(fancy_message("Config successfully validated"))
    
validate_config(global_config)

===   Config successfully validated   ===


##  2. Set up environment

- Create working directory
- Create logger
- Log package versions for reproducibility
- Log config values for reproducibility
- Set random seeds for reproducibility


In [5]:
##==============================##
##   Create working directory   ##
##==============================##

##  Report success
working_dir, logger, base_seed, np_seed, tf_seed = initialise_program(
    "unsupervised_learning_model notebook", 
    working_dir       = global_config["global"]["working_directory"], 
    global_config     = global_config["global"],
    base_seed         = global_config["global"]["base_seed"],
    log_lvl_iostream  = global_config["global"]["log_lvl_iostream"],
    log_lvl_fstream   = global_config["global"]["log_lvl_fstream" ],
)


===   Working directory created at unsupervised_learning_addition_model_generator_baseline_2023_04_14_v5   ===
   INFO initialise_logging: Begin logging on 2023-04-14 at 16:38:54
   INFO initialise_program: Program description: unsupervised_learning_model notebook
   INFO initialise_program: Working directory: unsupervised_learning_addition_model_generator_baseline_2023_04_14_v5
   INFO log_versions: ------------------------------------------------------+----------------------------------------------------------------------------------
   INFO log_versions:                                              PACKAGE  |  VERSION
   INFO log_versions: ------------------------------------------------------+----------------------------------------------------------------------------------
   INFO log_versions:                                               Python  |  3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
   INFO log_versions:                              

   INFO log_versions:                                 platformdirs.version  |  2.6.0
   INFO log_versions:                                       prompt_toolkit  |  3.0.36
   INFO log_versions:                                               psutil  |  5.9.4
   INFO log_versions:                                           ptyprocess  |  0.7.0
   INFO log_versions:                                            pure_eval  |  0.2.2
   INFO log_versions:                                    pure_eval.version  |  0.2.2
   INFO log_versions:                                               pydevd  |  2.9.1
   INFO log_versions:                                             pygments  |  2.13.0
   INFO log_versions:                                            pyparsing  |  3.0.9
   INFO log_versions:                                                   re  |  2.2.1
   INFO log_versions:                                             requests  |  2.28.1
   INFO log_versions:                                 requests

##  3. Create training data

###  Generate string-string pairs

In [6]:

token_transform = TokenTransform.from_dictionary(global_config["data"])
token_transform.summary(print_fn=logger.info)


   INFO summary: TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO summary: Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)
   INFO summary: Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}
   INFO summary: Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


In [7]:
class RandomDataGenerator(tf.keras.utils.Sequence) :
    
    def __init__(self, token_transform:TokenTransform, int_lengths:list, num_ints:list, batch_size:int, num_batches:int, 
                 base_seed:int=-1, reproducible:bool=False) :
        """
        """
        if base_seed < 0 :
            base_seed = int(time.time())
        
        self.token_transform = token_transform
        self.int_lengths     = int_lengths
        self.num_ints        = num_ints
        self.batch_size      = batch_size
        self.num_batches     = num_batches
        self.base_seed       = base_seed
        self.reproducible    = reproducible
        self.reset_rng()

    
    def __getitem__(self, index:int) :
        """
        Returns ([X, Y_in], Y_out)
        """
        if self.reproducible :
            self.reset_rng(self.base_seed + index)
        x = self.rng.choice(self.num_ints, size=(self.batch_size,))
        y = [self._generate_string(xp, self.int_lengths) for xp in x]
        X, Y = [yp[0] for yp in y], [yp[1] for yp in y]
        X = self.token_transform.strings_to_tensor(X)
        Y = self.token_transform.strings_to_tensor(Y)
        return [X, Y[:,:-1]], Y[:,1:]
    
    
    def __len__(self) :
        """
        """
        return self.num_batches
    
    
    def __str__(self) :
        """
        """
        return f"Generator of {self.num_ints} integers of length {self.int_lengths} in {self.num_batches} batches of size {self.batch_size} (base_seed={self.base_seed}, reproducible={self.reproducible})"
    
    
    def _generate_int_string(self, length) :
        """
        """
        sign        = self.rng.choice(["", "N"])
        lead_char   = str(self.rng.randint(1, 10))
        other_chars = "".join([str(self.rng.randint(0, 10)) for i in range(length-1)])
        return sign + lead_char + other_chars
    
    
    def _generate_int_strings(self, lengths) :
        """
        """
        return np.array([self._generate_int_string(l) for l in lengths])
    
    
    def _generate_string(self, num, lengths) :
        """
        """
        lengths = self.rng.choice(lengths, size=(num,))
        ints    = self._generate_int_strings(lengths)
        out_s, out_i = ints[0], int(ints[0].replace("N","-"))
        for si in ints[1:] :
            f = self.rng.uniform(0, 1)
            i = int(si.replace("N","-"))
            if f < 0.5 :
                out_s += "+" + si
                out_i += i
            else :
                out_s += "-" + si
                out_i -= i
        return out_s, str(out_i).replace("-","N")
    
    
    def get_as_tensors(self, num_batches:int=-1) :
        """
        """
        ##  If num batches not set then return all of them
        if num_batches < 1 :
            num_batches = len(self)
        
        ##  Containers to stores batches
        X, Y_in, Y_out = [], [] ,[]

        ##  Fill containers with batch results
        for i in range(num_batches) :
            [x, yi], yo = self[i]
            X, Y_in, Y_out = X + [x], Y_in + [yi], Y_out + [yo]

        ##  Find max widths of tensors, which currently have ragged shapes
        len_x, len_yi, len_yo = max([xp.shape[1] for xp in X]), max([xp.shape[1] for xp in Y_in]), max([xp.shape[1] for xp in Y_out])

        ##  Pad all tensors to the same width
        for i in range(len(X)) :
            X    [i] = tf.pad(X    [i], [[0, 0], (0, len_x -X    [i].shape[1])])
            Y_in [i] = tf.pad(Y_in [i], [[0, 0], (0, len_yi-Y_in [i].shape[1])])
            Y_out[i] = tf.pad(Y_out[i], [[0, 0], (0, len_yo-Y_out[i].shape[1])])

        ##  Concatenate batch results into single tensor
        X, Y_in, Y_out = tf.concat(X, axis=0), tf.concat(Y_in, axis=0), tf.concat(Y_out, axis=0)
        
        ##  Return
        return X, Y_in, Y_out
    
    
    def print_predictions_table(self, transformer, num_print:int, max_tokens:int=-1) :
        """
        """
        X, Y_in, Y_out = self.get_as_tensors(num_batches=math.ceil(num_print/self.batch_size))

        ##  Log table header
        logger.info("-"*80)
        logger.info("INPUT".rjust(32) + "TRUE".rjust(12) + "PRED".rjust(max([max_tokens,12])) + "CORRECT".rjust(12) + "RESIDUAL".rjust(12))
        logger.info("-"*80)

        ##  Get model predictions and log alongside true labels 
        for x, x_str, true_y_str in zip(X[:num_print], 
                                        transformer.token_transform.detokenise_strings(X    [:num_print,:].numpy()),
                                        transformer.token_transform.detokenise_strings(Y_out[:num_print  ].numpy())) :
            pred_y_str = transformer.transform_from_data_tensor(x, max_tokens=max_tokens)
            result     = "  X  " if pred_y_str == true_y_str else ""
            try    : residual = str(int(pred_y_str.replace("N","-")) - int(true_y_str.replace("N","-")))
            except : residual = "N/A"
            logger.info(x_str.rjust(32) + true_y_str.rjust(12) + pred_y_str.rjust(max([max_tokens,12])) + result.rjust(12) + residual.rjust(10))
            
            
    def reset_rng(self, seed:int=-1) :
        """
        """
        if seed < 0 :
            seed = self.base_seed
        self.rng = np.random.RandomState(seed)
        
        
    def summary(self, print_fn=print) :
        """
        """
        print_fn(str(self))

        

In [8]:

train_gen = RandomDataGenerator(token_transform, 
                                global_config["data"]["train_data"]["int_lengths"],
                                global_config["data"]["train_data"]["num_ints"],
                                global_config["data"]["train_data"]["batch_size"],
                                global_config["data"]["train_data"]["num_batchs"],
                                global_config["data"]["train_data"]["gen_base_seed"],
                                global_config["data"]["train_data"]["gen_reproducible"],)

logger.info(f"Training data generator created with the following config: {train_gen}")
(X, Y_in), Y_out = train_gen[0]
logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")

val_gen   = RandomDataGenerator(token_transform, 
                                global_config["data"]["val_data"]["int_lengths"],
                                global_config["data"]["val_data"]["num_ints"],
                                global_config["data"]["val_data"]["batch_size"],
                                global_config["data"]["val_data"]["num_batchs"],
                                global_config["data"]["val_data"]["gen_base_seed"],
                                global_config["data"]["val_data"]["gen_reproducible"],)

logger.info(f"Validation data generator created with the following config: {val_gen}")
(X, Y_in), Y_out = val_gen[0]
logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")

test_gen  = RandomDataGenerator(token_transform, 
                                global_config["data"]["test_data"]["int_lengths"],
                                global_config["data"]["test_data"]["num_ints"],
                                global_config["data"]["test_data"]["batch_size"],
                                global_config["data"]["test_data"]["num_batchs"],
                                global_config["data"]["test_data"]["gen_base_seed"],
                                global_config["data"]["test_data"]["gen_reproducible"],)

logger.info(f"Test data generator created with the following config: {test_gen}")
(X, Y_in), Y_out = test_gen[0]
logger.info(f"Output shapes for a test batch are ({X.shape}, {Y_in.shape}), {Y_out.shape}")


   INFO <module>: Training data generator created with the following config: Generator of [1, 2, 4] integers of length [1, 2, 3] in 4000 batches of size 32 (base_seed=100, reproducible=False)
   INFO <module>: Output shapes for a test batch are ((32, 19), (32, 6)), (32, 6)
   INFO <module>: Validation data generator created with the following config: Generator of [3] integers of length [1, 2, 3] in 500 batches of size 32 (base_seed=101, reproducible=True)
   INFO <module>: Output shapes for a test batch are ((32, 15), (32, 6)), (32, 6)
   INFO <module>: Test data generator created with the following config: Generator of [3] integers of length [1, 2, 3] in 1000 batches of size 32 (base_seed=102, reproducible=True)
   INFO <module>: Output shapes for a test batch are ((32, 15), (32, 6)), (32, 6)


##  4.  Create model

In [9]:
##==================================================##
##   Create supervised learning model - continued   ##
##==================================================##

model = create_text_to_text_model(
                     vocab_length             = token_transform.vocab_length, 
                     name                     = global_config["model"]["name"],
                     do_compile               = True,
                     dtype_in                 = token_transform.dtype,
                     dtype                    = global_config["model"]["dtype"],
                     dropout                  = global_config["model"]["dropout"],
                     optimizer_args           = {"learning_rate": global_config["model"]["learning_rate"]},
                     pos_enc_num_freqs        = global_config["model"]["positional_encoding"]["num_freqs"],
                     pos_enc_min_period       = global_config["model"]["positional_encoding"]["min_period"],
                     pos_enc_max_period       = global_config["model"]["positional_encoding"]["max_period"],
                     ndim_embedding           = global_config["model"]["ndim_embedding"],
                     comb_type                = global_config["model"]["comb_type"],
                     num_pre_layers_encoder   = global_config["model"]["pre_encoder"]["num_layers"],
                     ndim_pre_layers_encoder  = global_config["model"]["pre_encoder"]["ndim"],
                     skip_connect_pre_encoder = global_config["model"]["pre_encoder"]["skip_connect"],
                     num_pre_layers_decoder   = global_config["model"]["pre_decoder"]["num_layers"],
                     ndim_pre_layers_decoder  = global_config["model"]["pre_decoder"]["ndim"],
                     skip_connect_pre_decoder = global_config["model"]["pre_decoder"]["skip_connect"],
                     num_encoder_blocks       = global_config["model"]["encoder"]["num_blocks"],
                     ndim_encoder             = global_config["model"]["encoder"]["ndim"],
                     skip_connect_encoder     = True,
                     num_heads_encoder        = global_config["model"]["encoder"]["num_heads"],
                     ndim_att_hidden_encoder  = global_config["model"]["encoder"]["ndim_att_hidden"],
                     ndim_ff_hidden_encoder   = global_config["model"]["encoder"]["ndim_ff_hidden"],
                     num_decoder_blocks       = global_config["model"]["decoder"]["num_blocks"],
                     ndim_decoder             = global_config["model"]["decoder"]["ndim"],
                     skip_connect_decoder     = True,
                     num_heads_decoder        = global_config["model"]["decoder"]["num_heads"],
                     ndim_att_hidden_decoder  = global_config["model"]["decoder"]["ndim_att_hidden"],
                     ndim_ff_hidden_decoder   = global_config["model"]["decoder"]["ndim_ff_hidden"],
                     num_post_layers_decoder  = global_config["model"]["post_decoder"]["num_layers"],
                     ndim_post_layers_decoder = global_config["model"]["post_decoder"]["ndim"])


##  Load model if requested
load_model_fname = global_config.get("model", {}).get("load_pretrained_model", {})
if load_model_fname :
    logger.info(f"Loading model from file: {load_model_fname}")
    custom_objects = create_custom_objects_dict(model=model)
    custom_objects["masked_sparse_categorical_crossentropy"] = masked_sparse_categorical_crossentropy
    custom_objects["masked_accuracy"] = masked_accuracy
    tf.keras.models.load_model(load_model_fname, custom_objects=custom_objects)

##  Create hack to catch model summary
model_summary = []
model.summary(print_fn = lambda s : model_summary.append(s))

logger.info("Model created with summary:")
for s in model_summary : logger.info(s)


   INFO <module>: Model created with summary:
   INFO <module>: Model: "mathsformer_LLM"
   INFO <module>: __________________________________________________________________________________________________
   INFO <module>:  Layer (type)                   Output Shape         Param #     Connected to                     
   INFO <module>:  mathsformer_LLM_encoder_input_  [(None, None)]      0           []                               
   INFO <module>:  layer (InputLayer)                                                                               
   INFO <module>:                                                                                                   
   INFO <module>:  mathsformer_LLM_encoder_enumer  (1, None)           0           ['mathsformer_LLM_encoder_input_l
   INFO <module>:  ate (Enumerate)                                                 ayer[0][0]']                     
   INFO <module>:                                                                           

   INFO <module>:                                                                                                   
   INFO <module>:  mathsformer_LLM_feedfwd_block_  (None, None, 16)    145680      ['mathsformer_LLM_decoder_block_4
   INFO <module>:  post_attention (FeedForwardBlo                                  [0][0]']                         
   INFO <module>:  ck)                                                                                              
   INFO <module>:                                                                                                   
   INFO <module>: Total params: 619,156
   INFO <module>: Trainable params: 619,156
   INFO <module>: Non-trainable params: 0
   INFO <module>: __________________________________________________________________________________________________


In [10]:

##  Create transformer wrapper for model and token_transform
transformer = Transformer_Text_to_Text(model, token_transform)

##  5.  Train model

In [11]:
##===================================##
##   Create callbacks for training   ##
##===================================##


##  Create list of training callbacks
callbacks = []


##  Add logger callback
logger_callback_config = global_config["training"].get("log_after_epoch", {})
if logger_callback_config.get("do", True) :
    log_lvl = logger_callback_config.get("log_lvl", logging.DEBUG)
    callbacks.append(LoggerCallback(logger, loglvl=log_lvl))
    logger.info(f"Registered training callback: LoggerCallback with loglvl={log_lvl}")


##  Add callback for early stopping
early_stopping_config = global_config["training"].get("early_stopping", {})
if early_stopping_config.get("do", False) :
    monitor              = early_stopping_config.get("monitor"                , "val_loss")
    mode                 = early_stopping_config.get("mode"                   , 'min')
    restore_best_weights = early_stopping_config.get("restore_best_weights"   , True      )
    patience             = early_stopping_config.get("early_stopping_patience", 1         )
    callbacks.append(EarlyStopping(monitor              = monitor, 
                                   mode                 = mode,
                                   patience             = patience, 
                                   restore_best_weights = restore_best_weights))
    logger.info(f"Registered training callback: EarlyStopping with monitor={monitor}, mode={mode}, patience={patience}, restore_best_weights={restore_best_weights}")
    
    
## Adaptive learning rate
adaptive_learning_rate_config = global_config["training"].get("adaptive_learning_rate", {})
if adaptive_learning_rate_config.get("do", False) :
    decay_factor = adaptive_learning_rate_config.get("decay_factor", 0.5)
    patience     = adaptive_learning_rate_config.get("patience"    , 1)
    monitor      = adaptive_learning_rate_config.get("monitor"     , None)
    mode         = adaptive_learning_rate_config.get("mode"        , 'min')
    log_lvl      = adaptive_learning_rate_config.get("log_lvl"     , logging.DEBUG)
    callbacks.append(AdaptiveLearningRate(decay_factor = decay_factor,
                                          patience     = patience,
                                          monitor      = monitor,
                                          mode         = mode,
                                          logger       = logger,
                                          log_lvl      = log_lvl,))
    logger.info(f"Registeried training callback: AdaptiveLearningRate with decay_factor={decay_factor}, patience={patience}, monitor={monitor}, mode={mode}, log_lvl={log_lvl}")
            
    
## Add callback for model checkpointing
model_checkpoint_config = global_config["training"].get("model_checkpoint", {})
if model_checkpoint_config.get("do", False) :
    filename = model_checkpoint_config.get("filename", "model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5")
    filepath = f"{working_dir}/{filename}"
    callbacks.append(ModelCheckpoint(filepath=filepath))
    logger.info(f"Registeried training callback: ModelCheckpoint with filepath={filepath}")


##  Add callback to record layer weights - use recursive=True to monitor all sublayers
layer_weights_record_config = global_config["training"].get("layer_weights_record", {})
if layer_weights_record_config.get("do", False) :
    batch_frequency = layer_weights_record_config.get("batch_frequency", 1000)
    recursive       = layer_weights_record_config.get("recursive"      , True)
    layer_weights_record = LayerWeightsRecord(batch_frequency = batch_frequency, 
                                              recursive       = recursive      )
    callbacks.append(layer_weights_record)
    logger.info(f"Registered training callback: LayerWeightsRecord with batch_frequency={batch_frequency}, recursive={recursive}")


   INFO <module>: Registered training callback: LoggerCallback with loglvl=14
   INFO <module>: Registered training callback: EarlyStopping with monitor=val_masked_accuracy, mode=max, patience=1, restore_best_weights=True
   INFO <module>: Registeried training callback: AdaptiveLearningRate with decay_factor=0.3, patience=1, monitor=val_masked_accuracy, mode=max, log_lvl=14
   INFO <module>: Registeried training callback: ModelCheckpoint with filepath=unsupervised_learning_addition_model_generator_baseline_2023_04_14_v5/model_checkpoint_epoch{epoch}_val_loss_{val_loss:.5}.h5
   INFO <module>: Registered training callback: LayerWeightsRecord with batch_frequency=250, recursive=True


In [12]:
##==================================##
##   Test training data generator   ##
##==================================##

logger.info("Running text --> text mathsformer inference on some training data:")
train_gen.print_predictions_table(transformer, 10, max_tokens=15)


   INFO <module>: Running text --> text mathsformer inference on some training data:
   INFO print_predictions_table: --------------------------------------------------------------------------------
   INFO print_predictions_table:                            INPUT        TRUE           PRED     CORRECT    RESIDUAL
   INFO print_predictions_table: --------------------------------------------------------------------------------
   INFO print_predictions_table:                         212+N362        N150 +0+993299199++                   N/A
   INFO print_predictions_table:                      25+N11-5+N3           6 +0+993299199++                   N/A
   INFO print_predictions_table:                          N951+N7        N958 +0+993299199++                   N/A
   INFO print_predictions_table:                    81+27-N918+30        1056 B0+993299199++                   N/A
   INFO print_predictions_table:                              234         234 +0+99339919999                  

In [13]:
##====================================##
##   Test validation data generator   ##
##====================================##

logger.info("Running text --> text mathsformer inference on some validation data:")
val_gen.print_predictions_table(transformer, 10, max_tokens=15)


   INFO <module>: Running text --> text mathsformer inference on some validation data:
   INFO print_predictions_table: --------------------------------------------------------------------------------
   INFO print_predictions_table:                            INPUT        TRUE           PRED     CORRECT    RESIDUAL
   INFO print_predictions_table: --------------------------------------------------------------------------------
   INFO print_predictions_table:                               N6          N6 -0+99329919999                   N/A
   INFO print_predictions_table:                  64+N579+24+N124        N615 15+993299199++                   N/A
   INFO print_predictions_table:                    N21-N482-3-N1         459 B0+993299199++                   N/A
   INFO print_predictions_table:                                2           2 -0+99369919999                   N/A
   INFO print_predictions_table:                                9           9 -0+99329919999                

In [None]:
##=================##
##   Train model   ##
##=================##

##  Fit the model if configured
if global_config.get("training",{}).get("train",True) :
    max_epochs = global_config["training"]["max_epochs"]
    logger.info(f"Begin model training with max_epochs={max_epochs}")
    model.fit(train_gen, 
              epochs          = max_epochs,
              validation_data = val_gen,
              callbacks       = callbacks
             )
else :
    logger.warning("Skipping model training following global config instructions")


   INFO <module>: Begin model training with max_epochs=100000
Epoch 1/100000


2023-04-14 16:39:06.591838: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/100000


  layer_config = serialize_layer_fn(layer)


 725/4000 [====>.........................] - ETA: 5:36 - loss: 2.3297 - masked_accuracy: 0.2514

In [None]:
##================##
##   Save model   ##
##================##

if global_config.get("training",{}).get("train",True) :
    save_fname = f"{working_dir}/final_model.h5"
    model.save(save_fname)
    logger.info(f"Model save to file {save_fname}")
else :
    logger.warning("Not saving model because no training was done")


## 6.  Evaluate model

In [None]:
  
##  Find out how many datapoints to print predictions for 
num_print = int(global_config.get("evaluate",{}).get("num_print", 20))


In [None]:

##  Print vs val data
logger.info("Running text --> text mathsformer inference on the train set:")
train_gen.print_predictions_table(transformer, num_print)


In [None]:

##  Print vs val data
logger.info("Running text --> text mathsformer inference on the validation set:")
val_gen.print_predictions_table(transformer, num_print)


In [None]:

##  Print vs val data
logger.info("Running text --> text mathsformer inference on the test set:")
test_gen.print_predictions_table(transformer, num_print)


##  7. Additional visualisations

In [None]:
##=============================================##
##   Visualise layer weights during training   ##
##=============================================##

if len(layer_weights_record.batch_indices) == 0 :
    logger.warning("Not plotting layer weights because no data found")
else :
    logger.info("Plotting layer weights")
    layer_weights_record.plot(num_col=7)
    