In [1]:
import os, sys

import numpy as np
import tensorflow as tf

from tensorflow.keras.layers  import Embedding, Input, Softmax
from tensorflow.keras.losses  import sparse_categorical_crossentropy, SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.models  import Model

sys.path.append("/".join(os.getcwd().split("/")[:-1]))

from mathsformer.tf_objects import EncoderBlock, FeedForwardBlock

In [2]:
sparse_categorical_crossentropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction='none')

sparse_accuracy_metric = SparseCategoricalAccuracy()

def masked_accuracy(y, y_pred, mask_value=0) :
    """
    """
    ##  Create the mask wherever the label is zero
    mask = y != mask_value
    
    ##  Get the predicted token-id using an argmax along the final axis
    y_pred  = tf.argmax(y_pred, axis=-1)
    
    ##  Determine whether the predicted token matches the label
    y_pred = tf.cast(y_pred, y.dtype)
    match  = y == y_pred
    
    ##  Mask the matches
    match = match & mask
    
    ##  Cast matches and mask to float, and compute masked average
    match = tf.cast(match, dtype=tf.float32)
    mask  = tf.cast(mask , dtype=tf.float32)
    acc   = tf.reduce_sum(match) / tf.reduce_sum(mask)
    
    ##  Return accuracy
    return acc

def masked_sparse_categorical_crossentropy(y, y_pred, mask_value:int=0, weight_seq_by_length:bool=False) :
    """
    """
    ##  Create the mask wherever the label is zero
    mask = y != mask_value
    
    ##  Calculate the loss for every token, including masked tokens
    loss = sparse_categorical_crossentropy_loss(y, y_pred)
    
    ##  Cast the mask to the same dtype as the loss values
    mask = tf.cast(mask, dtype=loss.dtype)
    
    ##  Calculate sum loss over sequence, excluding the masked values
    loss *= mask
    loss = tf.reduce_sum(loss)
    
    ##  Calculate average loss over the unmasked values if configured
    if not weight_seq_by_length :
        loss /= tf.reduce_sum(mask)
    
    ##  Return loss value
    return loss


In [3]:
def create_model(vocab_size, name, loss=sparse_categorical_crossentropy_loss, acc=sparse_accuracy_metric) :
    x_in = Input((None,))
    x    = Embedding(vocab_size, 16)(x_in)
    x    = FeedForwardBlock(16, num_hidden_layers=1, ndim_hidden=32, skip_connect=True, layer_norm=True, batch_norm=False)(x)
    x    = EncoderBlock(16, num_heads=4, ndim_hidden_mha=16, ndim_hidden_ff=16)(x)
    x    = FeedForwardBlock(vocab_size, num_hidden_layers=1, ndim_hidden=32, skip_connect=False, layer_norm=False, batch_norm=False)(x)
    model = Model(x_in, x, name=name)
    model.compile(loss=loss, 
                  optimizer="adam",
                  metrics=[acc])
    return model
    

In [4]:
x = tf.constant([[1, 2, 3, 0, 0],
                 [3, 3, 3, 3, 0],], dtype=tf.int32)
print(x.shape)

y = tf.constant([[2, 3, 4, 0, 0],
                 [3, 3, 3, 3, 0]], dtype=tf.int32)
print(y.shape)


(2, 5)
(2, 5)


In [5]:
model1 = create_model(5, "test_model_1")
model1.summary()

model2 = create_model(5, "test_model_2", loss=masked_sparse_categorical_crossentropy, acc=masked_accuracy)
model2.summary()

Model: "test_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, None)]            0         
                                                                 
 embedding (Embedding)       (None, None, 16)          80        
                                                                 
 feed_forward_block (FeedFor  (None, None, 16)         1104      
 wardBlock)                                                      
                                                                 
 encoder_block (EncoderBlock  (None, None, 16)         4912      
 )                                                               
                                                                 
 feed_forward_block_1 (FeedF  (None, None, 5)          709       
 orwardBlock)                                                    
                                                      

In [6]:
model1.fit(x, y, epochs=1, validation_data=(x, y))
model2.fit(x, y, epochs=1, validation_data=(x, y))

2023-04-08 12:53:13.394963: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz




<keras.callbacks.History at 0x15743ef80>

In [7]:
def get_loss(y, y_pred) :
    loss   = sparse_categorical_crossentropy(y, y_pred, from_logits=True).numpy()
    y      = y.numpy()
    y_pred = y_pred.numpy()
    mask   = y != 0
    y_unmasked  = loss.flatten().sum() / np.prod(y.shape)
    y_masked    = np.where(mask, loss, 0).flatten().sum() / mask.flatten().sum()
    return y_unmasked, y_masked

def get_acc(y, y_pred) :
    y           = y.numpy()
    y_pred      = y_pred.numpy().argmax(axis=-1)
    match, mask = y == y_pred, y != 0
    y_unmasked  = match.flatten().sum() / np.prod(y.shape)
    y_masked    = np.where(mask, match, 0).flatten().sum() / mask.flatten().sum()
    return y_unmasked, y_masked

In [8]:

y_pred = model1(x)

sample_weight = tf.ones_like(x)

print(model1.compute_loss(y=y, y_pred=y_pred).numpy())
print(model1.compute_loss(y=y, y_pred=y_pred, sample_weight=sample_weight).numpy())

for name, val in model1.compute_metrics(x=x, y=y, y_pred=y_pred, sample_weight=sample_weight).items() :
    print(name, val.numpy())

print(get_loss(y, y_pred))
print(get_acc(y, y_pred))

[[1.9450715 1.3341032 2.0708637 2.0139081 2.0139081]
 [1.0918609 1.0918609 1.0918609 1.0918608 1.77385  ]]
[[1.9450715 1.3341032 2.0708637 2.0139081 2.0139081]
 [1.0918609 1.0918609 1.0918609 1.0918608 1.77385  ]]
loss 1.5519147
sparse_categorical_accuracy 0.5
(1.5519147872924806, 1.388211795261928)
(0.5, 0.7142857142857143)


In [9]:

y_pred = model2(x)

sample_weight = tf.ones_like(x)

print(model2.compute_loss(y=y, y_pred=y_pred).numpy())
print(model2.compute_loss(y=y, y_pred=y_pred, sample_weight=sample_weight).numpy())

for name, val in model2.compute_metrics(x=x, y=y, y_pred=y_pred, sample_weight=sample_weight).items() :
    print(name, val.numpy())

print(get_loss(y, y_pred))
print(get_acc(y, y_pred))

1.1604853
1.1604853
loss 1.1604853
masked_accuracy 0.5714286
(1.331087303161621, 1.1604852676391602)
(0.4, 0.5714285714285714)


In [10]:
bs_indices = np.random.choice(2, size=(5,2))

for indcs in bs_indices :
    print(tf.gather(y, indcs))

tf.Tensor(
[[3 3 3 3 0]
 [2 3 4 0 0]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[3 3 3 3 0]
 [2 3 4 0 0]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[2 3 4 0 0]
 [2 3 4 0 0]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[3 3 3 3 0]
 [3 3 3 3 0]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[3 3 3 3 0]
 [3 3 3 3 0]], shape=(2, 5), dtype=int32)
