In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
import random

from torch.nn.utils.rnn import PackedSequence
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+2 for i,s in enumerate(chars)}
stoi['#'] = 0
stoi['.'] = 1
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'x', 26: 'y', 27: 'z', 0: '#', 1: '.'}
28


In [4]:
[*map(stoi.get, ['a','b', 'c'])][1:]

[3, 4]

In [5]:
# shuffle up the words
random.seed(42)
random.shuffle(words)

In [6]:
# # build the dataset

# def build_dataset(words):  
#   X, Y = [], []
  
#   for w in words:
#     context = [0] * block_size
#     for ch in w + '.':
#       ix = stoi[ch]
#       X.append(context)
#       Y.append(ix)
#       context = context[1:] + [ix] # crop and append

#   X = torch.tensor(X)
#   Y = torch.tensor(Y)
#   print(X.shape, Y.shape)
#   return X, Y

# n1 = int(0.8*len(words))
# n2 = int(0.9*len(words))
# Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
# Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
# Xte,  Yte  = build_dataset(words[n2:])     # 10%

In [7]:
thing = torch.randn(1,2,3)

In [8]:
thing

tensor([[[ 0.2158,  0.7893,  0.3984],
         [-0.8623, -0.2438, -0.5209]]])

In [9]:
thing.size()

torch.Size([1, 2, 3])

In [10]:
thing.shape

torch.Size([1, 2, 3])

In [11]:
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = pack_sequence([a, b, c], enforce_sorted=False)

a = torch.cat((a, torch.zeros(4, 10)))
b = torch.cat((b, torch.zeros(5, 10)))
f = torch.stack([a, b, c])
e = pack_padded_sequence(f, [3,2,7], batch_first = True, enforce_sorted=False)
(d[0] == e[0]).all()

tensor(True)

In [12]:
def cmp(label, t1, t2):
  ex = torch.all(t1 == t2).item()
  app = torch.allclose(t1, t2, rtol=1e-04, atol=1e-07)
  maxdiff = (t1 - t2).abs().max().item()
  print(f'{label} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')
  print()

In [48]:
torch.manual_seed(42)
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = torch.randn(6, 10)
e = pack_sequence([a, b, c, d], enforce_sorted=False)
rnn = torch.nn.RNN(input_size = 10, hidden_size = 20, bidirectional = True)
out, hidden = rnn.forward(e)
out_unpacked = torch.nn.utils.rnn.unpack_sequence(out)
weights = rnn._flat_weights

In [20]:
hidden.shape, out[0].shape

(torch.Size([2, 4, 20]), torch.Size([18, 40]))

In [15]:
a[[0], :].shape, weights[0].shape

(torch.Size([1, 10]), torch.Size([20, 10]))

In [59]:
for w, name in zip(weights, rnn._flat_weights_names):
    print(w.shape, name)

torch.Size([20, 10]) weight_ih_l0
torch.Size([20, 20]) weight_hh_l0
torch.Size([20]) bias_ih_l0
torch.Size([20]) bias_hh_l0
torch.Size([20, 10]) weight_ih_l0_reverse
torch.Size([20, 20]) weight_hh_l0_reverse
torch.Size([20]) bias_ih_l0_reverse
torch.Size([20]) bias_hh_l0_reverse


In [75]:
h0f = torch.zeros(1, 20)
h1f = torch.tanh((a[[0], :] @ weights[0].T + weights[2]) + (h0f @ weights[1].T + weights[3]))
h2f = torch.tanh((a[[1], :] @ weights[0].T + weights[2]) + (h1f @ weights[1].T + weights[3]))
h3f = torch.tanh((a[[2], :] @ weights[0].T + weights[2]) + (h2f @ weights[1].T + weights[3]))

h0r = torch.zeros(1, 20)
h1r = torch.tanh((a[[2], :] @ weights[4].T + weights[6]) + (h0r @ weights[5].T + weights[7]))
h2r = torch.tanh((a[[1], :] @ weights[4].T + weights[6]) + (h1r @ weights[5].T + weights[7]))
h3r = torch.tanh((a[[0], :] @ weights[4].T + weights[6]) + (h2r @ weights[5].T + weights[7]))

h0 = torch.cat([h0f, h0r], dim=1)
h1 = torch.cat([h1f, h3r], dim=1)
h2 = torch.cat([h2f, h2r], dim=1)
h3 = torch.cat([h3f, h1r], dim=1)

In [79]:
hs = torch.cat([h1, h2, h3], dim=0)
a_out = out_unpacked[0]
cmp("hs", hs, a_out)
cmp("h3", torch.cat([h3f, h3r], dim = 0), hidden[:, 0, :])

hs | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

h3 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07



In [74]:
cmp("h3", h1r, a_out[[2], 20:])

h3 | exact: False | approximate: True  | maxdiff: 5.960464477539063e-08



In [35]:
a_out.shape

torch.Size([3, 40])

In [7]:
h = torch.zeros(4, 20)
input, batch_sizes, sorted_indices, unsorted_indices = e
split_inputs = torch.split(input, list(batch_sizes), dim=0)
answers = torch.split(out[0], list(batch_sizes), dim=0)
# ordered_list = [c, d, a, b]
final = []
for i in range(len(batch_sizes)):
    size = batch_sizes[i]
    while size < h.shape[0]:
        final.insert(0, h[[-1], :])
        h = h[:-1, :]
    # h = h[:size, :]
    h = torch.tanh((split_inputs[i] @ weights[0].T + weights[2]) + (h @ weights[1].T + weights[3]))
    cmp(i, h, answers[i])
final.insert(0, h)
final = torch.cat(final, dim=0)
final = final.index_select(0, unsorted_indices)
cmp("final", final, hidden)

0 | exact: True  | approximate: True  | maxdiff: 0.0

1 | exact: True  | approximate: True  | maxdiff: 0.0

2 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

3 | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08

4 | exact: False | approximate: True  | maxdiff: 5.960464477539063e-08

5 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

6 | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08

final | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07



In [110]:
h.shape

torch.Size([4, 20])

In [100]:
out[0].shape

torch.Size([18, 20])

In [101]:
answers = torch.split(out[0], list(batch_sizes), dim=0)

In [102]:
[t.shape for t in answers]

[torch.Size([4, 20]),
 torch.Size([4, 20]),
 torch.Size([3, 20]),
 torch.Size([2, 20]),
 torch.Size([2, 20]),
 torch.Size([2, 20]),
 torch.Size([1, 20])]

In [98]:
h.shape

torch.Size([4, 20])

In [95]:
batch_sizes

tensor([4, 4, 3, 2, 2, 2, 1])

In [96]:
[t.shape for t in torch.split(input, list(batch_sizes), dim=0)]

[torch.Size([4, 10]),
 torch.Size([4, 10]),
 torch.Size([3, 10]),
 torch.Size([2, 10]),
 torch.Size([2, 10]),
 torch.Size([2, 10]),
 torch.Size([1, 10])]

In [84]:
(a[[0], :] @ weights[0].T).shape, weights[2].shape

(torch.Size([1, 20]), torch.Size([20]))

In [77]:
(hs == a_out)

tensor([[False, False,  True, False, False, False,  True, False,  True, False,
          True, False, False, False,  True,  True, False,  True,  True, False],
        [False,  True, False, False,  True,  True, False, False, False, False,
          True, False,  True, False, False,  True,  True, False, False, False],
        [False,  True,  True, False, False, False, False,  True, False, False,
         False,  True, False,  True,  True, False,  True,  True,  True, False]])

exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07



In [81]:
hidden.squeeze(0).shape

torch.Size([4, 20])

In [57]:
weights = rnn._flat_weights
[w.shape for w in weights]

[torch.Size([20, 10]),
 torch.Size([20, 20]),
 torch.Size([20]),
 torch.Size([20])]

In [85]:
rnn._flat_weights_names

['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']

In [17]:
d[0].shape, d[1], d[2], d[3]

(torch.Size([12, 10]),
 tensor([3, 3, 2, 1, 1, 1, 1]),
 tensor([2, 0, 1]),
 tensor([1, 2, 0]))

In [16]:
out[0][0].shape, out[0][1], out[0][2], out[0][3]

(torch.Size([12, 10]),
 tensor([3, 3, 2, 1, 1, 1, 1]),
 tensor([2, 0, 1]),
 tensor([1, 2, 0]))

In [18]:
out[1].shape

torch.Size([1, 3, 10])

In [None]:
torch.nn.utils.rnn.unpack_sequence()

In [28]:
[w.shape for w in rnn._flat_weights]

[torch.Size([10, 10]),
 torch.Size([10, 10]),
 torch.Size([10]),
 torch.Size([10])]

In [54]:
f.shape

torch.Size([3, 7, 10])

In [55]:
d.data.shape, e.data.shape

(torch.Size([12, 10]), torch.Size([12, 10]))

In [56]:
e

PackedSequence(data=tensor([[ 3.0735e-01, -2.6678e-01, -3.0902e-01, -1.1208e+00, -3.2936e-01,
          9.3997e-01,  4.4931e-01, -3.2858e-01,  1.5793e+00, -4.5184e-01],
        [ 2.9325e-01, -2.9306e-01, -8.9733e-01, -5.9929e-02,  2.3169e-01,
         -4.0382e-01, -1.3285e+00, -1.4221e+00,  6.5656e-01, -1.2048e+00],
        [ 1.4668e+00, -1.2524e+00,  8.0810e-01,  1.6108e-01,  7.2651e-01,
          1.9860e-01, -1.2810e+00,  6.2184e-01,  3.7555e-01, -4.4060e-01],
        [ 6.1136e-01,  2.2960e+00,  7.1144e-01, -9.4992e-01,  1.6331e+00,
         -1.3578e-01, -6.8200e-01,  1.1770e+00, -6.7226e-01,  4.7319e-01],
        [ 1.1006e+00, -1.1622e+00,  3.4273e+00,  2.6590e-02,  3.6384e-01,
          1.8749e+00,  1.0857e+00, -5.7192e-01, -1.9600e-01, -1.0567e+00],
        [-5.5180e-01, -8.6533e-01, -1.0941e-01, -8.8519e-01,  4.8411e-01,
         -2.4756e-02, -5.5840e-01, -6.4298e-01, -1.7691e-01, -5.6977e-01],
        [-5.4802e-01, -6.3640e-01,  8.6379e-01, -8.6455e-01,  1.5797e+00,
          8.

In [151]:
# fix initialisation values

class RNN:
    def __init__(self, input_size, hidden_size, act_fn = torch.tanh) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.act_fn = act_fn
        self.wih = torch.randn(hidden_size, input_size)
        self.whh = torch.randn(hidden_size, hidden_size)
        self.bih = torch.randn(hidden_size)
        self.bhh = torch.randn(hidden_size)

    def forward(self, input, hidden=None):
        input, batch_sizes, sorted_indices, unsorted_indices = input
        max_batch_size = int(batch_sizes[0])

        if hidden is None:
            hidden = torch.zeros(max_batch_size, self.hidden_size)
        else:
            hidden = hidden.index_select(0, sorted_indices)

        split_inputs = torch.split(input, list(batch_sizes), dim=0)
        out = []
        out_hidden = []
        for i in range(len(batch_sizes)):
            size = batch_sizes[i]
            diff = hidden.shape[0] - size
            if diff > 0:
                out_hidden.insert(0, hidden[-diff:, :])
                hidden = hidden[:size, :]
            hidden = torch.tanh((split_inputs[i] @ self.wih.T + self.bih) + hidden @ self.whh.T + self.bhh)
            out.append(hidden)
        out_hidden.insert(0, hidden)
        out_hidden = torch.cat(out_hidden, dim=0).index_select(0, unsorted_indices)
        out = torch.cat(out, dim=0)
        out = PackedSequence(out, batch_sizes, sorted_indices, unsorted_indices)

        return out, out_hidden

    def parameters(self):
        return [self.wih, self.whh, self.bih, self.bhh]

In [154]:
torch.manual_seed(42)
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = torch.randn(6, 10)
e = pack_sequence([a, b, c, d], enforce_sorted=False)
rnn = RNN(input_size = 10, hidden_size = 20)
out, hidden = rnn.forward(e)
out_unpacked = torch.nn.utils.rnn.unpack_sequence(out)
weights = rnn.parameters()

In [157]:
h = torch.zeros(4, 20)
input, batch_sizes, sorted_indices, unsorted_indices = e
split_inputs = torch.split(input, list(batch_sizes), dim=0)
answers = torch.split(out[0], list(batch_sizes), dim=0)
# ordered_list = [c, d, a, b]
final = []
for i in range(len(batch_sizes)):
    size = batch_sizes[i]
    while size < h.shape[0]:
        final.insert(0, h[[-1], :])
        h = h[:-1, :]
    # h = h[:size, :]
    h = torch.tanh((split_inputs[i] @ weights[0].T + weights[2]) + (h @ weights[1].T + weights[3]))
    cmp(i, h, answers[i])
final.insert(0, h)
final = torch.cat(final, dim=0)
final = final.index_select(0, unsorted_indices)
cmp("final", final, hidden)

0 | exact: True  | approximate: True  | maxdiff: 0.0

1 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

2 | exact: False | approximate: True  | maxdiff: 2.980232238769531e-07

3 | exact: False | approximate: True  | maxdiff: 3.8743019104003906e-07

4 | exact: False | approximate: True  | maxdiff: 3.5762786865234375e-07

5 | exact: False | approximate: True  | maxdiff: 7.748603820800781e-07

6 | exact: False | approximate: True  | maxdiff: 1.043081283569336e-07

final | exact: False | approximate: True  | maxdiff: 7.748603820800781e-07

