In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy.random as npr
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from keras.optimizers import Adam
from keras_nlp.layers import PositionEmbedding

In [3]:
seed = 428

np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)

In [15]:
def get_masked_input_and_labels(encoded_texts, n_cat):
    # For each sentence, mask each word one-by-one

    encoded_texts_masked = []
    y_labels = []

    for encoded_text in encoded_texts:
        for i in range(len(encoded_text)):
            encoded_text_masked = np.copy(encoded_text)
            y_label = encoded_text_masked[i]
            encoded_texts_masked.append(np.delete(encoded_text_masked, i))
            y_labels.append(np.array([y_label]))

    return np.array(encoded_texts_masked), np.array(y_labels)

In [16]:
#### K = number of countries = number of capitals = number of currencies
#### M = number of words only used by each topic
#### S = number of words used by both topics
#### L = sentence length
#### q1, q2 = probability of having 1 or 2 pairs
#### embed_dim = dimension of embeddings
#### n_sentences = number of training sentences

def train_model(K, M, S, L, q1, q2, embed_dim, n_sentences):
    
    countries = ['country_' + str(i) for i in range(K)]
    capitals = ['capital_' + str(i) for i in range(K)]
    currencies = ['currency_' + str(i) for i in range(K)]
    random_capitals = ['random_capital_' + str(i) for i in range(M)]
    random_currencies = ['random_currency_' + str(i) for i in range(M)]
    randoms = ['random_' + str(i) for i in range(S)]

    vocabs = countries + capitals + currencies + random_capitals + random_currencies + randoms
    vocab_map = {}

    for i in range(len(vocabs)):
        vocab_map[vocabs[i]] = i
        
    sentences = []
    sentences_number = []
    
    q0 = 1 - q1 - q2

    for i in range(n_sentences):

        sentence = []
        
        temp = npr.uniform()
        temp2 = npr.uniform()
        
        if temp2 <= q0:
            n_pairs = 0
        elif temp2 <= q0 + q1:
            n_pairs = 1
        else:
            n_pairs = 2
        
        if temp <= 0.5: ### country - capital
        
            pairs = np.random.choice(np.arange(K), n_pairs, replace = False)
#             pairs2 = np.random.choice(np.arange(K), n_pairs, replace = False)
            for pair in pairs:
                sentence.append(countries[pair])
                sentence.append(capitals[pair])
#             for pair in pairs2:
#                 temp3 = npr.uniform()
#                 if temp3 <= 0.8:
#                     sentence.append(random_capitals[pair])
#                 else:
#                     sentence.append(random_currencies[pair])

            randoms_dup = 4 * random_capitals + 2 * randoms + 1 * random_currencies
            sentence += list(np.random.choice(randoms_dup, L - 2 * n_pairs, replace = False))  
                 
        else: ### country - currency
            
            pairs = np.random.choice(np.arange(K), n_pairs, replace = False)
            pairs2 = np.random.choice(np.arange(K), n_pairs, replace = False)
            for pair in pairs:
                sentence.append(countries[pair])
                sentence.append(currencies[pair])        
#             for pair in pairs2:
#                 temp3 = npr.uniform()
#                 if temp3 <= 0.8:
#                     sentence.append(random_currencies[pair])
#                 else:
#                     sentence.append(random_capitals[pair])
            
            randoms_dup = 1 * random_capitals + 2 * randoms + 4 * random_currencies
            sentence += list(np.random.choice(randoms_dup, L - 2 * n_pairs, replace = False))  
            
#         sentence += list(np.random.choice(randoms, L - 3 * n_pairs, replace = False))  

        sentence_number = [vocab_map[i] for i in sentence]
        sentences.append(sentence)
        sentences_number.append(sentence_number)
        
    x_train = np.array(sentences_number)
    n_cat = len(vocab_map)
    x_masked_train, y_masked_labels_train = get_masked_input_and_labels(x_train, n_cat)
    
    callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
    inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)
    word_embeddings = layers.Embedding(n_cat, embed_dim, name="word_embedding")(inputs)
    encoder_output = layers.GlobalAveragePooling1D()(word_embeddings)
    mlm_output = layers.Dense(n_cat, name="mlm_cls", activation="softmax", use_bias=False)(encoder_output)
    mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)
    adam = Adam()
    mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)

    history = mlm_model.fit(x_masked_train, y_masked_labels_train,
                        validation_split = 0.5, callbacks = [callback], 
                        epochs=500, batch_size=128, verbose=0)
    
    return sentences, vocab_map, mlm_model

In [17]:
def get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples):
    
    sentences, vocab_map, current_model = train_model(K, M, S, L, q1, q2, embed_dim, n_sentences)

    acc_capitals = []
    prob_capitals = []

    for _ in range(n_samples):
        sentence = []
        random_capitals = np.random.choice(np.arange(K), int(L/2), replace = False)
        for random_capital in random_capitals:
            sentence.append('country_' + str(random_capital))
            sentence.append('capital_' + str(random_capital))
        sentence = sentence[:-1]
        sentence_number = [vocab_map[i] for i in sentence]
        temp = keras.backend.function(inputs = current_model.layers[0].input, outputs = current_model.layers[-1].output) \
            (np.array(sentence_number).reshape(1,len(sentence_number)))
        actual = vocab_map['capital_' + str(random_capitals[-1])]
        acc_capitals.append(1 if np.argsort(-1 * temp)[0][0] == actual else 0)
        prob_capitals.append(temp[0][vocab_map['capital_' + str(random_capitals[-1])]])
        
    acc_currencies = []
    prob_currencies = []

    for _ in range(n_samples):
        sentence = []
        random_currencies = np.random.choice(np.arange(K), int(L/2), replace = False)
        for random_currency in random_currencies:
            sentence.append('country_' + str(random_currency))
            sentence.append('currency_' + str(random_currency))
        sentence = sentence[:-1]
        sentence_number = [vocab_map[i] for i in sentence]
        temp = keras.backend.function(inputs = current_model.layers[0].input, outputs = current_model.layers[-1].output) \
            (np.array(sentence_number).reshape(1,len(sentence_number)))
        actual = vocab_map['currency_' + str(random_currencies[-1])]
        acc_currencies.append(1 if np.argsort(-1 * temp)[0][0] == actual else 0)
        prob_currencies.append(temp[0][vocab_map['currency_' + str(random_currencies[-1])]])
        

    return sentences, current_model, vocab_map, (np.mean(acc_capitals), np.mean(prob_capitals)), \
                (np.mean(acc_currencies), np.mean(prob_currencies))

In [18]:
K = 10 # number of countries
L = 8 # sentence length
M = 20 # number of words used by each topic
S = 20 # number of words used by both topics
embed_dim = 100 # CBOW embedding dimension
n_sentences = 50000 # number of sentences in the training set
n_samples = 1000

In [19]:
q0 = 0 # probability of having 0 pairs
q1 = 1 # probability of having 1 pair
q2 = 0 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(0.0, 3.5944326e-15)
(0.0, 2.757914e-14)
(0.0, 1.2790812e-14)
(0.0, 4.5745624e-15)
(0.0, 7.055414e-15)
(0.0, 3.6637418e-16)
(0.0, 1.3081684e-15)
(0.0, 6.070273e-15)
(0.0, 4.5148156e-15)
(0.0, 7.729426e-15)
(0.0, 2.3726957e-13)
(0.0, 8.1500155e-14)
(0.0, 1.9516595e-13)
(0.0, 6.433365e-14)
(0.0, 6.1655664e-15)
(0.0, 3.3151508e-15)
(0.0, 2.8108558e-15)
(0.0, 1.849911e-15)
(0.0, 8.651988e-15)
(0.0, 1.2067215e-14)
(0.0, 4.7932756631893166e-14)
(0.0, 2.0938585806752907e-14)


In [20]:
q0 = 0 # probability of having 0 pairs
q1 = 0 # probability of having 1 pair
q2 = 1 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(0.0, 3.1898244e-07)
(0.0, 5.853088e-07)
(0.0, 1.8220868e-06)
(0.0, 1.8744369e-06)
(0.0, 4.434433e-06)
(0.0, 7.5463695e-06)
(0.0, 1.4470506e-07)
(0.0, 1.7710906e-07)
(0.0, 3.6418434e-05)
(0.0, 2.892952e-05)
(0.0, 6.897493e-05)
(0.0, 3.741454e-05)
(0.0, 2.504861e-06)
(0.0, 2.2556653e-06)
(0.0, 1.4139414e-06)
(0.0, 1.3601484e-06)
(0.0, 1.0800929e-05)
(0.0, 1.6999922e-05)
(0.0, 7.753401e-06)
(0.0, 6.987807e-06)
(0.0, 1.3458670279931082e-05)
(0.0, 1.0413082569016296e-05)


In [21]:
q0 = 1/2 # probability of having 0 pairs
q1 = 1/2 # probability of having 1 pair
q2 = 0 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(1.0, 0.7787374)
(1.0, 0.72903186)
(1.0, 0.874276)
(1.0, 0.86882067)
(1.0, 0.8170653)
(1.0, 0.8003274)
(1.0, 0.82229626)
(1.0, 0.77562314)
(1.0, 0.8536942)
(1.0, 0.8557219)
(1.0, 0.8695247)
(1.0, 0.81739074)
(1.0, 0.83527845)
(1.0, 0.81028306)
(1.0, 0.8140205)
(1.0, 0.8633002)
(1.0, 0.7886727)
(1.0, 0.828625)
(1.0, 0.78305924)
(1.0, 0.77305895)
(0.9999999999999999, 0.8236624777317046)
(0.9999999999999999, 0.8122182965278625)


In [22]:
q0 = 1/2 # probability of having 0 pairs
q1 = 0 # probability of having 1 pair
q2 = 1/2 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(1.0, 0.99991727)
(1.0, 0.9999479)
(1.0, 0.99988955)
(1.0, 0.9998941)
(1.0, 0.9998204)
(1.0, 0.99976075)
(1.0, 0.9999633)
(1.0, 0.99995667)
(1.0, 0.9997715)
(1.0, 0.99978197)
(1.0, 0.9999869)
(1.0, 0.9999811)
(1.0, 0.99988776)
(1.0, 0.9998825)
(1.0, 0.99993235)
(1.0, 0.99994236)
(1.0, 0.99988186)
(1.0, 0.99985117)
(1.0, 0.9999317)
(1.0, 0.99995214)
(0.9999999999999999, 0.9998982548713685)
(0.9999999999999999, 0.9998950660228729)


In [23]:
q0 = 0 # probability of having 0 pairs
q1 = 1/2 # probability of having 1 pair
q2 = 1/2 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(1.0, 0.99658686)
(1.0, 0.99681866)
(1.0, 0.99803025)
(1.0, 0.9977919)
(1.0, 0.9978691)
(1.0, 0.99741685)
(1.0, 0.99852353)
(1.0, 0.9983657)
(1.0, 0.99776936)
(1.0, 0.998373)
(1.0, 0.9975255)
(1.0, 0.9973251)
(1.0, 0.9978946)
(1.0, 0.997852)
(1.0, 0.9980767)
(1.0, 0.9970582)
(1.0, 0.9981928)
(1.0, 0.9983706)
(1.0, 0.99757713)
(1.0, 0.9975261)
(0.9999999999999999, 0.9978045761585236)
(0.9999999999999999, 0.997689813375473)


In [24]:
q0 = 1/3 # probability of having 0 pairs
q1 = 1/3 # probability of having 1 pair
q2 = 1/3 # probability of having 2 pairs

accs_c = 0
probs_c = 0
accs_d = 0
probs_d = 0

for _ in range(10):
    sentences, mlm_model, vocab_map, acc_c, acc_d \
        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)
    
    print(acc_c)
    print(acc_d)
    
    accs_c += acc_c[0]/10
    probs_c += acc_c[1]/10
    accs_d += acc_d[0]/10
    probs_d += acc_d[1]/10
    
print((accs_c, probs_c))
print((accs_d, probs_d))

(1.0, 0.9992926)
(1.0, 0.99933267)
(1.0, 0.99881834)
(1.0, 0.9990907)
(1.0, 0.9992466)
(1.0, 0.9994664)
(1.0, 0.99878234)
(1.0, 0.9990516)
(1.0, 0.99892384)
(1.0, 0.9990865)
(1.0, 0.99885935)
(1.0, 0.9990305)
(1.0, 0.99896413)
(1.0, 0.99908215)
(1.0, 0.9991895)
(1.0, 0.9989625)
(1.0, 0.9989498)
(1.0, 0.9989823)
(1.0, 0.9993551)
(1.0, 0.99916196)
(0.9999999999999999, 0.9990381598472595)
(0.9999999999999999, 0.9991247236728668)
