In [1]:
import os, sys
sys.path.append("/".join(os.getcwd().split("/")[:-1]))

import numpy      as np
import tensorflow as tf

from tensorflow.keras.layers     import Add, Average, Concatenate, Embedding, Input
from tensorflow.keras.models     import Model
from tensorflow.keras.optimizers import Adam

from mathsformer.tf_objects import (DecoderBlock, EncoderBlock, Enumerate, FeedForwardBlock, LearnableMixture, MaskedCategoricalAccuracy,
                                    MaskedSparseCategoricalCrossentropy, PositionalEncoding)

from mathsformer.data import RandomDataGenerator_Addition, TokenTransform


In [11]:
def create_text_to_text_model(vocab_length:int, 
                              name:str, 
                              do_compile:bool     = True,
                              dtype_in            = tf.int32, 
                              dtype               = tf.float32, 
                              dropout:float       = 0.1, 
                              jit_compile:bool    = None,
                              optimizer           = Adam,
                              optimizer_args:dict = None,
                              pos_enc_num_freqs:int       = 16, pos_enc_min_period:float     = 5, pos_enc_max_period:float = 10000,
                              pos_enc_learnable:bool      = False,
                              ndim_embedding:int          = 32, comb_type:str                = "average",
                              num_pre_layers_encoder:int  = 1 , ndim_pre_layers_encoder:int  = 128, skip_connect_pre_encoder:bool = True,
                              num_pre_layers_decoder:int  = 1 , ndim_pre_layers_decoder:int  = 128, skip_connect_pre_decoder:bool = True,
                              num_encoder_blocks:int      = 2 , ndim_encoder:int             = 32 , skip_connect_encoder:bool     = True,
                              num_heads_encoder:int       = 8 , ndim_att_hidden_encoder:int  = 64 , ndim_ff_hidden_encoder:int    = 64  , 
                              num_decoder_blocks:int      = 2 , ndim_decoder:int             = 32 , skip_connect_decoder:bool     = True,
                              num_heads_decoder:int       = 8 , ndim_att_hidden_decoder:int  = 64 , ndim_ff_hidden_decoder:int    = 64  , 
                              num_post_layers_decoder:int = 2 , ndim_post_layers_decoder:int = 128, 
                             ) :
    """
    """
    ##  Resolve mutable default args
    if optimizer_args is None :
        optimizer_args = {'learning_rate': 0.001}
    
    ##=============================================##
    ##===   Input layer - Output shape [B, S]   ===##
    ##=============================================##
    x_in_enc = Input((None,), dtype=dtype_in, name=f"{name}_encoder_input_layer")
    x_in_dec = Input((None,), dtype=dtype_in, name=f"{name}_decoder_input_layer")
            
    ##===========================================================================##
    ##===  Token embedding, masking 0s - Output shape [B, S, ndim_embedding]  ===##
    ##===========================================================================##
    x_embed_enc = Embedding(vocab_length, 
                            ndim_embedding, 
                            mask_zero=True, 
                            dtype=dtype, 
                            name=f"{name}_encoder_embedding")(x_in_enc)
    x_embed_dec = Embedding(vocab_length, 
                            ndim_embedding, 
                            mask_zero=True, 
                            dtype=dtype, 
                            name=f"{name}_decoder_embedding")(x_in_dec)
    
    model1 = Model(x_in_enc, x_embed_enc, name="model_encoder_embedding")
    model2 = Model(x_in_dec, x_embed_dec, name="model_decoder_embedding")
    
    ##=========================================================================##
    ##===  Enumerate indices for positional encoding - Output shape [B, S]  ===##
    ##=========================================================================##
    ##  -  if comb_type will lead to broadcasting with embeddings later on, then we don't need to repeat the enumerations
    ##     along the batch axis and can use minimal_dims=True for an ouput of [1, S] instead. This saves us memory here
    ##     and reduces the number of operations in the positional encoding step by a factor of B
    minimal_dims = comb_type.lower() in ["add", "sum", "average", "mean", "mixture"]
    x_pos_enc    = Enumerate(name=f"{name}_encoder_enumerate", dtype=dtype)(x_in_enc, minimal_dims=minimal_dims)
    x_pos_dec    = Enumerate(name=f"{name}_decoder_enumerate", dtype=dtype)(x_in_dec, minimal_dims=minimal_dims)
    
    model3 = Model(x_in_enc, x_pos_enc, name="model_encoder_enum")
    model4 = Model(x_in_dec, x_pos_dec, name="model_decoder_enum")
    
    ##========================================================================##
    ##===  Positional encoding - Output shape [B, S, 2*pos_enc_num_freqs]  ===##
    ##========================================================================##
    x_pos_enc = PositionalEncoding(num_freqs  = pos_enc_num_freqs, 
                                   min_period = pos_enc_min_period, 
                                   max_period = pos_enc_max_period, 
                                   learnable  = pos_enc_learnable,
                                   dtype      = dtype, 
                                   name       = f"{name}_encoder_position_encoding")(x_pos_enc)
    x_pos_dec = PositionalEncoding(num_freqs  = pos_enc_num_freqs, 
                                   min_period = pos_enc_min_period, 
                                   max_period = pos_enc_max_period, 
                                   learnable  = pos_enc_learnable,
                                   dtype      = dtype, 
                                   name       = f"{name}_decoder_position_encoding")(x_pos_dec)

    model5 = Model(x_in_enc, x_pos_enc, name="model_encoder_pos_enc")
    model6 = Model(x_in_dec, x_pos_dec, name="model_decoder_pos_enc")
    
    ##==============================================================================================##
    ##===  Combine embeddings end pos enc - Output shape [B, S, N] where N depends on comb_type  ===##
    ##==============================================================================================##
    allowed_comb_types = ["add", "sum", "average", "mean", "concat", "concatenate", "mixture"]
    match comb_type.lower() :
        case "add" | "sum" :
            x_enc = Add(name=f"{name}_encoder_emb_and_pos", dtype=dtype)([x_embed_enc, x_pos_enc])
            x_dec = Add(name=f"{name}_decoder_emb_and_pos", dtype=dtype)([x_embed_dec, x_pos_dec])
        case "average" | "mean" :
            x_enc = Average(name=f"{name}_encoder_emb_and_pos", dtype=dtype)([x_embed_enc, x_pos_enc])
            x_dec = Average(name=f"{name}_decoder_emb_and_pos", dtype=dtype)([x_embed_dec, x_pos_dec])
        case "concat" | "concatenate" :
            x_enc = Concatenate(name=f"{name}_encoder_emb_and_pos", dtype=dtype)([x_embed_enc, x_pos_enc])
            x_dec = Concatenate(name=f"{name}_decoder_emb_and_pos", dtype=dtype)([x_embed_dec, x_pos_dec])
        case "mixture" :
            x_enc = LearnableMixture(name=f"{name}_encoder_emb_and_pos", dtype=dtype)([x_embed_enc, x_pos_enc])
            x_dec = LearnableMixture(name=f"{name}_decoder_emb_and_pos", dtype=dtype)([x_embed_dec, x_pos_dec])
        case _ :
            raise RuntimeError(f"comb_type '{comb_type}' not recognised, recognised keywords are {allowed_comb_types}")

    model7 = Model(x_in_enc, x_enc, name="model_encoder_comb")
    model8 = Model(x_in_dec, x_dec, name="model_decoder_comb")
    
    ##=========================================================================##
    ##===  Initial pre-processing - Output shape [B, S, ndim_(en/de)coder]  ===##
    ##=========================================================================##
    ##  - use layer_norm instead of batch_norm because elements in sequence are not independent
    if num_pre_layers_encoder >= 0 :
        x_enc = FeedForwardBlock(ndim_encoder, 
                                 ndim_hidden       = ndim_pre_layers_encoder, 
                                 num_hidden_layers = num_pre_layers_encoder, 
                                 dropout           = dropout, 
                                 layer_norm        = True, 
                                 batch_norm        = False,  
                                 skip_connect      = skip_connect_pre_encoder, 
                                 dtype             = dtype, 
                                 name              = f"{name}_encoder_feedfwd_block_pre_attention")(x_enc)
    if num_pre_layers_decoder >= 0 :
        x_dec = FeedForwardBlock(ndim_decoder, 
                                 ndim_hidden       = ndim_pre_layers_decoder, 
                                 num_hidden_layers = num_pre_layers_decoder, 
                                 dropout           = dropout, 
                                 layer_norm        = True, 
                                 batch_norm        = False,  
                                 skip_connect      = skip_connect_pre_decoder, 
                                 dtype             = dtype, 
                                 name              = f"{name}_decoder_feedfwd_block_pre_attention")(x_dec)
    
    model9  = Model(x_in_enc, x_enc, name="model_encoder_pre")
    model10 = Model(x_in_dec, x_dec, name="model_decoder_pre")
    
    ##============================================================##
    ##===  Encoder blocks - Output shape [B, S, ndim_encoder]  ===##
    ##============================================================##
    for layer_idx in range(num_encoder_blocks) :
        x_enc = EncoderBlock(ndim_encoder, 
                             num_heads_encoder, 
                             ndim_att_hidden_encoder, 
                             ndim_ff_hidden_encoder, 
                             dropout_mha  = dropout, 
                             dtype        = dtype, 
                             layer_norm   = True, 
                             skip_connect = skip_connect_encoder, 
                             name         = f"{name}_encoder_block_{layer_idx+1}")(x_enc)
    
    model11 = Model(x_in_enc, x_enc, name="model_encoder")
    
    ##============================================================##
    ##===  Decoder blocks - Output shape [B, S, ndim_decoder]  ===##
    ##============================================================##
    for layer_idx in range(num_decoder_blocks) :
        x_dec = DecoderBlock(ndim_decoder, 
                             num_heads_decoder, 
                             ndim_att_hidden_decoder, 
                             ndim_ff_hidden_decoder, 
                             dropout_mha  = dropout, 
                             dtype        = dtype, 
                             layer_norm   = True, 
                             skip_connect = skip_connect_decoder, 
                             name         = f"{name}_decoder_block_{layer_idx+1}")([x_dec, x_enc])
        
    model12 = Model([x_in_enc, x_in_dec], x_dec, name="model_decoder")
    
    ##==================================================================================================##
    ##===  Predict logit probabilities using feed-forward block - Output shape [B, S, vocab_length]  ===##
    ##==================================================================================================##
    ##  - use layer_norm instead of batch_norm because elements in sequence are not independent
    x = FeedForwardBlock(vocab_length, 
                         ndim_hidden       = ndim_post_layers_decoder, 
                         num_hidden_layers = num_post_layers_decoder, 
                         skip_connect      = False, 
                         layer_norm        = True, 
                         batch_norm        = False, 
                         dtype             = dtype, 
                         name              = f"{name}_feedfwd_block_post_attention")(x_dec)
        
    model13 = Model([x_in_enc, x_in_dec], x, name="model_decoder_post")
    
    ##  Create model
    model = Model([x_in_enc, x_in_dec], x, name=name)
    
    ##  Compile model with sparse categorical crossentropy loss and accuracy metric
    if do_compile :
        acc  = MaskedCategoricalAccuracy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0)
        loss = MaskedSparseCategoricalCrossentropy(scalar_output=True, equal_token_weight=True, use_keras_mask=False, mask_value=0, from_logits=True)
        model.compile(loss        = loss, 
                      optimizer   = optimizer(**optimizer_args), 
                      metrics     = [acc],
                      jit_compile = jit_compile)
    
    ##  Return model
    return [model1, model2, model3, model4, model5, model6, model7, model8, model9, model10, model11, model12, model13, model]


In [15]:
models = create_text_to_text_model(32, name="test_model", dtype=tf.float16, comb_type="concat")

RuntimeError: Exception encountered when calling layer "test_model_encoder_feedfwd_block_pre_attention" (type FeedForwardBlock).

in user code:

    File "/Users/Ste/PROJECTS/misc/ML-sandbox/ML-sandbox/Project_Maths_Transformer/mathsformer/tf_objects.py", line 933, in call  *
        if x_dims != y_dims : raise RuntimeError(f"Cannot apply skip-connection combining tensors of different dimensions {x_dims} and {y_dims}")

    RuntimeError: Cannot apply skip-connection combining tensors of different dimensions 64 and 32


Call arguments received by layer "test_model_encoder_feedfwd_block_pre_attention" (type FeedForwardBlock):
  • x=tf.Tensor(shape=(None, None, 64), dtype=float16)
  • training=None
  • mask=tf.Tensor(shape=(None, None), dtype=bool)

In [16]:
models[4].layers[-1].weights

[<tf.Variable 'test_model_encoder_position_encoding_frequencies:0' shape=(1, 16) dtype=float16, numpy=
 array([[6.285e-04, 1.043e-03, 1.731e-03, 2.872e-03, 4.768e-03, 7.919e-03,
         1.314e-02, 2.180e-02, 3.619e-02, 6.009e-02, 9.973e-02, 1.655e-01,
         2.749e-01, 4.561e-01, 7.568e-01, 1.257e+00]], dtype=float16)>]

In [17]:
X = np.random.randint(low=0, high=32, size=(2, 4))
Y = np.random.randint(low=0, high=32, size=(2, 4))

for model in models :
    inp = [X, Y] if type(model.input) is list else X
    print(model.name, model.layers[-1].dtype)
    print(np.mean(model.predict(inp)))
    print(np.mean(model(inp, training=False)))


model_encoder_embedding float16
0.002018
0.002018
model_decoder_embedding float16
0.001136
0.001136
model_encoder_enum float16
1.5
1.5
model_decoder_enum float16
1.5
1.5
model_encoder_pos_enc float16
0.5024
0.5024
model_decoder_pos_enc float16
0.5024
0.5024
model_encoder_comb float16
0.0
0.2524
model_decoder_comb float16
0.0
0.252
model_encoder_pre float16
nan
0.3079
model_decoder_pre float16
nan
0.064
model_encoder float16
nan
0.0265
model_decoder float16
nan
0.1294
model_decoder_post float16
nan
-0.01807
test_model float16
nan
-0.01807


In [6]:

token_transform = TokenTransform(['M', 'B', 'E', 'N', '+', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], 
                                mask_char='M', seq_start_char='B', seq_end_char='E')

gen = RandomDataGenerator_Addition(token_transform, int_lengths=[3], num_ints=[3], batch_size=32, num_batches=1, negative_char='N')

models[-1].evaluate(gen)




[nan, 0.0]

In [7]:
[X1, X2], Y_true = gen[0]

Y_pred = models[-1].predict([X1, X2])



In [8]:
Y_pred

array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]],

       [[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]],

       [[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]],

       ...,

       [[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, na

In [9]:
'''
num_test = 1
X = np.random.randint(2, size=(num_test, 4), dtype=np.int32)
X = tf.constant(X)
print(X)

for dtype in [tf.float16, tf.float32] :
    for use_causal_mask in [True, False] :
        print(dtype, use_causal_mask)
        m = create_model(dtype=dtype, use_causal_mask=use_causal_mask)
        Y = m(X)
        print(Y)
'''

'\nnum_test = 1\nX = np.random.randint(2, size=(num_test, 4), dtype=np.int32)\nX = tf.constant(X)\nprint(X)\n\nfor dtype in [tf.float16, tf.float32] :\n    for use_causal_mask in [True, False] :\n        print(dtype, use_causal_mask)\n        m = create_model(dtype=dtype, use_causal_mask=use_causal_mask)\n        Y = m(X)\n        print(Y)\n'

In [10]:
'''
m = create_model(dtype=tf.float16, use_causal_mask=True)
l = m.layers[2]
print(l.dtype)
for d in l.get_weights() :
    print(d.dtype, d.shape)
'''

'\nm = create_model(dtype=tf.float16, use_causal_mask=True)\nl = m.layers[2]\nprint(l.dtype)\nfor d in l.get_weights() :\n    print(d.dtype, d.shape)\n'