# This notebook contains an implementation for NN-based language modelling. 
At the core, a language model is a sequence classifier that uses all the tokens produced so far as input in order to produce a probability density function over all possible next tokens (a token could be a word, a character, or something inbetween). We can then either use the "best possible guess" of the classifier as the next token, or we can sample from the distribution according to the distribution. 

In fact, producing a probability density function comes for free, when we build a neural classifier that uses a softmax output activation. Therefore, nothing actually changes from "before", when we simply built classifiers.

Once we have trained the model, we repeatedly ask for next tokens, and add these to the context. This is called "autoregressive sequence generation".

In [41]:
import torch
import torch.nn as nn
import random
from collections import defaultdict

In [42]:
# load the data

source = "data/shakespeare-en.txt" # other files in data: lyrik-de.txt, dialoge-de.txt, merkel-de.txt
N = 4


START_SYMBOL = "<s>"
END_SYMBOL = "</s>"
data = open(source, 'r').read() # should be simple plain text file
characters = set(data)
characters.add(START_SYMBOL)
characters.add(END_SYMBOL)
characters = list(sorted(characters))
NUM_CHARACTERS = len(characters)
int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}
print(characters, len(characters))
NUM_CLASSES = NUM_CHARACTERS

['\n', ' ', '!', '"', '&', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '</s>', '<s>', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '}'] 86


In [43]:
def ngrams_ids_from_text(f, N):
    '''
    Given a text file (f), create all pairs of (context,next) in that file.
    This code is ignorant of line starts and endings and considers all the data as one string.
    NOTE: as compared to the previous version, this outputs characterIDs, rather than actual characters
    '''
    for line in open(f, 'r').readlines():
        context = [START_SYMBOL] * (N - 1)
        symbols = list(line.rstrip()) + [END_SYMBOL]
        for last in symbols:
            yield ([char2int[c] for c in context], char2int[last])
            context.append(last) # append last element
            context.pop(0) # get rid of first element

In [48]:
list(ngrams_ids_from_text(source, N))[0]

([25, 25, 25], 24)

In [59]:

INPUT_SIZE = NUM_CHARACTERS * (N-1)
NUM_CLASSES = NUM_CHARACTERS
MAX_GENERATION_LENGTH = 800

class LM(nn.Module):
    def __init__(self):
        super(LM, self).__init__()
        self.final_layer = nn.Linear(INPUT_SIZE, NUM_CLASSES)

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        xs = torch.nn.functional.one_hot(torch.LongTensor(xs), num_classes=NUM_CLASSES)
        xs = xs.reshape((INPUT_SIZE,)).float()
        return self.final_layer(xs)

    def generate(self, xs=torch.tensor([char2int[START_SYMBOL]] * (N-1)), sample="max") -> torch.tensor:
        """sample can be "max" or "prop" for max likelihood or proportional sampling"""
        classification = None
        output = []
        while ((classification == None) or (classification.item() != char2int[END_SYMBOL])) and (len(output) < MAX_GENERATION_LENGTH):
            if sample == "max":
                classification = torch.argmax(self.forward(xs))
            elif sample == "prop":
                classification = torch.multinonomial(self.forward(xs), 1)[0]
            else:
                assert False, "only max and prop are possible values for sample!"
            print(classification)
            output.append(classification)
            xs = torch.concat([xs[1:], classification])
        output = torch.stack(output[:-1]) if len(output) > 1 else torch.tensor([])
        return output

In [None]:
MAX_EPOCHS = 5

lm = LM()
optimizer = torch.optim.SGD(lm.parameters(), lr=0.01)

ngrams = list(ngrams_ids_from_text(source, N))

def training(ngrams):
    for epoch in range(MAX_EPOCHS):
        print(("Epoch {} starting".format(epoch+1)))
        for ngram in ngrams:
            optimizer.zero_grad()
            print(ngram[0])
            output = lm(ngram[0])
            loss = nn.functional.nll_loss(output, ngram[1])
            loss.backward()
            optimizer.step()
        #print("forced: " + "".join([int2char[x] for x in torch.argmax(lm(training_data[0][:-1]), dim=1)]))
        print("freemax:" + "".join([int2char[x] for x in lm.generate()]))
        print("fresamp:" + "".join([int2char[x] for x in lm.generate(sample="prop")]))
        if epoch < MAX_EPOCHS - 1:
            random.shuffle(ngrams)
    return lm

training(ngrams)

In [50]:
NUM_CLASSES

86

In [39]:
ngrams[3]

([25, 47, 35], 32)