# LANGUAGE MODEL
*VietAI Advanced NLP*

# Addition with LSTM

Inspired from Andrej Karpathy's [play_math](https://github.com/karpathy/minGPT/blob/master/play_math.ipynb), in this tutorial, you will learn how to calculate addition using LSTM !!!

You will learn:
1. Create custom dataset
2. Modeling input, output for a well-structured dataset
3. Masking the labels, so that model only learn on specific part of the labels.
4. Build a LSTM model, train and evaluate the trained model.

## 1. Problem Statement

> The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our
    encoding will simply be the n-digit first number, n-digit second number, 
    and (n+1)-digit result, **all simply concatenated together**. Because each addition
    problem is so structured, there is no need to bother the model with encoding
    +, =, or other tokens. Each possible sequence has the **same length**, and simply
    contains the raw digits of the addition problem. - [From Andrej Karapthy's play_math](https://github.com/karpathy/minGPT/blob/master/play_math.ipynb)
    

Examples:
  - ```85 + 50 = 135``` becomes the sequence ```[8, 5, 5, 0, 1, 3, 5]```
  - ```47 + 17 =  64``` becomes the sequence ```[4, 7, 1, 7, 0, 6, 4]```
  
  etc.

Example of the 2-digit problems:

    - 47 + 17 =  64 becomes the sequence [4, 7, 1, 7, 0, 6, 4]
    


We will also only train LSTM on the final (n+1)-digits because the first
two n-digits are always assumed to be given. So when we give LSTM an exam later,
we will e.g. feed it the sequence ```[4, 7, 1, 7]```, which encodes that we'd like to add ```47 + 17```, and hope that the model completes the integer sequence with ```[0, 6, 4]``` in 3 sequential steps.

---
Example of an item that was generated by this dataset:

input, target = train_dataset[0]
> input: ```tensor([4, 7, 1, 7, 0, 6])```

> target: ```tensor([-100, -100, -100,    0,    6,    4])```
    
---

Explaination:

1. equation: ```47 + 17 = 064```
2. concat all digits together: ```[4, 7, 1, 7, 0, 6, 4]```
3. prepare data for language model teacher-forcing objective (predict next digits)

```
    target:  7     1     7     0     6     4  
             |     |     |     |     |     |
    input :  4     7     1     7     0     6
```

4. Since the model just need to learn to yield the summation, we could ignore the loss from the given number to make the learning progress more efficient, by mask the target with specific index (e.g -100)

```
    contribute to loss:   ✖    ✖     ✖     ✔     ✔     ✔ 
               target : -100  -100  -100    0     6     4
                          |     |     |     |     |     |
               input  :   4     7     1     7     0     6
```

## 2. Define the dataset

In [None]:
# !pip install torch tqdm numpy

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class AdditionDataset(Dataset):

    def __init__(self, ndigit, split):
        """
          ndigit: number of digits 
          split: train/ test
        """
        self.split = split # train/test
        self.ndigit = ndigit
        self.vocab_size = 10 # 10 possible digits 0..9
        
        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        nd = 10**self.ndigit
        a = idx // nd
        b = idx %  nd
        c = a + b
        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028" 
        dix = [int(s) for s in render] # convert each character to its token index
        # x will be input to LSTM and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
        
        y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero
        return x, y

In [None]:
# create a dataset for e.g. 2-digit addition
ndigit = 2
train_dataset = AdditionDataset(ndigit= ndigit, split= 'train')
test_dataset = AdditionDataset(ndigit= ndigit, split= 'test')

In [None]:
train_dataset[0]

## 3. Define the model


We will implement a neural network with two LSTM layers, vocab_size = 10 (10 digits), 1024 hidden state, with p_dropout = 0.2 between each layer

In [None]:
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

In [None]:
n_layers = 2
vocab_size = 10
hidden_size = 1024
p_dropout = 0.2

In [None]:
class AdditionLSTM(nn.Module):
    def __init__(self):
        super(AdditionLSTM, self).__init__()
        """
        layers:
          1. Embedding
          2. LSTM (with n_layers layers, droprate is p_dropout, batch size is first demension)
          3. Linear
        """
        ### YOUR CODE HERE ###
        # using nn module in pytorch
        self.layer1 = 
        self.layer2 = 
        self.layer3 =
        
    def forward(self, word_seq):
        g_seq                      =   self.layer1( word_seq )
        h_seq , (h_final,c_final)  =   self.layer2(g_seq)      
        score_seq                  =   self.layer3( h_seq )
        return score_seq,  h_final , c_final

In [None]:
net = AdditionLSTM()
print(net)

## 4. Implement the training loop

In [None]:
import random
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
# Some config for the training
n_epochs = 500
batch_size = 512
lr = 0.002
device = torch.device("cuda")
log_every = 10

net = net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

In [None]:
def run_epoch(split):
    is_train = split == "train"
    net.train(is_train) # set to train False when eval
    data = train_dataset if is_train else test_dataset
    loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=2, worker_init_fn=np.random.seed(seed))

    losses = []
    
    for input, target in loader:
        # send mini batch to device
        input = input.to(device)
        target = target.to(device)
    
        # forward the model
        with torch.set_grad_enabled(is_train):
            scores, h, c = net(input) # scores: batch_size, seq_len, vocab_size

            # reshape the scores and target to huge batch of size bs*seq_length
            scores = scores.view(-1, vocab_size)
            target = target.flatten()
            loss = criterion(scores, target)
            losses.append(loss.item())
        
        if is_train:
            net.zero_grad()
            loss.backward()
            optimizer.step()
        
        return float(np.mean(losses))
        
        
for epoch in range(n_epochs):
    train_loss = run_epoch('train')
    test_loss = run_epoch('test')
    if (epoch + 1) % log_every == 0:
        print(f"Epoch: {epoch + 1:3d} Train loss: {train_loss:.5f} Test loss: {test_loss:.5f}")

## 5. Define the sampling

We will define the way we sample the model's output to get the final prediction.

Normally, we take the index with the highest score to predict. However, in this section, we will define a more general sampling function based on k highest scores.

In some situations, we may want not to take the prediction with the highest scores but sample from top k scores or we may want to keep the k best candidates (eg. beam search). These functions will help you do that.

```
1. top_k_logits: remain the top k elements and mask the others to -INFINITY
2. generate_top_k: generate the prediction 
  * top_k (int): prediction only based on the top k logits
  * sample (True): sample random from top k score (just naive sampling 😆)
  or simply take the index with highest score 
```

*Our evaluation simply uses the highest score to generate prediction*

In [None]:
import torch.nn.functional as F


def top_k_logits(logits, k):
    """
    Remain top k logits and mask the others to -INFINITY

    logits: input logits matrix (size [B,length,vocab_size])
    k: the number of remaining elements

    examples:
     a = [[1.2, 5.0, -3.4, 4.2, -2.1],
          [2.1, -0.4, -3.0, 1.7, 0.1]]

     => top_k_logits(a, k = 2) = [[-INF, 5.0, -INF, 4.2, -INF],
                                  [2.1, -INF, -INF, 1.7, -INF]]
    """

    # hint: torch.topk() might be useful
    ### YOUR CODE HERE ###

    return out


def generate_top_k(ids, n_char=200, sample=False, top_k=None):
    
    """
    ids: input ids
    n_char: number of steps to generate final output (n-digit corresponds to n-digit + 1 steps)
    sample: sample random from top k score
    top_k: get top_k prediction score
    """
    
    prompt_ids = ids[:] # copy
    with torch.no_grad():
        for _ in range(n_char):
            input = torch.LongTensor(prompt_ids).reshape(1, -1).to(device)
            scores, h, c = net(input)
            # scores shape (B, L, V), B = 1
            logits = scores[:, -1, :]

            # apply softmax to convert to probabilities
            if top_k is not None:
                logits = top_k_logits(logits, top_k)
            
            
            probs = F.softmax(logits, dim=-1)

            if sample:
                ix = torch.multinomial(probs, num_samples=1)
            else:
                _, ix = torch.topk(probs, k=1, dim=-1)
                # ix = np.argmax(probs.cpu().numpy(), axis= -1)

              

            prompt_ids.append(ix.item())
    return prompt_ids

In [None]:
net.eval()

In [None]:
generate_top_k([1, 2, 3, 5], n_char=3) # We expect [1, 2, 3, 5, 0, 4, 7] since 12 + 35 = 047

In [None]:
def to_integer(list_digit):
    """
    Convert a list of digits to number
    e.g [1, 2] => 12
    """
    out = 0
    reverse_list_digit = list_digit[::-1]
    for factor, i in enumerate(reverse_list_digit):
        out += 10**factor*i
    return out

to_integer([1, 2, 3])

In [None]:
# now let's evaluate our trained model
def evaluate(dataset):
    
    results = []
    for x, y in tqdm(dataset):
        d1d2 = x[:ndigit*2].numpy().tolist() # Take first two term as prompt
        d1d2d3 = generate_top_k(d1d2, ndigit+1)
        d3 = d1d2d3[-(ndigit+1):] # Take the last ndigit+1
        
        # decode the integers from individual digits
        d1i = to_integer(d1d2[:ndigit])
        d2i = to_integer(d1d2[ndigit:])
        
        d3i_pred = to_integer(d3)
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt)
        results.append(int(correct))
        
        judge = 'YEP!!!' if correct else 'NOPE'
        if not correct:
            print("LSTM claims that %03d + %03d = %03d (gt is %03d; %s)" 
                % (d1i, d2i, d3i_pred, d3i_gt, judge))
        
    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))



In [None]:
# training set: how well did we memorize?
evaluate(train_dataset)

In [None]:
# test set: how well did we generalize?
evaluate(test_dataset)