In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy.random as npr
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from keras.optimizers import Adam
from keras_nlp.layers import PositionEmbedding

In [3]:
seed = 428

np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)

In [4]:
def bert_module(query, key, value, embed_dim, num_head, i):
    
    # Multi headed self-attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_head,
        key_dim=embed_dim // num_head,
        name="encoder_{}/multiheadattention".format(i)
    )(query, key, value, use_causal_mask=True)
    
    # Add & Normalize
    attention_output = layers.Add()([query, attention_output])  # Skip Connection
    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
    
    # Feedforward network
    ff_net = keras.models.Sequential([
        layers.Dense(2 * embed_dim, activation='relu', name="encoder_{}/ffn_dense_1".format(i)),
        layers.Dense(embed_dim, name="encoder_{}/ffn_dense_2".format(i)),
    ])

    # Apply Feedforward network
    ffn_output = ff_net(attention_output)

    # Add & Normalize
    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection
    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)
    
    return ffn_output

In [5]:
def get_sinusoidal_embeddings(sequence_length, embedding_dim):
    position_enc = np.array([
        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]
        if pos != 0 else np.zeros(embedding_dim)
        for pos in range(sequence_length)
    ])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return tf.cast(position_enc, dtype=tf.float32)

In [None]:
pairs_train

In [7]:
pairs

[('word_0', 'word_1', 'word_2'),
 ('word_0', 'word_1', 'word_3'),
 ('word_0', 'word_1', 'word_4'),
 ('word_0', 'word_1', 'word_5'),
 ('word_0', 'word_1', 'word_6'),
 ('word_0', 'word_1', 'word_7'),
 ('word_0', 'word_1', 'word_8'),
 ('word_0', 'word_1', 'word_9'),
 ('word_0', 'word_1', 'word_10'),
 ('word_0', 'word_1', 'word_11'),
 ('word_0', 'word_1', 'word_12'),
 ('word_0', 'word_1', 'word_13'),
 ('word_0', 'word_1', 'word_14'),
 ('word_0', 'word_1', 'word_15'),
 ('word_0', 'word_1', 'word_16'),
 ('word_0', 'word_1', 'word_17'),
 ('word_0', 'word_1', 'word_18'),
 ('word_0', 'word_1', 'word_19'),
 ('word_0', 'word_2', 'word_1'),
 ('word_0', 'word_2', 'word_3'),
 ('word_0', 'word_2', 'word_4'),
 ('word_0', 'word_2', 'word_5'),
 ('word_0', 'word_2', 'word_6'),
 ('word_0', 'word_2', 'word_7'),
 ('word_0', 'word_2', 'word_8'),
 ('word_0', 'word_2', 'word_9'),
 ('word_0', 'word_2', 'word_10'),
 ('word_0', 'word_2', 'word_11'),
 ('word_0', 'word_2', 'word_12'),
 ('word_0', 'word_2', 'word_13

In [12]:
N = 20 # vocab_size

vocabs = ['word_' + str(i) for i in range(N)]

vocab_map = {}
for i in range(len(vocabs)):
    vocab_map[vocabs[i]] = i
    
pairs = []

for i in vocabs:
    for j in vocabs:
        for k in vocabs:
            if i != j and i != k and j != k:
                pairs.append((i,j,k))
            
#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])

# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]
# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]

pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]
pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]

In [17]:
sentences_train = []
sentences_number_train = []
sentences_test = []
sentences_number_test = []

for pair in pairs_train:
    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])
    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
for pair in pairs_test:
    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])
    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
x_masked_train = []
y_masked_labels_train = []
x_masked_test = []
y_masked_labels_test = []

for pair in pairs_train:
    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
for pair in pairs_test:
    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
x_masked_train = np.array(x_masked_train)
y_masked_labels_train = np.array(y_masked_labels_train)
x_masked_test = np.array(x_masked_test)
y_masked_labels_test = np.array(y_masked_labels_test)

perm = np.random.permutation(len(x_masked_train))
x_masked_train = x_masked_train[perm]
y_masked_labels_train = y_masked_labels_train[perm]

In [19]:
embed_dim = 10
num_head = 2

callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)
word_embeddings = layers.Embedding(N, embed_dim, name="word_embedding")(inputs)
position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)
encoder_output = word_embeddings + position_embeddings

for i in range(1):
    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)

encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)
mlm_output = layers.Dense(N, name="mlm_cls", activation="softmax")(encoder_output)
mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)
adam = Adam()
mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)
history = mlm_model.fit(x_masked_train, y_masked_labels_train,
                        validation_split = 0.5, callbacks = [callback], 
                        epochs=2000, batch_size=5000, 
                        verbose=2)

Epoch 1/2000
1/1 - 1s - loss: 3.3022 - val_loss: 3.2990 - 1s/epoch - 1s/step
Epoch 2/2000
1/1 - 0s - loss: 3.2886 - val_loss: 3.2856 - 41ms/epoch - 41ms/step
Epoch 3/2000
1/1 - 0s - loss: 3.2754 - val_loss: 3.2725 - 41ms/epoch - 41ms/step
Epoch 4/2000
1/1 - 0s - loss: 3.2625 - val_loss: 3.2598 - 43ms/epoch - 43ms/step
Epoch 5/2000
1/1 - 0s - loss: 3.2500 - val_loss: 3.2474 - 41ms/epoch - 41ms/step
Epoch 6/2000
1/1 - 0s - loss: 3.2378 - val_loss: 3.2355 - 41ms/epoch - 41ms/step
Epoch 7/2000
1/1 - 0s - loss: 3.2260 - val_loss: 3.2239 - 42ms/epoch - 42ms/step
Epoch 8/2000
1/1 - 0s - loss: 3.2146 - val_loss: 3.2126 - 42ms/epoch - 42ms/step
Epoch 9/2000
1/1 - 0s - loss: 3.2035 - val_loss: 3.2015 - 42ms/epoch - 42ms/step
Epoch 10/2000
1/1 - 0s - loss: 3.1926 - val_loss: 3.1907 - 41ms/epoch - 41ms/step
Epoch 11/2000
1/1 - 0s - loss: 3.1820 - val_loss: 3.1802 - 41ms/epoch - 41ms/step
Epoch 12/2000
1/1 - 0s - loss: 3.1717 - val_loss: 3.1701 - 43ms/epoch - 43ms/step
Epoch 13/2000
1/1 - 0s - loss

1/1 - 0s - loss: 2.8685 - val_loss: 2.8836 - 40ms/epoch - 40ms/step
Epoch 102/2000
1/1 - 0s - loss: 2.8664 - val_loss: 2.8816 - 40ms/epoch - 40ms/step
Epoch 103/2000
1/1 - 0s - loss: 2.8642 - val_loss: 2.8795 - 43ms/epoch - 43ms/step
Epoch 104/2000
1/1 - 0s - loss: 2.8619 - val_loss: 2.8772 - 43ms/epoch - 43ms/step
Epoch 105/2000
1/1 - 0s - loss: 2.8595 - val_loss: 2.8750 - 40ms/epoch - 40ms/step
Epoch 106/2000
1/1 - 0s - loss: 2.8571 - val_loss: 2.8726 - 39ms/epoch - 39ms/step
Epoch 107/2000
1/1 - 0s - loss: 2.8545 - val_loss: 2.8701 - 43ms/epoch - 43ms/step
Epoch 108/2000
1/1 - 0s - loss: 2.8519 - val_loss: 2.8675 - 43ms/epoch - 43ms/step
Epoch 109/2000
1/1 - 0s - loss: 2.8491 - val_loss: 2.8648 - 42ms/epoch - 42ms/step
Epoch 110/2000
1/1 - 0s - loss: 2.8462 - val_loss: 2.8620 - 40ms/epoch - 40ms/step
Epoch 111/2000
1/1 - 0s - loss: 2.8432 - val_loss: 2.8590 - 42ms/epoch - 42ms/step
Epoch 112/2000
1/1 - 0s - loss: 2.8401 - val_loss: 2.8559 - 43ms/epoch - 43ms/step
Epoch 113/2000
1/1 

Epoch 200/2000
1/1 - 0s - loss: 2.3807 - val_loss: 2.3852 - 42ms/epoch - 42ms/step
Epoch 201/2000
1/1 - 0s - loss: 2.3764 - val_loss: 2.3808 - 41ms/epoch - 41ms/step
Epoch 202/2000
1/1 - 0s - loss: 2.3720 - val_loss: 2.3763 - 41ms/epoch - 41ms/step
Epoch 203/2000
1/1 - 0s - loss: 2.3677 - val_loss: 2.3719 - 42ms/epoch - 42ms/step
Epoch 204/2000
1/1 - 0s - loss: 2.3634 - val_loss: 2.3675 - 42ms/epoch - 42ms/step
Epoch 205/2000
1/1 - 0s - loss: 2.3590 - val_loss: 2.3631 - 41ms/epoch - 41ms/step
Epoch 206/2000
1/1 - 0s - loss: 2.3547 - val_loss: 2.3587 - 42ms/epoch - 42ms/step
Epoch 207/2000
1/1 - 0s - loss: 2.3504 - val_loss: 2.3543 - 41ms/epoch - 41ms/step
Epoch 208/2000
1/1 - 0s - loss: 2.3462 - val_loss: 2.3499 - 42ms/epoch - 42ms/step
Epoch 209/2000
1/1 - 0s - loss: 2.3419 - val_loss: 2.3456 - 41ms/epoch - 41ms/step
Epoch 210/2000
1/1 - 0s - loss: 2.3377 - val_loss: 2.3413 - 42ms/epoch - 42ms/step
Epoch 211/2000
1/1 - 0s - loss: 2.3335 - val_loss: 2.3370 - 41ms/epoch - 41ms/step
Epoc

Epoch 299/2000
1/1 - 0s - loss: 2.1135 - val_loss: 2.1173 - 40ms/epoch - 40ms/step
Epoch 300/2000
1/1 - 0s - loss: 2.1123 - val_loss: 2.1161 - 41ms/epoch - 41ms/step
Epoch 301/2000
1/1 - 0s - loss: 2.1110 - val_loss: 2.1148 - 43ms/epoch - 43ms/step
Epoch 302/2000
1/1 - 0s - loss: 2.1097 - val_loss: 2.1136 - 43ms/epoch - 43ms/step
Epoch 303/2000
1/1 - 0s - loss: 2.1085 - val_loss: 2.1123 - 40ms/epoch - 40ms/step
Epoch 304/2000
1/1 - 0s - loss: 2.1072 - val_loss: 2.1111 - 40ms/epoch - 40ms/step
Epoch 305/2000
1/1 - 0s - loss: 2.1060 - val_loss: 2.1099 - 43ms/epoch - 43ms/step
Epoch 306/2000
1/1 - 0s - loss: 2.1048 - val_loss: 2.1087 - 42ms/epoch - 42ms/step
Epoch 307/2000
1/1 - 0s - loss: 2.1036 - val_loss: 2.1076 - 40ms/epoch - 40ms/step
Epoch 308/2000
1/1 - 0s - loss: 2.1025 - val_loss: 2.1064 - 39ms/epoch - 39ms/step
Epoch 309/2000
1/1 - 0s - loss: 2.1014 - val_loss: 2.1053 - 43ms/epoch - 43ms/step
Epoch 310/2000
1/1 - 0s - loss: 2.1002 - val_loss: 2.1042 - 43ms/epoch - 43ms/step
Epoc

Epoch 398/2000
1/1 - 0s - loss: 2.0422 - val_loss: 2.0467 - 42ms/epoch - 42ms/step
Epoch 399/2000
1/1 - 0s - loss: 2.0419 - val_loss: 2.0464 - 41ms/epoch - 41ms/step
Epoch 400/2000
1/1 - 0s - loss: 2.0415 - val_loss: 2.0460 - 41ms/epoch - 41ms/step
Epoch 401/2000
1/1 - 0s - loss: 2.0412 - val_loss: 2.0457 - 42ms/epoch - 42ms/step
Epoch 402/2000
1/1 - 0s - loss: 2.0408 - val_loss: 2.0454 - 42ms/epoch - 42ms/step
Epoch 403/2000
1/1 - 0s - loss: 2.0405 - val_loss: 2.0450 - 43ms/epoch - 43ms/step
Epoch 404/2000
1/1 - 0s - loss: 2.0402 - val_loss: 2.0447 - 40ms/epoch - 40ms/step
Epoch 405/2000
1/1 - 0s - loss: 2.0398 - val_loss: 2.0443 - 41ms/epoch - 41ms/step
Epoch 406/2000
1/1 - 0s - loss: 2.0395 - val_loss: 2.0440 - 42ms/epoch - 42ms/step
Epoch 407/2000
1/1 - 0s - loss: 2.0392 - val_loss: 2.0437 - 42ms/epoch - 42ms/step
Epoch 408/2000
1/1 - 0s - loss: 2.0389 - val_loss: 2.0434 - 41ms/epoch - 41ms/step
Epoch 409/2000
1/1 - 0s - loss: 2.0385 - val_loss: 2.0430 - 41ms/epoch - 41ms/step
Epoc

Epoch 497/2000
1/1 - 0s - loss: 2.0201 - val_loss: 2.0247 - 40ms/epoch - 40ms/step
Epoch 498/2000
1/1 - 0s - loss: 2.0199 - val_loss: 2.0246 - 41ms/epoch - 41ms/step
Epoch 499/2000
1/1 - 0s - loss: 2.0198 - val_loss: 2.0245 - 43ms/epoch - 43ms/step
Epoch 500/2000
1/1 - 0s - loss: 2.0197 - val_loss: 2.0243 - 41ms/epoch - 41ms/step
Epoch 501/2000
1/1 - 0s - loss: 2.0195 - val_loss: 2.0242 - 40ms/epoch - 40ms/step
Epoch 502/2000
1/1 - 0s - loss: 2.0194 - val_loss: 2.0241 - 41ms/epoch - 41ms/step
Epoch 503/2000
1/1 - 0s - loss: 2.0193 - val_loss: 2.0239 - 43ms/epoch - 43ms/step
Epoch 504/2000
1/1 - 0s - loss: 2.0191 - val_loss: 2.0238 - 41ms/epoch - 41ms/step
Epoch 505/2000
1/1 - 0s - loss: 2.0190 - val_loss: 2.0237 - 40ms/epoch - 40ms/step
Epoch 506/2000
1/1 - 0s - loss: 2.0189 - val_loss: 2.0236 - 41ms/epoch - 41ms/step
Epoch 507/2000
1/1 - 0s - loss: 2.0188 - val_loss: 2.0234 - 43ms/epoch - 43ms/step
Epoch 508/2000
1/1 - 0s - loss: 2.0186 - val_loss: 2.0233 - 42ms/epoch - 42ms/step
Epoc

Epoch 596/2000
1/1 - 0s - loss: 2.0103 - val_loss: 2.0153 - 43ms/epoch - 43ms/step
Epoch 597/2000
1/1 - 0s - loss: 2.0102 - val_loss: 2.0153 - 42ms/epoch - 42ms/step
Epoch 598/2000
1/1 - 0s - loss: 2.0102 - val_loss: 2.0152 - 40ms/epoch - 40ms/step
Epoch 599/2000
1/1 - 0s - loss: 2.0101 - val_loss: 2.0151 - 42ms/epoch - 42ms/step
Epoch 600/2000
1/1 - 0s - loss: 2.0100 - val_loss: 2.0151 - 43ms/epoch - 43ms/step
Epoch 601/2000
1/1 - 0s - loss: 2.0099 - val_loss: 2.0150 - 41ms/epoch - 41ms/step
Epoch 602/2000
1/1 - 0s - loss: 2.0099 - val_loss: 2.0150 - 40ms/epoch - 40ms/step
Epoch 603/2000
1/1 - 0s - loss: 2.0098 - val_loss: 2.0149 - 40ms/epoch - 40ms/step
Epoch 604/2000
1/1 - 0s - loss: 2.0097 - val_loss: 2.0148 - 43ms/epoch - 43ms/step
Epoch 605/2000
1/1 - 0s - loss: 2.0097 - val_loss: 2.0148 - 42ms/epoch - 42ms/step
Epoch 606/2000
1/1 - 0s - loss: 2.0096 - val_loss: 2.0147 - 40ms/epoch - 40ms/step
Epoch 607/2000
1/1 - 0s - loss: 2.0095 - val_loss: 2.0146 - 41ms/epoch - 41ms/step
Epoc

Epoch 695/2000
1/1 - 0s - loss: 2.0048 - val_loss: 2.0106 - 41ms/epoch - 41ms/step
Epoch 696/2000
1/1 - 0s - loss: 2.0048 - val_loss: 2.0105 - 41ms/epoch - 41ms/step
Epoch 697/2000
1/1 - 0s - loss: 2.0047 - val_loss: 2.0105 - 43ms/epoch - 43ms/step
Epoch 698/2000
1/1 - 0s - loss: 2.0047 - val_loss: 2.0105 - 41ms/epoch - 41ms/step
Epoch 699/2000
1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 41ms/epoch - 41ms/step
Epoch 700/2000
1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 41ms/epoch - 41ms/step
Epoch 701/2000
1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 43ms/epoch - 43ms/step
Epoch 702/2000
1/1 - 0s - loss: 2.0045 - val_loss: 2.0103 - 41ms/epoch - 41ms/step
Epoch 703/2000
1/1 - 0s - loss: 2.0045 - val_loss: 2.0103 - 41ms/epoch - 41ms/step
Epoch 704/2000
1/1 - 0s - loss: 2.0044 - val_loss: 2.0103 - 41ms/epoch - 41ms/step
Epoch 705/2000
1/1 - 0s - loss: 2.0044 - val_loss: 2.0103 - 43ms/epoch - 43ms/step
Epoch 706/2000
1/1 - 0s - loss: 2.0044 - val_loss: 2.0102 - 41ms/epoch - 41ms/step
Epoc

In [20]:
acc = []
prob = []
x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]

for sentence_number in x_test_subset:
    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \
        (np.array(sentence_number).reshape(1,len(sentence_number)))
    temp = temp[:,-1,:]
    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)
    prob.append(temp[0][sentence_number[-1]])

In [21]:
(np.mean(acc), np.mean(prob))

(0.0, 0.002540875)

In [25]:
embed_dim = 100
num_head = 2

callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)
word_embeddings = layers.Embedding(N, embed_dim, name="word_embedding")(inputs)
position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)
encoder_output = word_embeddings + position_embeddings

for i in range(1):
    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)

encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)
mlm_output = layers.Dense(N, name="mlm_cls", activation="softmax")(encoder_output)
mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)
adam = Adam()
mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)
history = mlm_model.fit(x_masked_train, y_masked_labels_train,
                        validation_split = 0.5, callbacks = [callback], 
                        epochs=2000, batch_size=5000, 
                        verbose=2)

Epoch 1/2000
1/1 - 4s - loss: 3.6708 - val_loss: 3.1908 - 4s/epoch - 4s/step
Epoch 2/2000
1/1 - 0s - loss: 3.1555 - val_loss: 3.0720 - 230ms/epoch - 230ms/step
Epoch 3/2000
1/1 - 0s - loss: 3.0556 - val_loss: 2.9671 - 234ms/epoch - 234ms/step
Epoch 4/2000
1/1 - 0s - loss: 2.9613 - val_loss: 2.9044 - 228ms/epoch - 228ms/step
Epoch 5/2000
1/1 - 0s - loss: 2.8997 - val_loss: 2.8759 - 230ms/epoch - 230ms/step
Epoch 6/2000
1/1 - 0s - loss: 2.8680 - val_loss: 2.8542 - 227ms/epoch - 227ms/step
Epoch 7/2000
1/1 - 0s - loss: 2.8421 - val_loss: 2.8374 - 224ms/epoch - 224ms/step
Epoch 8/2000
1/1 - 0s - loss: 2.8214 - val_loss: 2.8259 - 222ms/epoch - 222ms/step
Epoch 9/2000
1/1 - 0s - loss: 2.8076 - val_loss: 2.8151 - 227ms/epoch - 227ms/step
Epoch 10/2000
1/1 - 0s - loss: 2.7961 - val_loss: 2.8058 - 225ms/epoch - 225ms/step
Epoch 11/2000
1/1 - 0s - loss: 2.7876 - val_loss: 2.7992 - 229ms/epoch - 229ms/step
Epoch 12/2000
1/1 - 0s - loss: 2.7819 - val_loss: 2.7937 - 227ms/epoch - 227ms/step
Epoch 1

In [26]:
acc = []
prob = []
x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]

for sentence_number in x_test_subset:
    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \
        (np.array(sentence_number).reshape(1,len(sentence_number)))
    temp = temp[:,-1,:]
    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)
    prob.append(temp[0][sentence_number[-1]])

In [27]:
(np.mean(acc), np.mean(prob))

(0.0, 0.0052961786)

In [28]:
x_masked_test

array([[10,  0,  1, 10],
       [10,  0,  2, 10],
       [10,  0,  3, 10],
       ...,
       [19, 18, 15, 19],
       [19, 18, 16, 19],
       [19, 18, 17, 19]])