In [1]:
##=========================##
##   All imports go here   ##
##=========================##

##  Import entire python stdlib packages
import glob, logging, os, sys

##  Import entire pypi packages
import tensorflow as tf
import numpy      as np

##  Remove tensorflow INFO messages
tf.get_logger().setLevel('WARNING')

##  Add directory above this to system path to expose mathsformer package location
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

##  Import individual modules/objects from packages
from matplotlib  import patches, pyplot as plt
from mathsformer import config, data, transformers, utils
from mathsformer import selfsupervised_learning_addition_model_backend as backend


In [2]:
##==============================##
##   Set custom config values   ##
##==============================##

custom_config = {
    "global" : {
        "model_tag"        : "loop_idem2",
        "base_seed"        : -1,
        "working_dir"      : "evaluate_training_[date]",
        "log_lvl_iostream" : logging.INFO,
        "log_lvl_fstream"  : logging.DEBUG,
    },
    "data" : {
        "train_data" : {
            "int_lengths"      : [1, 2, 3, 4],
            "num_ints"         : [1, 2, 4, 5],
        },
        "test_data" : {
            "int_lengths"      : [1, 2, 3, 4, 5],
            "num_ints"         : [1, 2, 3, 4, 5, 6, 7, 8],
            "batch_size"       : 32,
            "num_batches"      : 2,
            "gen_base_seed"    : 200,
            "gen_reproducible" : True,
        },
        "characters"              : ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-'],
        "mask_char"               : 'M',
        "seq_start_char"          : 'B',
        "seq_end_char"            : 'E',
        "negative_char"           : 'N',
        "dtype"                   : "int32",
    },
    "models" : {
        "loop_idem0" : "save/SSL_loopy_enc_dec_notebook_int1234_num1245_embed128_enc_2blocks_5loops_width512_dec_2blocks_1loops_width512_post3_width512_idemm1_2023_06_25_v2",
        "loop_idem2" : "SSL_loopy_enc_dec_notebook_int1234_num1245_embed128_enc_2blocks_5loops_width512_dec_2blocks_1loops_width512_post3_width512_idem2_2023_06_27",
    },
    "evaluate" : {
        "skip_first_epochs" : 72,
    },
}


In [3]:
##===================================##
##   Load and validate full config   ##
##===================================##

##  Create config object containing default values
cfg = config.Config(backend.DEFAULT_CONFIG)

##  Override with custom values
cfg.load_dict(custom_config)

##  Validate config
backend.validate_config(cfg)

##  Print success
print(utils.fancy_message(f"Config created"))

##  For convenience, split configs for different sections
cfg_global   = cfg["global"  ]
cfg_data     = cfg["data"    ]
cfg_model    = cfg["models"  ]
cfg_training = cfg["training"]
cfg_evaluate = cfg["evaluate"]


===   Config created   ===


In [4]:
##==============================##
##   Create working directory   ##
##==============================##

##  Report success
working_dir, logger, base_seed, np_seed, tf_seed = utils.initialise_program(
    "evaluate_training (notebook)", 
    working_dir       = cfg_global["working_dir"], 
    cfg               = cfg,
    base_seed         = cfg_global["base_seed"],
    log_lvl_iostream  = cfg_global["log_lvl_iostream"],
    log_lvl_fstream   = cfg_global["log_lvl_fstream" ],
)


===   Working directory created at evaluate_training_2023_06_28_v7   ===
   INFO initialise_logging: Begin logging on 2023-06-28 at 11:42:08
   INFO initialise_program: Program description: evaluate_training (notebook)
   INFO initialise_program: Working directory: evaluate_training_2023_06_28_v7
   INFO log_versions: ------------------------------------------------------+------------------------------------------------------
   INFO log_versions:                                              PACKAGE  |  VERSION
   INFO log_versions: ------------------------------------------------------+------------------------------------------------------
   INFO log_versions:                                               Python  |  3.11.3 (main, May 15 2023, 18:01:31) [Clang 14.0.6 ]
   INFO log_versions:                                              IPython  |  8.14.0
   INFO log_versions:                                 IPython.core.release  |  8.14.0
   INFO log_versions:                          

   INFO log_versions:                                             pygments  |  2.15.1
   INFO log_versions:                                            pyparsing  |  3.1.0
   INFO log_versions:                                                   re  |  2.2.1
   INFO log_versions:                                             requests  |  2.31.0
   INFO log_versions:                                 requests.__version__  |  2.31.0
   INFO log_versions:                                                 idna  |  3.4
   INFO log_versions:                                        idna.idnadata  |  15.0.0
   INFO log_versions:                                    idna.package_data  |  3.4
   INFO log_versions:                                              urllib3  |  1.26.16
   INFO log_versions:                                     urllib3._version  |  1.26.16
   INFO log_versions:                                   urllib3.connection  |  1.26.16
   INFO log_versions:                                 urlli

   INFO initialise_program: Registered config value data > characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO initialise_program: Registered config value data > mask_char: M
   INFO initialise_program: Registered config value data > seq_start_char: B
   INFO initialise_program: Registered config value data > seq_end_char: E
   INFO initialise_program: Registered config value data > negative_char: N
   INFO initialise_program: Registered config value data > dtype: int32
   INFO initialise_program: Registered config value model > load_pretrained_model: None
   INFO initialise_program: Registered config value model > name: mathsformer_LLM
   INFO initialise_program: Registered config value model > dtype: float32
   INFO initialise_program: Registered config value model > dropout: 0.1
   INFO initialise_program: Registered config value model > learning_rate: 0.001
   INFO initialise_program: Registered config value model > jit_compile: Fals

In [5]:
##======================##
##   Get model fnames   ##
##======================##

model_tag   = cfg_global["model_tag"]
model_fname = cfg_model[model_tag]

model_fnames = glob.glob(f"{model_fname}/model_checkpoint*.keras")
model_fnames = sorted(model_fnames)

model_fnames_ = {}

for model_fname in model_fnames :
    start_idx = model_fname.find("model_checkpoint_epoch")
    start_idx = start_idx + 22
    end_idx   = model_fname[start_idx:].find("_")
    end_idx   = start_idx + end_idx
    epoch_num = int(model_fname[start_idx:end_idx])
    model_fnames_[epoch_num] = model_fname

model_fnames = model_fnames_
del model_fnames_

logger.info(f"Found {len(model_fnames)} model files")


   INFO <module>: Found 72 model files


In [6]:
##======================##
##   Create tokeniser   ##
##======================##

token_transform = data.TokenTransform.from_dictionary(cfg_data)
token_transform.summary(print_fn=logger.info)


   INFO summary: TokenTransform of dtype int32 with 16 characters: ['M', 'B', 'E', 'N', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-']
   INFO summary: Special characters are seq_start_char (B), seq_end_char (E), mask_char (M)
   INFO summary: Tokeniser dictionary is {'M': 0, 'B': 1, 'E': 2, 'N': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '+': 14, '-': 15}
   INFO summary: Detokeniser dictionary is {0: 'M', 1: 'B', 2: 'E', 3: 'N', 4: '0', 5: '1', 6: '2', 7: '3', 8: '4', 9: '5', 10: '6', 11: '7', 12: '8', 13: '9', 14: '+', 15: '-'}


In [7]:
##=============================================================##
##   Define method for retrieving answer accuracy from model   ##
##=============================================================##

def get_answer_accuracy(model, token_transform, N, L, verbose=0) :
    data_gen = data.RandomDataGenerator_Addition(
                                        token_transform = token_transform, 
                                        int_lengths     = [L],
                                        num_ints        = [N],
                                        batch_size      = cfg_data["test_data"]["batch_size"],
                                        num_batches     = cfg_data["test_data"]["num_batches"],
                                        base_seed       = cfg_data["test_data"]["gen_base_seed"],
                                        reproducible    = cfg_data["test_data"]["gen_reproducible"],
                                        negative_char   = cfg_data["negative_char"],
    )
    
    Nt, Nf = 0, 0
    for X, Y_true in data_gen :
        Y_true = Y_true.numpy()
        Y_pred = model.predict(X, verbose=0)
        if type(Y_pred) is list :
            Y_pred = Y_pred[0]
        Y_pred = np.argmax(Y_pred, axis=-1)
        for y_true, y_pred in zip(Y_true, Y_pred) :
            y_true, y_pred = y_true[y_true != 0], y_pred[y_true != 0]
            if (y_true == y_pred).all() : Nt += 1
            else : Nf += 1
                            
    acc = Nt / (Nt + Nf)
    return acc


In [8]:
##======================================================##
##   Define method for creating answer accuracy table   ##
##======================================================##

def get_answer_accuracy_table(model, token_transform, num_ints, int_lengths, verbose=0, log=True) :
    result_accs = []
    for N in num_ints :
        result_accs.append([])
        for L in int_lengths :
            result_acc  = get_answer_accuracy(model, token_transform, N, L)
            result_accs[-1].append(result_acc)
            if log :
                logger.info(f"N={N}, L={L} with result accuracy {100.*result_acc:.2f}%")
    return np.array(result_accs)


In [9]:
##==================================##
##   Create answer accuracy table   ##
##==================================##

num_ints    = np.array(cfg_data["test_data"]["num_ints"])
int_lengths = np.array(cfg_data["test_data"]["int_lengths"])


In [10]:
##=====================================##
##   Define method for creating plot   ##
##=====================================##

def make_plot(X, Y, Z, epoch_num=-1, model_tag="unknown", savefig=None, show=True, close=True, dpi=150) :
    
    fig = plt.figure(figsize=(0.55*len(X), 0.55*len(Y)))
    ax  = fig.add_axes([0, 0, 1, 1])
    ax.tick_params(axis="both", which="both", direction="in", right=True, top=True, labelsize=12)

    ax.set_xticks(X-1)
    ax.xaxis.set_ticklabels(X)

    ax.set_yticks(Y-1)
    ax.yaxis.set_ticklabels(Y)

    ax.set_xlabel("Num. integers.", va="top", labelpad=20, fontsize=14)
    ax.set_ylabel("Integer\nlength", va="bottom", ha="right", rotation=0, labelpad=20, fontsize=14)

    cbar = ax.imshow(Z, cmap="Blues", vmin=0, vmax=1)

    cax = fig.add_axes([1.06, 0, 0.05, 1])
    cax.tick_params(axis="both", which="both", direction="in", right=True, top=True, labelsize=12)
    plt.colorbar(cbar, cax=cax)

    cax.set_ylabel("Accuracy\nper-answer", ha="left", va="bottom", labelpad=15, rotation=0, fontsize=14)

    if epoch_num >= 0 :
        ax.text(0, 1.19, f"Epoch {epoch_num}", ha="left", va="bottom", fontsize=8, transform=ax.transAxes,
               bbox={"fc":"palegreen", "ec":"darkgreen", "lw":0.6})
    
    ax.text(0, 1.05, f"Model tag: {model_tag}", ha="left", va="bottom", style="italic", fontsize=10,
           transform=ax.transAxes)

    first_patch = True
    for L in cfg_data["train_data"].get("int_lengths", []) :
        for N in cfg_data["train_data"].get("num_ints", []) :
            rect = patches.Rectangle((N-1.475, L-1.475), 0.95, 0.95, lw=1.5, ls="-", ec="r", fc='none', alpha=1,
                                    label="= inside training region" if first_patch else "")
            ax.add_patch(rect)            
            first_patch = False

    ax.legend(loc=(0.57, 1.03), frameon=False, fontsize=10, handlelength=1, handletextpad=0.4)

    if savefig is not None :
        fig.savefig(savefig, bbox_inches="tight", dpi=dpi)

    if show :
        plt.show(fig)

    if close :
        plt.close(fig)
        

In [11]:
##==================##
##   Create model   ##
##==================##

figs_dir = f"figures/evaluate_training/{model_tag}"
if os.path.exists(figs_dir) :
    pass
    #raise RuntimeError(f"Directory {figs_dir} already exists")
else :
    os.mkdir(figs_dir)
    
    
skip_first_epochs = cfg_evaluate.get("skip_first_epochs", -1)
logger.info(f"Skipping the first {skip_first_epochs} epochs")


for epoch_num in sorted(model_fnames) :
    
    if epoch_num <= skip_first_epochs : continue
    
    model_fname = model_fnames[epoch_num]

    logger.info(f"Processing model at epoch {epoch_num}: {model_fname}")

    model = backend.load_text_to_text_model(model_fname)
    
    try :
        result_accs = get_answer_accuracy_table(model, token_transform, num_ints, int_lengths)
    except :
        logger.warning("Recovered from execution error, trying one more time...")
        result_accs = get_answer_accuracy_table(model, token_transform, num_ints, int_lengths)
    
    make_plot(num_ints, 
              int_lengths, 
              np.transpose(result_accs), 
              model_tag = model_tag, 
              epoch_num = epoch_num,
              savefig   = f"{figs_dir}/evaluate_model_epoch{epoch_num}.png",
              dpi       = 70,
              show      = False
    )


   INFO <module>: Skipping the first 71 epochs
   INFO <module>: Processing model at epoch 72: SSL_loopy_enc_dec_notebook_int1234_num1245_embed128_enc_2blocks_5loops_width512_dec_2blocks_1loops_width512_post3_width512_idem2_2023_06_27/model_checkpoint_epoch72_val_loss_0.027003.keras


2023-06-28 11:42:11.748898: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


   INFO get_answer_accuracy_table: N=1, L=1 with result accuracy 100.00%


INFO:mathsformer:N=1, L=1 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=1, L=2 with result accuracy 100.00%


INFO:mathsformer:N=1, L=2 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=1, L=3 with result accuracy 100.00%


INFO:mathsformer:N=1, L=3 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=1, L=4 with result accuracy 100.00%


INFO:mathsformer:N=1, L=4 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=1, L=5 with result accuracy 0.00%


INFO:mathsformer:N=1, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=2, L=1 with result accuracy 100.00%


INFO:mathsformer:N=2, L=1 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=2, L=2 with result accuracy 100.00%


INFO:mathsformer:N=2, L=2 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=2, L=3 with result accuracy 100.00%


INFO:mathsformer:N=2, L=3 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=2, L=4 with result accuracy 100.00%


INFO:mathsformer:N=2, L=4 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=2, L=5 with result accuracy 0.00%


INFO:mathsformer:N=2, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=3, L=1 with result accuracy 92.19%


INFO:mathsformer:N=3, L=1 with result accuracy 92.19%


   INFO get_answer_accuracy_table: N=3, L=2 with result accuracy 98.44%


INFO:mathsformer:N=3, L=2 with result accuracy 98.44%


   INFO get_answer_accuracy_table: N=3, L=3 with result accuracy 100.00%


INFO:mathsformer:N=3, L=3 with result accuracy 100.00%


   INFO get_answer_accuracy_table: N=3, L=4 with result accuracy 90.62%


INFO:mathsformer:N=3, L=4 with result accuracy 90.62%


   INFO get_answer_accuracy_table: N=3, L=5 with result accuracy 0.00%


INFO:mathsformer:N=3, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=4, L=1 with result accuracy 95.31%


INFO:mathsformer:N=4, L=1 with result accuracy 95.31%


   INFO get_answer_accuracy_table: N=4, L=2 with result accuracy 85.94%


INFO:mathsformer:N=4, L=2 with result accuracy 85.94%


   INFO get_answer_accuracy_table: N=4, L=3 with result accuracy 84.38%


INFO:mathsformer:N=4, L=3 with result accuracy 84.38%


   INFO get_answer_accuracy_table: N=4, L=4 with result accuracy 75.00%


INFO:mathsformer:N=4, L=4 with result accuracy 75.00%


   INFO get_answer_accuracy_table: N=4, L=5 with result accuracy 0.00%


INFO:mathsformer:N=4, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=5, L=1 with result accuracy 84.38%


INFO:mathsformer:N=5, L=1 with result accuracy 84.38%


   INFO get_answer_accuracy_table: N=5, L=2 with result accuracy 84.38%


INFO:mathsformer:N=5, L=2 with result accuracy 84.38%


   INFO get_answer_accuracy_table: N=5, L=3 with result accuracy 64.06%


INFO:mathsformer:N=5, L=3 with result accuracy 64.06%


   INFO get_answer_accuracy_table: N=5, L=4 with result accuracy 43.75%


INFO:mathsformer:N=5, L=4 with result accuracy 43.75%


   INFO get_answer_accuracy_table: N=5, L=5 with result accuracy 0.00%


INFO:mathsformer:N=5, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=6, L=1 with result accuracy 82.81%


INFO:mathsformer:N=6, L=1 with result accuracy 82.81%


   INFO get_answer_accuracy_table: N=6, L=2 with result accuracy 62.50%


INFO:mathsformer:N=6, L=2 with result accuracy 62.50%


   INFO get_answer_accuracy_table: N=6, L=3 with result accuracy 26.56%


INFO:mathsformer:N=6, L=3 with result accuracy 26.56%


   INFO get_answer_accuracy_table: N=6, L=4 with result accuracy 0.00%


INFO:mathsformer:N=6, L=4 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=6, L=5 with result accuracy 0.00%


INFO:mathsformer:N=6, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=7, L=1 with result accuracy 64.06%


INFO:mathsformer:N=7, L=1 with result accuracy 64.06%


   INFO get_answer_accuracy_table: N=7, L=2 with result accuracy 15.62%


INFO:mathsformer:N=7, L=2 with result accuracy 15.62%


   INFO get_answer_accuracy_table: N=7, L=3 with result accuracy 0.00%


INFO:mathsformer:N=7, L=3 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=7, L=4 with result accuracy 0.00%


INFO:mathsformer:N=7, L=4 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=7, L=5 with result accuracy 0.00%


INFO:mathsformer:N=7, L=5 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=8, L=1 with result accuracy 35.94%


INFO:mathsformer:N=8, L=1 with result accuracy 35.94%


   INFO get_answer_accuracy_table: N=8, L=2 with result accuracy 1.56%


INFO:mathsformer:N=8, L=2 with result accuracy 1.56%


   INFO get_answer_accuracy_table: N=8, L=3 with result accuracy 0.00%


INFO:mathsformer:N=8, L=3 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=8, L=4 with result accuracy 0.00%


INFO:mathsformer:N=8, L=4 with result accuracy 0.00%


   INFO get_answer_accuracy_table: N=8, L=5 with result accuracy 0.00%


INFO:mathsformer:N=8, L=5 with result accuracy 0.00%


## Notes

- Potential representational problem: we don't start off with multi-digit numbers in a good representation, and it's hard for our two simple repeated layers to simultaneously balance the transition between representations and the compositional logic within a representation
- Factor methods into functions, loop over all models in training trajectory
- Pre-representation layers before loops?