In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
import random

from torch.nn.utils.rnn import PackedSequence
%matplotlib inline

In [2]:
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = pack_sequence([a, b, c], enforce_sorted=False)

a = torch.cat((a, torch.zeros(4, 10)))
b = torch.cat((b, torch.zeros(5, 10)))
f = torch.stack([a, b, c])
e = pack_padded_sequence(f, [3,2,7], batch_first = True, enforce_sorted=False)
(d[0] == e[0]).all()

tensor(True)

In [3]:
def cmp(label, t1, t2):
  ex = torch.all(t1 == t2).item()
  app = torch.allclose(t1, t2, rtol=1e-04, atol=1e-07)
  maxdiff = (t1 - t2).abs().max().item()
  print(f'{label} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')
  print()

In [4]:
torch.manual_seed(42)
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = torch.randn(6, 10)
e = pack_sequence([a, b, c, d], enforce_sorted=False)
lstm = torch.nn.LSTM(input_size = 10, hidden_size = 20)
out, (hn, cn) = lstm.forward(e)
out_unpacked = torch.nn.utils.rnn.unpack_sequence(out)
weights = lstm._flat_weights

In [16]:
e[0].shape

torch.Size([18, 10])

In [15]:
out[0].shape, hn[0].shape, cn[0].shape

(torch.Size([18, 20]), torch.Size([4, 20]), torch.Size([4, 20]))

In [12]:
[w.shape for w in weights]

[torch.Size([80, 10]),
 torch.Size([80, 20]),
 torch.Size([80]),
 torch.Size([80])]

In [13]:
lstm._flat_weights_names

['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']

In [17]:
h0 = torch.zeros(1, 20)
c0 = torch.zeros(1, 20)

In [18]:
a[[0], :].shape

torch.Size([1, 10])

In [34]:
preact = ((e[0][:4, :] @ weights[0].T + + weights[2]) + (h0 @ weights[1].T + weights[3]))
# print(preact.shape)
i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
# print(i.shape, f.shape, g.shape, o.shape)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c1 = f * c0 + i*g
h1 = o * torch.tanh(c1)

In [35]:
h1.shape, c1.shape

(torch.Size([4, 20]), torch.Size([4, 20]))

In [36]:
preact = (a[[0], :] @ weights[0].T + + weights[2] + h0 @ weights[1].T + weights[3])
# print(preact.shape)
i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
# print(i.shape, f.shape, g.shape, o.shape)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c1 = f * c0 + i*g
h1 = o * torch.tanh(c1)

preact = ((a[[1], :] @ weights[0].T + + weights[2]) + (h1 @ weights[1].T + weights[3]))
# print(preact.shape)
i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
# print(i.shape, f.shape, g.shape, o.shape)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c2 = f * c1 + i*g
h2 = o * torch.tanh(c2)

preact = ((a[[2], :] @ weights[0].T + + weights[2]) + (h2 @ weights[1].T + weights[3]))
# print(preact.shape)
i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
# print(i.shape, f.shape, g.shape, o.shape)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c3 = f * c2 + i*g
h3 = o * torch.tanh(c3)

In [37]:
hs = torch.cat([h1, h2, h3], dim=0)
a_out = out_unpacked[0]
cmp("hs", hs, a_out)
cmp("h3", h3, hn.squeeze(0)[[0], :])
cmp("c3", c3, cn.squeeze(0)[[0], :])

hs | exact: False | approximate: True  | maxdiff: 4.470348358154297e-08

h3 | exact: False | approximate: True  | maxdiff: 4.470348358154297e-08

c3 | exact: False | approximate: True  | maxdiff: 5.960464477539063e-08



In [30]:
cmp("h1", h1, out_unpacked[0][[0], :])

h1 | exact: False | approximate: True  | maxdiff: 2.9802322387695312e-08



In [25]:
h1.shape

torch.Size([1, 20])

In [29]:
out_unpacked[0].shape

torch.Size([3, 20])

In [6]:
h0 = torch.zeros(1, 20)
c0 = torch.zeros(1, 20)
h1 = torch.tanh((a[[0], :] @ weights[0].T + weights[2]) + (h0 @ weights[1].T + weights[3]))
h2 = torch.tanh((a[[1], :] @ weights[0].T + weights[2]) + (h1 @ weights[1].T + weights[3]))
h3 = torch.tanh((a[[2], :] @ weights[0].T + weights[2]) + (h2 @ weights[1].T + weights[3]))

In [7]:
hs = torch.cat([h1, h2, h3], dim=0)
a_out = out_unpacked[0]
cmp("hs", hs, a_out)
cmp("h3", h3, hidden.squeeze(0)[[0], :])

hs | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

h3 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07



In [8]:
h = torch.zeros(4, 20)
input, batch_sizes, sorted_indices, unsorted_indices = e
split_inputs = torch.split(input, list(batch_sizes), dim=0)
answers = torch.split(out[0], list(batch_sizes), dim=0)
# ordered_list = [c, d, a, b]
final = []
for i in range(len(batch_sizes)):
    size = batch_sizes[i]
    while size < h.shape[0]:
        final.insert(0, h[[-1], :])
        h = h[:-1, :]
    # h = h[:size, :]
    h = torch.tanh((split_inputs[i] @ weights[0].T + weights[2]) + (h @ weights[1].T + weights[3]))
    cmp(i, h, answers[i])
final.insert(0, h)
final = torch.cat(final, dim=0)
final = final.index_select(0, unsorted_indices)
cmp("final", final, hidden)

0 | exact: True  | approximate: True  | maxdiff: 0.0

1 | exact: True  | approximate: True  | maxdiff: 0.0

2 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

3 | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08

4 | exact: False | approximate: True  | maxdiff: 5.960464477539063e-08

5 | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07

6 | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08

final | exact: False | approximate: True  | maxdiff: 1.1920928955078125e-07



In [39]:
class LSTM:

    def __init__(self, input_size, hidden_size, device = torch.device("cpu")) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.wih = torch.randn(4 * hidden_size, input_size).uniform_(-hidden_size**-0.5, hidden_size**-0.5).to(device)
        self.whh = torch.randn(4 * hidden_size, hidden_size).uniform_(-hidden_size**-0.5, hidden_size**-0.5).to(device)
        self.bih = torch.randn(4 * hidden_size).uniform_(-hidden_size**-0.5, hidden_size**-0.5).to(device)
        self.bhh = torch.randn(4 * hidden_size).uniform_(-hidden_size**-0.5, hidden_size**-0.5).to(device)

    def forward(self, input, hx = None):
        input, batch_sizes, sorted_indices, unsorted_indices = input
        max_batch_size = int(batch_sizes[0])

        if hx is None:
            hidden = torch.zeros(max_batch_size, self.hidden_size).to(self.device)
            cell = torch.zeros(max_batch_size, self.hidden_size).to(self.device)
        else:
            hidden = hx[0].index_select(0, sorted_indices)
            cell = hx[1].index_select(0, sorted_indices)

        split_inputs = torch.split(input, list(batch_sizes), dim=0)
        out = []
        out_hidden = []
        out_cell = []
        for i in range(len(batch_sizes)):
            size = batch_sizes[i]
            diff = hidden.shape[0] - size
            if diff > 0:
                out_hidden.insert(0, hidden[-diff:, :])
                out_cell.insert(0, cell[-diff:, :])
                hidden = hidden[:size, :]
                cell = cell[:size, :]

            preact = (split_inputs[i] @ self.wih.T + self.bih + hidden @ self.whh.T + self.bhh)
            i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
            i = torch.sigmoid(i)
            f = torch.sigmoid(f)
            g = torch.tanh(g)
            o = torch.sigmoid(o)
            cell = f * cell + i*g
            hidden = o * torch.tanh(cell)
            out.append(hidden)
        out_hidden.insert(0, hidden)
        out_cell.insert(0, cell)
        out_hidden = torch.cat(out_hidden, dim=0).index_select(0, unsorted_indices)
        out_cell = torch.cat(out_cell, dim=0).index_select(0, unsorted_indices)
        out = torch.cat(out, dim=0)
        out = PackedSequence(out, batch_sizes, sorted_indices, unsorted_indices)

        return out, (out_hidden, out_cell)

    def parameters(self):
        return [self.wih, self.whh, self.bih, self.bhh]
    
    def to(self, device):
        self.device = device
        self.wih = self.wih.to(device)
        self.whh = self.whh.to(device)
        self.bih = self.bih.to(device)
        self.bhh = self.bhh.to(device)

In [42]:
torch.manual_seed(42)
a = torch.randn(3, 10)
b = torch.randn(2, 10)
c = torch.randn(7, 10)
d = torch.randn(6, 10)
e = pack_sequence([a, b, c, d], enforce_sorted=False)
lstm = LSTM(input_size = 10, hidden_size = 20)
out, (hn, cn) = lstm.forward(e)
out_unpacked = torch.nn.utils.rnn.unpack_sequence(out)
weights = lstm.parameters()

In [45]:
h = torch.zeros(4, 20)
c = torch.zeros(4, 20)
input, batch_sizes, sorted_indices, unsorted_indices = e
split_inputs = torch.split(input, list(batch_sizes), dim=0)
answers = torch.split(out[0], list(batch_sizes), dim=0)
# ordered_list = [c, d, a, b]
final_hidden = []
final_cell = []
for j in range(len(batch_sizes)):
    size = batch_sizes[j]
    while size < h.shape[0]:
        final_hidden.insert(0, h[[-1], :])
        final_cell.insert(0, c[[-1], :])
        h = h[:-1, :]
        c = c[:-1, :]

    preact = (split_inputs[j] @ weights[0].T + + weights[2] + h @ weights[1].T + weights[3])
    # print(preact.shape)
    i, f, g, o = torch.split(preact, [20, 20, 20, 20], dim=1)
    # print(i.shape, f.shape, g.shape, o.shape)
    i = torch.sigmoid(i)
    f = torch.sigmoid(f)
    g = torch.tanh(g)
    o = torch.sigmoid(o)
    c = f * c + i*g
    h = o * torch.tanh(c)
    cmp(j, h, answers[j])
final_hidden.insert(0, h)
final_cell.insert(0, c)
final_hidden = torch.cat(final_hidden, dim=0)
final_cell = torch.cat(final_cell, dim=0)
final_hidden = final_hidden.index_select(0, unsorted_indices)
final_cell = final_cell.index_select(0, unsorted_indices)
cmp("final", final_hidden, hn)
cmp("final", final_cell, cn)

0 | exact: True  | approximate: True  | maxdiff: 0.0

1 | exact: True  | approximate: True  | maxdiff: 0.0

2 | exact: True  | approximate: True  | maxdiff: 0.0

3 | exact: True  | approximate: True  | maxdiff: 0.0

4 | exact: True  | approximate: True  | maxdiff: 0.0

5 | exact: True  | approximate: True  | maxdiff: 0.0

6 | exact: True  | approximate: True  | maxdiff: 0.0

final | exact: True  | approximate: True  | maxdiff: 0.0

final | exact: True  | approximate: True  | maxdiff: 0.0

