# RNN from scratch in PyTorch to generate char sequences

## Support code

In [91]:
import pandas as pd
import numpy as np
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
#from torch.nn.functional import softmax
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
np.set_printoptions(precision=2, suppress=True, linewidth=3000, threshold=20000)
from typing import Sequence

np.set_printoptions(precision=3)

dtype = torch.float64

In [92]:
def randn(n1, n2, dtype=torch.float64, mean=0.0, std=0.01, requires_grad=True):
    x = torch.randn(n1, n2, dtype=dtype)
    x = x*std + mean # Convert x to have mean and std
    x.requires_grad=requires_grad
    return x

## Use fastai human numbers data

The data is from [fastai book chap 12](https://github.com/fastai/fastbook/blob/master/12_nlp_dive.ipynb). Looks like:

```
one 
two 
three 
...
two hundred seven 
two hundred eight 
...
```


In [93]:
from fastai2.text.all import untar_data, URLs
path = untar_data(URLs.HUMAN_NUMBERS)

## Support

In [94]:
import codecs
import os
import re
import string
import numpy as np
import pandas as pd
from typing import Sequence
from sklearn.model_selection import train_test_split

import tensorflow_addons as tfa
from keras.datasets import mnist
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import models, layers, callbacks, optimizers, Sequential, losses
import tqdm
from tqdm.keras import TqdmCallback

def get_text(filename:str):
    """
    Load and return the text of a text file, assuming latin-1 encoding as that
    is what the BBC corpus uses.  Use codecs.open() function not open().
    """
    f = codecs.open(filename, encoding='latin-1', mode='r')
    s = f.read()
    f.close()
    return s

## Load corpus and numericalize tokens

In [95]:
text = get_text(path/'train.txt')
text = text[:50_000] # TESTING!!!
text[:30]

'one \ntwo \nthree \nfour \nfive \ns'

In [96]:
text = re.sub(r'[ \n]+', ' . ', text) # use '.' as separator token
text[:20]

'one . two . three . '

In [97]:
tokens = text.split(' ')
tokens = tokens[:-1] # last token is blank '' so delete
tokens[:5]

['one', '.', 'two', '.', 'three']

In [98]:
# get unique vocab but don't sort; keep order so 'one'=1 etc...
v = set('.')
vocab = ['.']
for t in tokens:
    if t not in v:
        vocab.append(t)
        v.add(t)
#vocab = sorted(set(tokens))
vocab[:10]

['.', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']

In [99]:
len(vocab)

30

In [100]:
index = {w:i for i,w in enumerate(vocab)}
tokens = [index[w] for w in tokens]
tokens[:10]

[1, 0, 2, 0, 3, 0, 4, 0, 5, 0]

In [101]:
X = torch.tensor(tokens[0:-1])
y = torch.tensor(tokens[1:])
len(X), len(y), len(tokens)

(15353, 15353, 15354)

In [102]:
X[0:5], y[0:5]

(tensor([1, 0, 2, 0, 3]), tensor([0, 2, 0, 3, 0]))

## Split out validation set

In [103]:
# X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.20)
ntrain = int(len(X)*.80)
X_train, y_train = X[:ntrain], y[:ntrain]
X_valid, y_valid = X[ntrain:], y[ntrain:]

## Get vocab

In [104]:
wtoi = {w:i for i, w in enumerate(vocab)}
wtoi

{'.': 0,
 'one': 1,
 'two': 2,
 'three': 3,
 'four': 4,
 'five': 5,
 'six': 6,
 'seven': 7,
 'eight': 8,
 'nine': 9,
 'ten': 10,
 'eleven': 11,
 'twelve': 12,
 'thirteen': 13,
 'fourteen': 14,
 'fifteen': 15,
 'sixteen': 16,
 'seventeen': 17,
 'eighteen': 18,
 'nineteen': 19,
 'twenty': 20,
 'thirty': 21,
 'forty': 22,
 'fifty': 23,
 'sixty': 24,
 'seventy': 25,
 'eighty': 26,
 'ninety': 27,
 'hundred': 28,
 'thousand': 29}

In [105]:
def onehot(ci:int, vocab):
    v = torch.zeros((len(vocab),1), dtype=torch.float64)
    v[ci] = 1
    return v

In [106]:
onehot(2, vocab)

tensor([[0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], dtype=torch.float64)

## Train

In [107]:
def sample(h0, ci, n, temperature=0.1):
    "Derived from Karpathy: https://gist.github.com/karpathy/d4dee566867f8291f086"
    h = h0
    words = [vocab[ci]]
    with torch.no_grad():
        for i in range(n):
            x = onehot(X_train[i], vocab)
            h = W.mm(h) + U.mm(x)
            h = torch.relu(h)  # squish to (-1,+1); also better than sigmoid for vanishing gradient
            o = V.mm(h).reshape(-1) # unnormalized log probabilities for next char
#             print(o)
            o = o / temperature
            o = np.exp(o)
            p = o / np.sum(o.numpy())
#             p = F.softmax(o[0]).numpy() # normalized probabilities
#             print(p)
#             print(np.sum(p))
            wi = np.random.choice(range(len(vocab)), p=p)
            words.append(vocab[wi])
    return words

In [108]:
def predict(h0, input):
    h = h0
    words = [vocab[ci]]
    n = len(input)
    with torch.no_grad():
        for i in range(n):
            x = onehot(input[i], vocab)
            h = W.mm(h) + U.mm(x)
            h = torch.relu(h)  # squish to (-1,+1); also better than sigmoid for vanishing gradient
            o = V.mm(h).reshape(1,-1) # unnormalized log probabilities for next char
            p = F.softmax(o[0]).numpy() # normalized probabilities
            words.append(vocab[ci])
    return words

In [109]:
def forward(input):
    h = torch.zeros(nhidden, 1, dtype=torch.float64)
    seq_outputs = torch.empty(len(input),len(vocab))
    for i in range(0,len(input)):
        x = onehot(input[i], vocab)
        h = W.mm(h) + U.mm(x)
        h = torch.relu(h)  # squish to (-1,+1); also better than sigmoid for vanishing gradient
#         print(h)
        o = V.mm(h)
        seq_outputs[i] = o.reshape(-1)
    return seq_outputs

In [111]:
nhidden = 64
nfeatures = len(vocab)
nclasses = len(vocab) # predicting chars
seqlen = 16

W = randn(nhidden, nhidden)
U = randn(nhidden, nfeatures)
V = randn(nclasses, nhidden)

n = (len(X_train) // seqlen) * seqlen # make it a multiple of seqlen
X_train = X_train[:n]
y_train = y_train[:n]
X_valid = X_valid[:n]
y_valid = y_valid[:n]

learning_rate = 0.001
weight_decay = 0.0
optimizer = torch.optim.Adam([W,U,V], lr=learning_rate, weight_decay=weight_decay)
nepochs=20
loss = 0
for epoch in range(nepochs+1):
    h = randn(nhidden, 1, requires_grad=False) # reset hidden state at start of epoch
    outputs = torch.empty(n,len(vocab))
    for p in range(0,n,seqlen): # do one epoch
        seq_outputs = forward(X_train[p:p+seqlen])
        outputs[p:p+seqlen] = seq_outputs
        '''
        seq_outputs = torch.empty(seqlen,len(vocab))
        for i in range(p,p+seqlen,1):    # do one subsequence of entire X_train
            x = onehot(X_train[i], vocab)
            h = W.mm(h) + U.mm(x)
            h = torch.relu(h)  # squish to (-1,+1); also better than sigmoid for vanishing gradient
    #         print(h)
            o = V.mm(h)
            outputs[i] = seq_outputs[i-p] = o.reshape(-1)
#             print(i, vocab[X_train[i]], '->', vocab[y_train[i]], "vs", vocab[np.argmax(F.softmax(o).detach().numpy())])
#             loss = loss + F.cross_entropy(o, torch.tensor([y_train[i]]))
#             print(i, X_train[i], loss.item())
#         print(f"SEQUENCE loss={loss.item():.4f} ------------")
        '''
        h = h.detach() # truncated BPTT; tell pytorch to forget prev h computations for dx purposes
        loss = F.cross_entropy(seq_outputs, y_train[p:p+seqlen])
#         print(loss.item())
        optimizer.zero_grad()
        loss.backward() # autograd computes U.grad and M.grad
        optimizer.step()

    with torch.no_grad():
        train_loss = F.cross_entropy(outputs, y_train)
        y_prob = F.softmax(outputs, dim=1)
        y_pred = torch.argmax(y_prob, dim=1)
        metric_train = accuracy_score(y_pred, y_train)
#         print(f"Epoch {epoch:3d} loss {train_loss:7.4f} accuracy {metric_train:4.3f} OLD ----------")

        o = forward(X_train)
        train_loss = F.cross_entropy(o, y_train)
        y_prob = F.softmax(o, dim=1)
        y_pred = torch.argmax(y_prob, dim=1)
        metric_train = accuracy_score(y_pred, y_train)
#         print(f"Epoch {epoch:3d} loss {train_loss:7.4f} accuracy {metric_train:4.3f}")

        o = forward(X_valid)
        valid_loss = F.cross_entropy(o, y_valid)
        y_prob = F.softmax(o, dim=1)
        y_pred = torch.argmax(y_prob, dim=1)
        metric_valid = accuracy_score(y_pred, y_valid)
        print(f"Epoch {epoch:3d} loss {train_loss:7.4f} accuracy {metric_train:4.3f}     loss {valid_loss:7.4f} accuracy {metric_valid:4.3f}")

#     print(sample(h0=h, ci=np.random.randint(0,len(vocab)), n=40))
#     print(sample(h0=h, ci=1, n=40))
#     with torch.no_grad():
#         loss = F.cross_entropy(model(train_data.tensors[0]), train_data.tensors[1])
#     print(f"loss={loss.item():.4f} ------------")


Epoch   0 loss  1.3989 accuracy 0.606     loss  1.2747 accuracy 0.697
Epoch   1 loss  1.4621 accuracy 0.607     loss  1.0880 accuracy 0.738
Epoch   2 loss  1.8122 accuracy 0.631     loss  1.3815 accuracy 0.757
Epoch   3 loss  1.9845 accuracy 0.653     loss  1.3950 accuracy 0.758
Epoch   4 loss  1.7803 accuracy 0.643     loss  1.3222 accuracy 0.768
Epoch   5 loss  1.7045 accuracy 0.651     loss  1.3061 accuracy 0.765
Epoch   6 loss  1.6898 accuracy 0.635     loss  1.2245 accuracy 0.767
Epoch   7 loss  1.7622 accuracy 0.654     loss  1.2443 accuracy 0.770
Epoch   8 loss  1.6787 accuracy 0.668     loss  1.1817 accuracy 0.771
Epoch   9 loss  1.7953 accuracy 0.662     loss  1.1798 accuracy 0.772
Epoch  10 loss  1.8556 accuracy 0.647     loss  1.3997 accuracy 0.770
Epoch  11 loss  1.6937 accuracy 0.662     loss  1.2926 accuracy 0.769
Epoch  12 loss  1.7508 accuracy 0.668     loss  1.4438 accuracy 0.769
Epoch  13 loss  1.6104 accuracy 0.663     loss  1.3482 accuracy 0.772
Epoch  14 loss  2.27

KeyboardInterrupt: 