In [1]:
##=========================##
##   All imports go here   ##
##=========================##

##  Import entire python stdlib packages
import logging, os, sys

##  Import entire pypi packages
import tensorflow as tf
import numpy      as np

##  Remove tensorflow INFO messages
tf.get_logger().setLevel('WARNING')

##  Add directory above this to system path to expose mathsformer package location
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

##  Import individual modules/objects from packages
from matplotlib  import pyplot as plt
from mathsformer import config, data, transformers, utils, tf_objects as tfo
from mathsformer import selfsupervised_learning_addition_model_backend as backend


In [2]:
##==============================##
##   Set custom config values   ##
##==============================##

custom_config = {
    "global" : {
        "model_tag"        : "loop_idem2",
        "base_seed"        : -1,
        "working_dir"      : "multi_loop_investigation_[date]",
        "log_lvl_iostream" : logging.INFO,
        "log_lvl_fstream"  : logging.DEBUG,
    },
    "data" : {
        "train_data" : {
            "int_lengths"      : [1, 2, 3, 4],
            "num_ints"         : [1, 2, 4, 5],
        },
        "test_data" : {
            "int_lengths"      : [3],
            "num_ints"         : [7],
            "batch_size"       : 32,
            "num_batches"      : 10,
            "gen_base_seed"    : 200,
            "gen_reproducible" : True,
        },
        "characters"              : ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-'],
        "mask_char"               : 'M',
        "seq_start_char"          : 'B',
        "seq_end_char"            : 'E',
        "negative_char"           : 'N',
        "dtype"                   : "int32",
    },
    "models" : {
        "num_loops"  : 13,
        "loop_idem0" : "save/SSL_loopy_enc_dec_notebook_int1234_num1245_embed128_enc_2blocks_5loops_width512_dec_2blocks_1loops_width512_post3_width512_idemm1_2023_06_25_v2/final_model.keras",
        "loop_idem2" : "SSL_loopy_enc_dec_notebook_int1234_num1245_embed128_enc_2blocks_5loops_width512_dec_2blocks_1loops_width512_post3_width512_idem2_2023_06_27/model_checkpoint_epoch69_val_loss_0.030081.keras"
    },
    "evaluate" : {
        "num_print" : 50,
    },
}


In [3]:
##===================================##
##   Load and validate full config   ##
##===================================##

##  Create config object containing default values
cfg = config.Config(backend.DEFAULT_CONFIG)

##  Override with custom values
cfg.load_dict(custom_config)

##  Validate config
backend.validate_config(cfg)

##  Print success
print(utils.fancy_message(f"Config created"))

##  For convenience, split configs for different sections
cfg_global   = cfg["global"  ]
cfg_data     = cfg["data"    ]
cfg_model    = cfg["models"  ]
cfg_training = cfg["training"]
cfg_evaluate = cfg["evaluate"]


===   Config created   ===


In [4]:
##==============================##
##   Create working directory   ##
##==============================##

##  Report success
working_dir, logger, base_seed, np_seed, tf_seed = utils.initialise_program(
    "multi_loop_investigation (notebook)", 
    working_dir       = cfg_global["working_dir"], 
    cfg               = cfg,
    base_seed         = cfg_global["base_seed"],
    log_lvl_iostream  = cfg_global["log_lvl_iostream"],
    log_lvl_fstream   = cfg_global["log_lvl_fstream" ],
)


===   Working directory created at multi_loop_investigation_2023_06_28_v3   ===
   INFO initialise_logging: Begin logging on 2023-06-28 at 11:02:00
   INFO initialise_program: Program description: multi_loop_investigation (notebook)
   INFO initialise_program: Working directory: multi_loop_investigation_2023_06_28_v3
   INFO log_versions: ------------------------------------------------------+------------------------------------------------------
   INFO log_versions:                                              PACKAGE  |  VERSION
   INFO log_versions: ------------------------------------------------------+------------------------------------------------------
   INFO log_versions:                                               Python  |  3.11.3 (main, May 15 2023, 18:01:31) [Clang 14.0.6 ]
   INFO log_versions:                                              IPython  |  8.14.0
   INFO log_versions:                                 IPython.core.release  |  8.14.0
   INFO log_versions:     

   INFO log_versions:                                               pydevd  |  2.9.5
   INFO log_versions:                                             pygments  |  2.15.1
   INFO log_versions:                                            pyparsing  |  3.1.0
   INFO log_versions:                                                   re  |  2.2.1
   INFO log_versions:                                             requests  |  2.31.0
   INFO log_versions:                                 requests.__version__  |  2.31.0
   INFO log_versions:                                                 idna  |  3.4
   INFO log_versions:                                        idna.idnadata  |  15.0.0
   INFO log_versions:                                    idna.package_data  |  3.4
   INFO log_versions:                                              urllib3  |  1.26.16
   INFO log_versions:                                     urllib3._version  |  1.26.16
   INFO log_versions:                                   urlli

   INFO initialise_program: Registered config value data > test_data > gen_reproducible: True
   INFO initialise_program: Registered config value data > characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO initialise_program: Registered config value data > mask_char: M
   INFO initialise_program: Registered config value data > seq_start_char: B
   INFO initialise_program: Registered config value data > seq_end_char: E
   INFO initialise_program: Registered config value data > negative_char: N
   INFO initialise_program: Registered config value data > dtype: int32
   INFO initialise_program: Registered config value model > load_pretrained_model: None
   INFO initialise_program: Registered config value model > name: mathsformer_LLM
   INFO initialise_program: Registered config value model > dtype: float32
   INFO initialise_program: Registered config value model > dropout: 0.1
   INFO initialise_program: Registered config value model > lear

In [5]:
##======================##
##   Create model   ##
##======================##

model_tag   = cfg_global["model_tag"]
model_fname = cfg_model[model_tag]

base_model = backend.load_text_to_text_model(model_fname)




In [6]:
for layer in base_model.layers :
    print(layer.name)
    
layers = {layer.name:layer for layer in base_model.layers}

mathsformer_LLM_encoder_input_layer
mathsformer_LLM_encoder_enumerate
mathsformer_LLM_encoder_embedding
mathsformer_LLM_encoder_position_encoding
mathsformer_LLM_encoder_emb_and_pos
mathsformer_LLM_encoder_block_1
mathsformer_LLM_encoder_block_2
mathsformer_LLM_decoder_input_layer
mathsformer_LLM_decoder_enumerate
mathsformer_LLM_decoder_embedding
mathsformer_LLM_decoder_position_encoding
mathsformer_LLM_decoder_emb_and_pos
mathsformer_LLM_encoder_output_norm
mathsformer_LLM_encoder_output_norm_idem0
mathsformer_LLM_encoder_output_norm_idem1
mathsformer_LLM_decoder_block_1
mathsformer_LLM_decoder_block_2
mathsformer_LLM_output


In [7]:
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
    
x_in_enc = Input((None,), dtype=tf.int32, name=f"multi_loop_encoder_input_layer")
x_in_dec = Input((None,), dtype=tf.int32, name=f"multi_loop_decoder_input_layer")

x_enc1 = layers["mathsformer_LLM_encoder_embedding"        ](x_in_enc)
x_enc2 = layers["mathsformer_LLM_encoder_enumerate"        ](x_in_enc)
x_enc2 = layers["mathsformer_LLM_encoder_position_encoding"](x_enc2)
x_enc  = layers["mathsformer_LLM_encoder_emb_and_pos"      ]([x_enc1, x_enc2])

x_dec  = x_in_dec
x_dec1 = layers["mathsformer_LLM_decoder_embedding"        ](x_in_dec)
x_dec2 = layers["mathsformer_LLM_decoder_enumerate"        ](x_in_dec)
x_dec2 = layers["mathsformer_LLM_decoder_position_encoding"](x_dec2)
x_dec  = layers["mathsformer_LLM_decoder_emb_and_pos"      ]([x_dec1, x_dec2])

loop_models = {-1:base_model}
for num_loops in range(1, cfg_model["num_loops"]) :
    x_enc_this, x_dec_this = x_enc, x_dec
    for loop_idx in range(num_loops) :
        x_enc_this = layers["mathsformer_LLM_encoder_block_1"](x_enc_this)
        x_enc_this = layers["mathsformer_LLM_encoder_block_2"](x_enc_this)
    x_enc_this = layers["mathsformer_LLM_encoder_output_norm"](x_enc_this)
    x_dec_this = layers["mathsformer_LLM_decoder_block_1"]([x_dec_this, x_enc_this])
    x_dec_this = layers["mathsformer_LLM_decoder_block_2"]([x_dec_this, x_enc_this])
    x_out = layers["mathsformer_LLM_output"](x_dec_this)
    model = Model([x_in_enc, x_in_dec], x_out, name=f"multi_loop_model_{num_loops}loops")
    acc   = tfo.MaskedCategoricalAccuracy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0)
    loss  = tfo.MaskedSparseCategoricalCrossentropy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0, from_logits=True)
    model.compile(loss=loss, metrics=[acc])
    loop_models[num_loops] = model
    

In [8]:
##======================##
##   Create tokeniser   ##
##======================##

token_transform = data.TokenTransform.from_dictionary(cfg_data)
token_transform.summary(print_fn=logger.info)


   INFO summary: TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']


INFO:mathsformer:TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']


   INFO summary: Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)


INFO:mathsformer:Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)


   INFO summary: Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}


INFO:mathsformer:Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}


   INFO summary: Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


INFO:mathsformer:Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


In [9]:

num_ints    = cfg_data["test_data"]["num_ints"]
int_lengths = cfg_data["test_data"]["int_lengths"]

model_token_accs, model_result_accs = {}, {}

for num_loops, model in loop_models.items() :
    
    token_accs, result_accs = [], []

    for N in num_ints :

        token_accs .append([])
        result_accs.append([])

        for L in int_lengths :

            print(f"Running ({num_loops} , {N} , {L})")

            data_gen = data.RandomDataGenerator_Addition(
                                            token_transform = token_transform, 
                                            int_lengths     = [L],
                                            num_ints        = [N],
                                            batch_size      = cfg_data["test_data"]["batch_size"],
                                            num_batches     = cfg_data["test_data"]["num_batches"],
                                            base_seed       = cfg_data["test_data"]["gen_base_seed"],
                                            reproducible    = cfg_data["test_data"]["gen_reproducible"],
                                            negative_char   = cfg_data["negative_char"],)

            evals = model.evaluate(data_gen)
            if "masked_categorical_accuracy" in model.metrics_names :
                metric_idx = model.metrics_names.index("masked_categorical_accuracy")
            else : 
                metric_idx = model.metrics_names.index("mathsformer_LLM_output_masked_categorical_accuracy")
            token_accs[-1].append(evals[metric_idx])

            Nt, Nf = 0, 0
            for X, Y_true in data_gen :
                Y_true = Y_true.numpy()
                Y_pred = model.predict(X)
                if type(Y_pred) is list :
                    Y_pred = Y_pred[0]
                Y_pred = np.argmax(Y_pred, axis=-1)
                for y_true, y_pred in zip(Y_true, Y_pred) :
                    y_true, y_pred = y_true[y_true != 0], y_pred[y_true != 0]
                    if (y_true == y_pred).all() : Nt += 1
                    else : Nf += 1

            acc = Nt / (Nt + Nf)
            result_accs[-1].append(acc)
            print(f"---->  {acc}")
            
        model_token_accs [num_loops] = token_accs
        model_result_accs[num_loops] = result_accs


Running (-1 , 7 , 3)


2023-06-28 11:02:10.275898: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


---->  0.0
Running (1 , 7 , 3)
---->  0.0
Running (2 , 7 , 3)
---->  0.0
Running (3 , 7 , 3)
---->  0.0
Running (4 , 7 , 3)
---->  0.0
Running (5 , 7 , 3)
---->  0.0
Running (6 , 7 , 3)
---->  0.0
Running (7 , 7 , 3)
---->  0.0
Running (8 , 7 , 3)
---->  0.003125
Running (9 , 7 , 3)
---->  0.003125
Running (10 , 7 , 3)
---->  0.003125
Running (11 , 7 , 3)


---->  0.003125
Running (12 , 7 , 3)
---->  0.003125


In [10]:
model_result_accs

{-1: [[0.0]],
 1: [[0.0]],
 2: [[0.0]],
 3: [[0.0]],
 4: [[0.0]],
 5: [[0.0]],
 6: [[0.0]],
 7: [[0.0]],
 8: [[0.003125]],
 9: [[0.003125]],
 10: [[0.003125]],
 11: [[0.003125]],
 12: [[0.003125]]}

In [11]:

transformer = transformers.Transformer_Text_to_Text(loop_models[5], token_transform)

transformer.print_predictions_table(
    data_gen, 
    num_print      = 20, 
    max_tokens     = 20, 
    max_col_length = 30, 
    negative_char  = 'N', 
    print_fn       = logger.info)


   INFO print_predictions_table: ------------------------------------------------------------------------------------------------


INFO:mathsformer:------------------------------------------------------------------------------------------------


   INFO print_predictions_table:                           INPUT         TRUE   PRED(MASK)    PRED(GEN)      CORRECT     RESIDUAL


INFO:mathsformer:                          INPUT         TRUE   PRED(MASK)    PRED(GEN)      CORRECT     RESIDUAL


   INFO print_predictions_table: ------------------------------------------------------------------------------------------------


INFO:mathsformer:------------------------------------------------------------------------------------------------


   INFO print_predictions_table:  147+781-N491+858-N760-N738-142         3633        1527E        15927                     12294


INFO:mathsformer: 147+781-N491+858-N760-N738-142         3633        1527E        15927                     12294


   INFO print_predictions_table:  563+546+140+N527+N967+N275+211         N309        N1254        N1254                      -945


INFO:mathsformer: 563+546+140+N527+N967+N275+211         N309        N1254        N1254                      -945


   INFO print_predictions_table:   211-407-536-179-N641+672-N197          599        1887E         1518                       919


INFO:mathsformer:  211-407-536-179-N641+672-N197          599        1887E         1518                       919


   INFO print_predictions_table:    N802+183-911+123-881-584-430        N3302        N7722        N7722                     -4420


INFO:mathsformer:   N802+183-911+123-881-584-430        N3302        N7722        N7722                     -4420


   INFO print_predictions_table:  512+N958-N646-N560+N944-172...        N1350        N1896        N1897                      -547


INFO:mathsformer: 512+N958-N646-N560+N944-172...        N1350        N1896        N1897                      -547


   INFO print_predictions_table:  N472-N938-354-N815-349-185+...         N260        N1277        N1277                     -1017


INFO:mathsformer: N472-N938-354-N815-349-185+...         N260        N1277        N1277                     -1017


   INFO print_predictions_table:  N310+998+N350-N765-987+677-...         1415        1102E         1112                      -303


INFO:mathsformer: N310+998+N350-N765-987+677-...         1415        1102E         1112                      -303


   INFO print_predictions_table:  992+N727+356-N438+N799-868-...         N461        N8253        N8233                     -7772


INFO:mathsformer: 992+N727+356-N438+N799-868-...         N461        N8253        N8233                     -7772


   INFO print_predictions_table:  274-424+593+250-N748+N766-N969         1644        N023E        N5077                     -6721


INFO:mathsformer: 274-424+593+250-N748+N766-N969         1644        N023E        N5077                     -6721


   INFO print_predictions_table:  N435-N896-774+N639-169+N446...         N690        N5595        N5595                     -4905


INFO:mathsformer: N435-N896-774+N639-169+N446...         N690        N5595        N5595                     -4905


   INFO print_predictions_table:  N425+182+285-985-799+N427+N609        N2778        N1420       N10520                     -7742


INFO:mathsformer: N425+182+285-985-799+N427+N609        N2778        N1420       N10520                     -7742


   INFO print_predictions_table:  251+334+N131+390-N269-N312-425         1000        4167E         4195                      3195


INFO:mathsformer: 251+334+N131+390-N269-N312-425         1000        4167E         4195                      3195


   INFO print_predictions_table:  N359+115-984-848-734+N211-N937        N2084        N5805        N5805                     -3721


INFO:mathsformer: N359+115-984-848-734+N211-N937        N2084        N5805        N5805                     -3721


   INFO print_predictions_table:  793+N944+164-N985-978+N533+...         N793        N4102        N4102                     -3309


INFO:mathsformer: 793+N944+164-N985-978+N533+...         N793        N4102        N4102                     -3309


   INFO print_predictions_table:  N184+607+495+N717-N493+N548...         1074        NN71E        N4839                     -5913


INFO:mathsformer: N184+607+495+N717-N493+N548...         1074        NN71E        N4839                     -5913


   INFO print_predictions_table:  N517-N782+N797+N381+251-669...         N539        N7178        N7178                     -6639


INFO:mathsformer: N517-N782+N797+N381+251-669...         N539        N7178        N7178                     -6639


   INFO print_predictions_table:  N853-N527-N707-N740+937+476...         2428        6399E         6309                      3881


INFO:mathsformer: N853-N527-N707-N740+937+476...         2428        6399E         6309                      3881


   INFO print_predictions_table:  829-189-N595-967+N872-N302+...        N1250        21433         2527                      3777


INFO:mathsformer: 829-189-N595-967+N872-N302+...        N1250        21433         2527                      3777


   INFO print_predictions_table:  410+N295+N243-N721-458+146-310          N29        1645E         1385                      1414


INFO:mathsformer: 410+N295+N243-N721-458+146-310          N29        1645E         1385                      1414


   INFO print_predictions_table:  456-719+N674-976-N537-116-N277        N1215        N1473        N1473                      -258


INFO:mathsformer: 456-719+N674-976-N537-116-N277        N1215        N1473        N1473                      -258
