In [29]:
import random, string
from tqdm import tqdm
import tensorflow as tf
import numpy as np
from keras.preprocessing.sequence import pad_sequences
from network import BaseModel

In [6]:
f = open("words_250000_train.txt", "r")
all_words = f.read().splitlines()
maxlen = max([len(word) for word in all_words])

In [7]:
random.shuffle(all_words)
n = len(all_words)
train_split = 0.8
train_words = all_words[:int(train_split*n)]
test_words = all_words[int(train_split*n):]

In [8]:
def get_masked_words(word_list):
  ret = {}
  for word in tqdm(word_list):
    orig = word
    chars = list(set([c for c in word]))
    num_dels = random.choice(range(1,len(chars)+1))
    del_chars = random.choices(chars, k = num_dels)
    for c in del_chars:
      word = word.replace(c, '_')
    ret[orig] = word
  return ret

In [11]:
train_masked_words, test_masked_words = get_masked_words(train_words), get_masked_words(test_words)

  0%|          | 0/181840 [00:00<?, ?it/s]

100%|██████████| 181840/181840 [00:00<00:00, 264055.00it/s]
100%|██████████| 45460/45460 [00:00<00:00, 257825.17it/s]


In [17]:
x_train, y_train = list(train_masked_words.values()), list(train_masked_words.keys())
x_test, y_test = list(test_masked_words.values()), list(test_masked_words.keys())

In [14]:
set_letters = set(string.ascii_lowercase)
letters = list(set_letters)
letters.sort()
letter_dict = {l : i+1 for i, l in enumerate(letters)}
letter_dict['_'] = 27

def letter_to_num(c):
  return letter_dict[c]

def process_input(words):
  seq = [list(map(letter_to_num, word)) for word in words]
  return pad_sequences(seq, maxlen = maxlen, padding="post", value=0)

In [18]:
x_train = process_input(x_train)
x_test = process_input(x_test)
y_train = process_input(y_train)
y_test = process_input(y_test)

In [21]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1000).batch(200)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(buffer_size=1000).batch(200)

In [24]:
base_model = BaseModel(vocab_size = 28, maxlen=maxlen, embed_size = 512, num_heads = 64, key_dim = 2)

In [25]:
inp = tf.keras.layers.Input(shape=(maxlen))
attn = base_model(inp)
dense1 = tf.keras.layers.Dense(512, activation='relu')(attn)
dense2 = tf.keras.layers.Dense(256, activation='relu')(dense1)
dense3 = tf.keras.layers.Dense(64, activation='relu')(dense2)
out = tf.keras.layers.Dense(28, activation='softmax')(dense3)

In [27]:
model = tf.keras.Model(inputs=inp, outputs=out)
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 29)]              0         
                                                                 
 base_model (BaseModel)      (None, 29, 512)           278400    
                                                                 
 dense (Dense)               (None, 29, 512)           262656    
                                                                 
 dense_1 (Dense)             (None, 29, 256)           131328    
                                                                 
 dense_2 (Dense)             (None, 29, 64)            16448     
                                                                 
 dense_3 (Dense)             (None, 29, 28)            1820      
                                                                 
Total params: 690652 (2.63 MB)
Trainable params: 690652 (2.

In [28]:
loss_fn = tf.losses.SparseCategoricalCrossentropy()
optimizer = tf.optimizers.Adam()
model.compile(loss=loss_fn, optimizer=optimizer)

In [None]:
model.fit(train_ds, validation_data=test_ds, epochs = 5)

In [None]:
base_model.save_weights('base_model_weights_64x512.h5')

In [None]:
num_samples = 10
target = random.choices(list(test_masked_words.keys()), k = num_samples)
source = [test_masked_words[word] for word in target]

In [None]:
pred = np.argmax(model.predict(process_input(source)), axis=-1)

In [None]:
inv_letter = {v:k for k,v in letter_dict.items()}
inv_letter[0] = '.'

def num_to_letter(num):
  return inv_letter[num]

In [None]:
print("Predicted Words: ", [''.join(list(map(num_to_letter, seq))).replace('.','') for seq in pred])

In [None]:
print("Original Words: ", {k:v for k,v in zip(source, target)})