## An LSTM language model in TF

For this task the inspiration comes from the famous [reference work of Andrej Karpathy](https://karpathy.github.io/2015/05/21/rnn-effectiveness/). 

Note, that in this case we will not use regularization, since we are willing to overfit - for the sake of play with the text. This is now an "overfitting competition", so _not_ a generally good practice!

## Reader

In [None]:
#!git clone https://github.com/solalatus/rejto_lm.git
#%cd rejto_lm

In [1]:
import rejto

In [2]:
rejto_corpus = rejto.Rejto_corpus()

In [3]:
len(rejto_corpus.sents())

192087

In [4]:
import numpy as np
import tensorflow as tf
import nltk

from numpy.random import seed
seed(1212)

tf.random.set_seed(1234)


# This can be an important parameter, so be aware of it...
max_seq_length = 30

# max_num_of_sents -- how many sentences should we read from the corpus 
max_num_of_sents = rejto_corpus.n_sents()

def generate_rejto_word_to_id_map():
    """Return a dictionary mapping downcased Rejto-words to their ids.
    Numbering starts from 1 since we use 0 for masking (!!!).
    """
    words = set()
    for word in rejto_corpus.words():
        words.add(word.lower())
    return {word: idx + 1 for idx, word in enumerate(sorted(words))}


class RejtoReader:
    """A secondary reader class for the Rejto corpus.
    """

    def __init__(self):
        self.word_to_id_map = generate_rejto_word_to_id_map()
        self.id_to_word_map = {idx: word for word, idx in self.word_to_id_map.items()}

    def n_words(self):
        return len(self.word_to_id_map)

    def sentence_to_ids(self, sentence):
        """Return the word ids of a sentence.
        """
        return [self.word_to_id_map[word.lower()] for word in sentence]
        
    def sentences(self):
        """Generator yielding features from the Rejto corpus.
        """
        return (self.sentence_to_ids(sentence) for sentence in rejto_corpus.sents())

    def sentence_matrixes(self):
        x = np.zeros((max_num_of_sents, max_seq_length-1))
        y = np.zeros((max_num_of_sents, max_seq_length-1))
        sents = self.sentences()
        for idx, sent in enumerate(sents):
            if idx == max_num_of_sents:
                breaka
            np_array = np.asarray(sent)
            length  = min(max_seq_length, len(np_array))
            x[idx, :length - 1] = np_array[:length - 1]
            y[idx, :length - 1] = np_array[1:length]
        return x, y


## Model

### Parameters

In [5]:
r = RejtoReader()
n_words = r.n_words()

# network parameters

lstm_size = 512
embedding_size = 150
max_input_length = max_seq_length - 1 # since our x/y input does not contain the last/first element of the sentences

### Network

In [6]:
from tensorflow.keras.layers import Input, Dense, Embedding, LSTM, TimeDistributed
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adadelta, Adam, SGD
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras import backend as K


tf.compat.v1.reset_default_graph() # It's good practice to clean and reset everything
K.clear_session            # even using Keras


# Model
########

x = Input(shape=(max_input_length,))

embedded_x =  Embedding(n_words + 1, embedding_size, input_length=max_input_length - 1, mask_zero=True)(x)

lstm_outputs = LSTM(lstm_size, return_sequences=True)(embedded_x)

lstm_outputs, hidden_state, cell_state = LSTM(lstm_size, return_sequences=True, return_state= True)(lstm_outputs)

predictions = Dense(n_words + 1, activation="softmax")(lstm_outputs)

model = Model(inputs=x, outputs=predictions)

model.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 14)]              0         
_________________________________________________________________
embedding (Embedding)        (None, 14, 150)           21221700  
_________________________________________________________________
lstm (LSTM)                  (None, 14, 512)           1357824   
_________________________________________________________________
lstm_1 (LSTM)                [(None, 14, 512), (None,  2099200   
_________________________________________________________________
dense (Dense)                (None, 14, 141478)        72578214  
Total params: 97,256,938
Trainable params: 97,256,938
Non-trainable params: 0
_________________________________________________________________


### Error, optimization, compilation

In [7]:
# Loss 

loss = sparse_categorical_crossentropy # we use this cross entropy variant as the input is not 
                                       # one-hot encoded

# Optimizer
# Choose an optimizer - adaptive ones work well here
optimizer = Adam()
 
# Compilation
#############

model.compile(optimizer=optimizer, loss=loss)

In [8]:
data_x, data_y = r.sentence_matrixes()

ValueError: could not broadcast input array from shape (0) into shape (13)

### Training

We generate the trainig data.

In [None]:
data_y = np.expand_dims(data_y, -1) # It seems that Keras needs this for the "one-cold" and softmax dims to match

And train the model:

In [None]:
history = model.fit(x=data_x , y=data_y, validation_split=0.1, epochs=10, batch_size=100)

## Demo 1: Predict next word

In [9]:
# Prediction
############

def str_to_input(s):
    """Convert a string to appropriate model input.
    """
    words = [x.lower() for x in s.split()[:max_input_length]]
    ids = [r.word_to_id_map[word] for word in words]
    ids_array = np.asarray(ids)
    length = min(max_input_length, len(ids_array))
    result = np.zeros((1, max_input_length))
    result[0, :length] = ids_array[:length]
    return result, length
    

while True:
    s = input("\nEnter a few starting words of a sentence or <return> to stop: ")
    if s == "":
        break
    else:
        try:
            x, length = str_to_input(s)
            predictions = model.predict(x)
            probs = predictions[0][length - 1]
            most_probable = np.argmax(probs)
            print("Predicted next word:", r.id_to_word_map[most_probable])
        except KeyError:
            print("Unknown words -- please try again!")


KeyboardInterrupt: 

## Demo 2: Similarity of sentences

First we define a function that generates the hidden state of the LSTM from an input sentence:

In [None]:
input_layer = model.get_layer("input_1")
lstm_2_layer = model.get_layer("lstm_1")

cell_state_fun = K.function([input_layer.input],[lstm_2_layer.output[2]])

def get_embedding(x):
    """Return the final cell state associated with the input.
       Returns the last cell state as a vector.
    """
    return cell_state_fun([x])[0].flatten()

Then we use the vectors for calculating the cosine distance between sentences.

In [None]:
def cos_sim(a, b):
	return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

while True:
    s1 = input("\nEnter the first sentence or <return> to quit: ")
    if s1 == "": break
    s2 = input("\nEnter the second sentence: ")
    try:
        x1, _ = str_to_input(s1)
        x2, _ = str_to_input(s2)
        e1 = get_embedding(x1)
        e2 = get_embedding(x2)
        print("The cosine similarity between the two sentences is", cos_sim(e1, e2))
    except KeyError:
        print("Unknown words -- please try again!")

## Demo 3: Mini search engine

We use the library [Annoy](https://github.com/spotify/annoy) published by Spotify to create a vector space index of the Rejto corpus from the LSTM's cell state. We assign a vector for each sentence, and then store it to be able to run nearest neighbor queries on it. With this we effectively created a **semantic search engine**.

(There are multiple solutions for approximate nearest neighbor search a scale which are worth looking into, one of them is [FAISS](https://code.fb.com/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) from Facebook Research.)

In [None]:
def rejto_sent_to_input(ids):
  ids_array = np.asarray(ids)
  length = min(max_input_length, len(ids_array))
  result = np.zeros((1, max_input_length))
  result[0, :length] = ids_array[:length]
  return result, length

In [None]:
sentlist = list(r.sentences())

In [None]:
!pip install annoy

In [None]:
INDEX_COVERAGE_PERCENT = 1.0 #How much of the corpus you want ot index? 1.0 means whole, 0.5 means half.
NEAREST_NEIGHBOR_NUM = 5

In [None]:
from annoy import AnnoyIndex
from tqdm import tqdm

index = AnnoyIndex(512, metric="angular")

for i in tqdm(range(int(len(sentlist)*INDEX_COVERAGE_PERCENT))):
  inputs,length = rejto_sent_to_input(sentlist[i])
  vector = get_embedding(inputs)
  index.add_item(i,vector)

print("Building index...")
index.build(100)
print("Index done, ready to query!")

In [None]:
def print_rejto_index(sentences, indices):
  for i in indices:
    word_ids_list = sentences[i]
    for j in word_ids_list:
      print(r.id_to_word_map[j]+" ", end='')
    print()

    

In [None]:
while True:
  query = input("\nEnter the query or <return> to quit: ")
  if query == "": break
  try:
    in_ids, length = str_to_input(query)
    in_vector = get_embedding(in_ids)
    nearest_sentence_indices = index.get_nns_by_vector(in_vector, NEAREST_NEIGHBOR_NUM)
    #print("nearest indices:", nearest_sentence_indices)
    print_rejto_index(sentlist, nearest_sentence_indices)

  except KeyError:
    print("Unknown words -- please try again!")

## Beam try:

In [6]:
loaded = False
try:
    model
except:
    model = tf.keras.models.load_model("rejto_model/rejto")
    loaded = True


In [79]:
import numpy as np

def search(model, src_input, k=5, sequence_max_len=50, loaded=False):
    # (log(1), initialize_of_zeros)
    k_beam = [(0, [0]*(sequence_max_len+1))]

    # l : point on target sentence to predict
    for l in range(sequence_max_len):
        all_k_beams = []
        for prob, sent_predict in k_beam:
            #predicted = model([np.array([src_input]), np.array([sent_predict])])[0]
            x, length = str_to_input(src_input)
            if loaded == False:
                predicted = model.predict(x)[0]

            else:
                predicted = model(x).numpy()[0]
            # top k!
            possible_k = predicted[l].argsort()[-k:][::-1]

            # add to all possible candidates for k-beams
            all_k_beams += [
                (
                    sum(np.log(predicted[i][sent_predict[i+1]]) for i in range(l)) + np.log(predicted[l][next_wid]),
                    list(sent_predict[:l+1])+[next_wid]+[0]*(sequence_max_len-l-1)
                )
                for next_wid in possible_k
            ]

        # top k
        k_beam = sorted(all_k_beams)[-k:]

    return k_beam

In [80]:
beams = search(model, "Jimmy", sequence_max_len = max_seq_length)

[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]]
[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-

[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]]
[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-

[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]]
[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-

[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]]
[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-

[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]]
[[3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-08
  3.5172981e-08 3.5724760e-08]
 ...
 [3.5338601e-08 3.4182950e-08 3.6516322e-08 ... 3.5636830e-08
  3.5172977e-08 3.5724756e-08]
 [3.5338608e-08 3.4182953e-08 3.6516326e-08 ... 3.5636834e-

IndexError: index 14 is out of bounds for axis 0 with size 14

In [82]:
for b in beams:
    #print(b[1])
    print([r.id_to_word_map[w] for w in b[1] if w != 0])

[',', ',', ',', 'a', ',', ',', ',', ',', ',', ',']
[',', ',', 'a', ',', ',', ',', ',', ',', ',', ',']
[',', 'a', ',', ',', ',', ',', ',', ',', ',', ',']
['a', ',', ',', ',', ',', ',', ',', ',', ',', ',']
[',', ',', ',', ',', ',', ',', ',', ',', ',', ',']
