In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy.random as npr
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from keras.optimizers import Adam
from keras_nlp.layers import PositionEmbedding

In [5]:
seed = 428

np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)

In [6]:
def bert_module(query, key, value, embed_dim, num_head, i):
    
    # Multi headed self-attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_head,
        key_dim=embed_dim // num_head,
        name="encoder_{}/multiheadattention".format(i)
    )(query, key, value, use_causal_mask=True)
    
    # Add & Normalize
    attention_output = layers.Add()([query, attention_output])  # Skip Connection
    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
    
    # Feedforward network
    ff_net = keras.models.Sequential([
        layers.Dense(2 * embed_dim, activation='relu', name="encoder_{}/ffn_dense_1".format(i)),
        layers.Dense(embed_dim, name="encoder_{}/ffn_dense_2".format(i)),
    ])

    # Apply Feedforward network
    ffn_output = ff_net(attention_output)

    # Add & Normalize
    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection
    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)
    
    return ffn_output

In [7]:
def get_sinusoidal_embeddings(sequence_length, embedding_dim):
    position_enc = np.array([
        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]
        if pos != 0 else np.zeros(embedding_dim)
        for pos in range(sequence_length)
    ])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return tf.cast(position_enc, dtype=tf.float32)

In [8]:
N = 20 # vocab_size

vocabs = ['word_' + str(i) for i in range(N)]

vocab_map = {}
for i in range(len(vocabs)):
    vocab_map[vocabs[i]] = i
    
pairs = []

for i in vocabs:
    for j in vocabs:
        for k in vocabs:
            if i != j and i != k and j != k:
                pairs.append((i,j,k))
            
indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])

pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]
pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]

In [9]:
sentences_train = []
sentences_number_train = []
sentences_test = []
sentences_number_test = []

for pair in pairs_train:
    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])
    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
for pair in pairs_test:
    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])
    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
x_masked_train = []
y_masked_labels_train = []
x_masked_test = []
y_masked_labels_test = []

for pair in pairs_train:
    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
for pair in pairs_test:
    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])
    
x_masked_train = np.array(x_masked_train)
y_masked_labels_train = np.array(y_masked_labels_train)
x_masked_test = np.array(x_masked_test)
y_masked_labels_test = np.array(y_masked_labels_test)

perm = np.random.permutation(len(x_masked_train))
x_masked_train = x_masked_train[perm]
y_masked_labels_train = y_masked_labels_train[perm]

In [10]:
embed_dim = 10
num_head = 2

callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)
word_embeddings = layers.Embedding(N, embed_dim, name="word_embedding")(inputs)
position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)
encoder_output = word_embeddings + position_embeddings

for i in range(1):
    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)

encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)
mlm_output = layers.Dense(N, name="mlm_cls", activation="softmax")(encoder_output)
mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)
adam = Adam()
mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)
history = mlm_model.fit(x_masked_train, y_masked_labels_train,
                        validation_split = 0.5, callbacks = [callback], 
                        epochs=2000, batch_size=5000, 
                        verbose=2)

2024-05-02 17:57:37.181665: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2024-05-02 17:57:37.181755: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
2024-05-02 17:57:37.181780: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3384.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist
2024-05-02 17:57:37.182089: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Epoch 1/2000
1/1 - 5s - loss: 3.2596 - val_loss: 3.2190 - 5s/epoch - 5s/step
Epoch 2/2000
1/1 - 0s - loss: 3.2195 - val_loss: 3.1927 - 147ms/epoch - 147ms/step
Epoch 3/2000
1/1 - 0s - loss: 3.1926 - val_loss: 3.1738 - 149ms/epoch - 149ms/step
Epoch 4/2000
1/1 - 0s - loss: 3.1731 - val_loss: 3.1591 - 135ms/epoch - 135ms/step
Epoch 5/2000
1/1 - 0s - loss: 3.1580 - val_loss: 3.1468 - 138ms/epoch - 138ms/step
Epoch 6/2000
1/1 - 0s - loss: 3.1449 - val_loss: 3.1362 - 160ms/epoch - 160ms/step
Epoch 7/2000
1/1 - 0s - loss: 3.1334 - val_loss: 3.1273 - 141ms/epoch - 141ms/step
Epoch 8/2000
1/1 - 0s - loss: 3.1236 - val_loss: 3.1202 - 144ms/epoch - 144ms/step
Epoch 9/2000
1/1 - 0s - loss: 3.1155 - val_loss: 3.1145 - 144ms/epoch - 144ms/step
Epoch 10/2000
1/1 - 0s - loss: 3.1089 - val_loss: 3.1101 - 137ms/epoch - 137ms/step
Epoch 11/2000
1/1 - 0s - loss: 3.1036 - val_loss: 3.1062 - 135ms/epoch - 135ms/step
Epoch 12/2000
1/1 - 0s - loss: 3.0988 - val_loss: 3.1026 - 143ms/epoch - 143ms/step
Epoch 1

Epoch 99/2000
1/1 - 0s - loss: 2.7444 - val_loss: 2.7546 - 154ms/epoch - 154ms/step
Epoch 100/2000
1/1 - 0s - loss: 2.7405 - val_loss: 2.7467 - 152ms/epoch - 152ms/step
Epoch 101/2000
1/1 - 0s - loss: 2.7330 - val_loss: 2.7402 - 151ms/epoch - 151ms/step
Epoch 102/2000
1/1 - 0s - loss: 2.7257 - val_loss: 2.7379 - 163ms/epoch - 163ms/step
Epoch 103/2000
1/1 - 0s - loss: 2.7224 - val_loss: 2.7326 - 174ms/epoch - 174ms/step
Epoch 104/2000
1/1 - 0s - loss: 2.7173 - val_loss: 2.7229 - 153ms/epoch - 153ms/step
Epoch 105/2000
1/1 - 0s - loss: 2.7067 - val_loss: 2.7200 - 153ms/epoch - 153ms/step
Epoch 106/2000
1/1 - 0s - loss: 2.7025 - val_loss: 2.7160 - 157ms/epoch - 157ms/step
Epoch 107/2000
1/1 - 0s - loss: 2.6992 - val_loss: 2.7059 - 144ms/epoch - 144ms/step
Epoch 108/2000
1/1 - 0s - loss: 2.6871 - val_loss: 2.7001 - 140ms/epoch - 140ms/step
Epoch 109/2000
1/1 - 0s - loss: 2.6807 - val_loss: 2.6956 - 139ms/epoch - 139ms/step
Epoch 110/2000
1/1 - 0s - loss: 2.6774 - val_loss: 2.6866 - 144ms/

Epoch 196/2000
1/1 - 0s - loss: 2.2968 - val_loss: 2.3140 - 190ms/epoch - 190ms/step
Epoch 197/2000
1/1 - 0s - loss: 2.2945 - val_loss: 2.3117 - 165ms/epoch - 165ms/step
Epoch 198/2000
1/1 - 0s - loss: 2.2916 - val_loss: 2.3082 - 154ms/epoch - 154ms/step
Epoch 199/2000
1/1 - 0s - loss: 2.2886 - val_loss: 2.3055 - 138ms/epoch - 138ms/step
Epoch 200/2000
1/1 - 0s - loss: 2.2856 - val_loss: 2.3028 - 163ms/epoch - 163ms/step
Epoch 201/2000
1/1 - 0s - loss: 2.2829 - val_loss: 2.3003 - 129ms/epoch - 129ms/step
Epoch 202/2000
1/1 - 0s - loss: 2.2806 - val_loss: 2.2985 - 135ms/epoch - 135ms/step
Epoch 203/2000
1/1 - 0s - loss: 2.2784 - val_loss: 2.2963 - 132ms/epoch - 132ms/step
Epoch 204/2000
1/1 - 0s - loss: 2.2767 - val_loss: 2.2959 - 133ms/epoch - 133ms/step
Epoch 205/2000
1/1 - 0s - loss: 2.2754 - val_loss: 2.2950 - 123ms/epoch - 123ms/step
Epoch 206/2000
1/1 - 0s - loss: 2.2754 - val_loss: 2.2945 - 125ms/epoch - 125ms/step
Epoch 207/2000
1/1 - 0s - loss: 2.2737 - val_loss: 2.2914 - 125ms

Epoch 293/2000
1/1 - 0s - loss: 2.1346 - val_loss: 2.1587 - 125ms/epoch - 125ms/step
Epoch 294/2000
1/1 - 0s - loss: 2.1336 - val_loss: 2.1576 - 120ms/epoch - 120ms/step
Epoch 295/2000
1/1 - 0s - loss: 2.1326 - val_loss: 2.1566 - 122ms/epoch - 122ms/step
Epoch 296/2000
1/1 - 0s - loss: 2.1314 - val_loss: 2.1555 - 121ms/epoch - 121ms/step
Epoch 297/2000
1/1 - 0s - loss: 2.1304 - val_loss: 2.1545 - 119ms/epoch - 119ms/step
Epoch 298/2000
1/1 - 0s - loss: 2.1294 - val_loss: 2.1537 - 154ms/epoch - 154ms/step
Epoch 299/2000
1/1 - 0s - loss: 2.1284 - val_loss: 2.1526 - 115ms/epoch - 115ms/step
Epoch 300/2000
1/1 - 0s - loss: 2.1274 - val_loss: 2.1517 - 118ms/epoch - 118ms/step
Epoch 301/2000
1/1 - 0s - loss: 2.1263 - val_loss: 2.1506 - 114ms/epoch - 114ms/step
Epoch 302/2000
1/1 - 0s - loss: 2.1253 - val_loss: 2.1497 - 112ms/epoch - 112ms/step
Epoch 303/2000
1/1 - 0s - loss: 2.1243 - val_loss: 2.1488 - 168ms/epoch - 168ms/step
Epoch 304/2000
1/1 - 0s - loss: 2.1233 - val_loss: 2.1478 - 152ms

Epoch 390/2000
1/1 - 0s - loss: 2.0591 - val_loss: 2.0901 - 130ms/epoch - 130ms/step
Epoch 391/2000
1/1 - 0s - loss: 2.0586 - val_loss: 2.0896 - 127ms/epoch - 127ms/step
Epoch 392/2000
1/1 - 0s - loss: 2.0580 - val_loss: 2.0892 - 140ms/epoch - 140ms/step
Epoch 393/2000
1/1 - 0s - loss: 2.0575 - val_loss: 2.0887 - 131ms/epoch - 131ms/step
Epoch 394/2000
1/1 - 0s - loss: 2.0570 - val_loss: 2.0883 - 121ms/epoch - 121ms/step
Epoch 395/2000
1/1 - 0s - loss: 2.0565 - val_loss: 2.0878 - 126ms/epoch - 126ms/step
Epoch 396/2000
1/1 - 0s - loss: 2.0560 - val_loss: 2.0875 - 132ms/epoch - 132ms/step
Epoch 397/2000
1/1 - 0s - loss: 2.0555 - val_loss: 2.0869 - 131ms/epoch - 131ms/step
Epoch 398/2000
1/1 - 0s - loss: 2.0550 - val_loss: 2.0868 - 128ms/epoch - 128ms/step
Epoch 399/2000
1/1 - 0s - loss: 2.0545 - val_loss: 2.0862 - 125ms/epoch - 125ms/step
Epoch 400/2000
1/1 - 0s - loss: 2.0541 - val_loss: 2.0861 - 136ms/epoch - 136ms/step
Epoch 401/2000
1/1 - 0s - loss: 2.0536 - val_loss: 2.0854 - 151ms

Epoch 487/2000
1/1 - 0s - loss: 2.0211 - val_loss: 2.0608 - 122ms/epoch - 122ms/step
Epoch 488/2000
1/1 - 0s - loss: 2.0209 - val_loss: 2.0609 - 103ms/epoch - 103ms/step
Epoch 489/2000
1/1 - 0s - loss: 2.0206 - val_loss: 2.0606 - 114ms/epoch - 114ms/step
Epoch 490/2000
1/1 - 0s - loss: 2.0203 - val_loss: 2.0604 - 134ms/epoch - 134ms/step
Epoch 491/2000
1/1 - 0s - loss: 2.0200 - val_loss: 2.0602 - 122ms/epoch - 122ms/step
Epoch 492/2000
1/1 - 0s - loss: 2.0197 - val_loss: 2.0600 - 125ms/epoch - 125ms/step
Epoch 493/2000
1/1 - 0s - loss: 2.0194 - val_loss: 2.0599 - 147ms/epoch - 147ms/step
Epoch 494/2000
1/1 - 0s - loss: 2.0192 - val_loss: 2.0597 - 126ms/epoch - 126ms/step
Epoch 495/2000
1/1 - 0s - loss: 2.0189 - val_loss: 2.0597 - 116ms/epoch - 116ms/step
Epoch 496/2000
1/1 - 0s - loss: 2.0187 - val_loss: 2.0593 - 119ms/epoch - 119ms/step
Epoch 497/2000
1/1 - 0s - loss: 2.0184 - val_loss: 2.0593 - 147ms/epoch - 147ms/step
Epoch 498/2000
1/1 - 0s - loss: 2.0181 - val_loss: 2.0590 - 114ms

Epoch 584/2000
1/1 - 0s - loss: 2.0000 - val_loss: 2.0509 - 108ms/epoch - 108ms/step
Epoch 585/2000
1/1 - 0s - loss: 1.9998 - val_loss: 2.0510 - 104ms/epoch - 104ms/step
Epoch 586/2000
1/1 - 0s - loss: 1.9997 - val_loss: 2.0508 - 114ms/epoch - 114ms/step
Epoch 587/2000
1/1 - 0s - loss: 1.9995 - val_loss: 2.0508 - 104ms/epoch - 104ms/step
Epoch 588/2000
1/1 - 0s - loss: 1.9993 - val_loss: 2.0504 - 124ms/epoch - 124ms/step
Epoch 589/2000
1/1 - 0s - loss: 1.9991 - val_loss: 2.0504 - 115ms/epoch - 115ms/step
Epoch 590/2000
1/1 - 0s - loss: 1.9990 - val_loss: 2.0506 - 112ms/epoch - 112ms/step
Epoch 591/2000
1/1 - 0s - loss: 1.9988 - val_loss: 2.0504 - 129ms/epoch - 129ms/step
Epoch 592/2000
1/1 - 0s - loss: 1.9988 - val_loss: 2.0508 - 125ms/epoch - 125ms/step
Epoch 593/2000
1/1 - 0s - loss: 1.9987 - val_loss: 2.0505 - 135ms/epoch - 135ms/step
Epoch 594/2000
1/1 - 0s - loss: 1.9984 - val_loss: 2.0503 - 141ms/epoch - 141ms/step
Epoch 595/2000
1/1 - 0s - loss: 1.9981 - val_loss: 2.0506 - 162ms

In [11]:
acc = []
prob = []
x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]

for sentence_number in x_test_subset:
    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \
        (np.array(sentence_number).reshape(1,len(sentence_number)))
    temp = temp[:,-1,:]
    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)
    prob.append(temp[0][sentence_number[-1]])

In [12]:
(np.mean(acc), np.mean(prob))

(1.0, 0.92229855)

In [26]:
embed_dim = 100
num_head = 2

callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)
word_embeddings = layers.Embedding(N, embed_dim, name="word_embedding")(inputs)
position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)
encoder_output = word_embeddings + position_embeddings

for i in range(1):
    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)

encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)
mlm_output = layers.Dense(N, name="mlm_cls", activation="softmax")(encoder_output)
mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)
adam = Adam()
mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)
history = mlm_model.fit(x_masked_train, y_masked_labels_train,
                        validation_split = 0.5, callbacks = [callback], 
                        epochs=2000, batch_size=5000, 
                        verbose=2)

Epoch 1/2000
1/1 - 1s - loss: 3.7903 - val_loss: 3.3893 - 1s/epoch - 1s/step
Epoch 2/2000
1/1 - 0s - loss: 3.3982 - val_loss: 3.2298 - 76ms/epoch - 76ms/step
Epoch 3/2000
1/1 - 0s - loss: 3.2282 - val_loss: 3.1684 - 78ms/epoch - 78ms/step
Epoch 4/2000
1/1 - 0s - loss: 3.1586 - val_loss: 3.1380 - 81ms/epoch - 81ms/step
Epoch 5/2000
1/1 - 0s - loss: 3.1217 - val_loss: 3.1117 - 74ms/epoch - 74ms/step
Epoch 6/2000
1/1 - 0s - loss: 3.0913 - val_loss: 3.0886 - 75ms/epoch - 75ms/step
Epoch 7/2000
1/1 - 0s - loss: 3.0656 - val_loss: 3.0736 - 78ms/epoch - 78ms/step
Epoch 8/2000
1/1 - 0s - loss: 3.0494 - val_loss: 3.0661 - 72ms/epoch - 72ms/step
Epoch 9/2000
1/1 - 0s - loss: 3.0413 - val_loss: 3.0600 - 80ms/epoch - 80ms/step
Epoch 10/2000
1/1 - 0s - loss: 3.0347 - val_loss: 3.0518 - 79ms/epoch - 79ms/step
Epoch 11/2000
1/1 - 0s - loss: 3.0261 - val_loss: 3.0427 - 71ms/epoch - 71ms/step
Epoch 12/2000
1/1 - 0s - loss: 3.0164 - val_loss: 3.0350 - 75ms/epoch - 75ms/step
Epoch 13/2000
1/1 - 0s - loss

In [27]:
acc = []
prob = []
x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]

for sentence_number in x_test_subset:
    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \
        (np.array(sentence_number).reshape(1,len(sentence_number)))
    temp = temp[:,-1,:]
    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)
    prob.append(temp[0][sentence_number[-1]])

In [28]:
(np.mean(acc), np.mean(prob))

(1.0, 0.97649634)