# RNN / LSTM / GRU

## setup

In [22]:
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F

## utils

In [16]:
# download alice in wonderland
url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
book = requests.get(url).content
book = book.decode('ascii', 'ignore')
alphabet = ['<eof>'] + list(set(book))

## Recurrent Neural Networks (RNN)

In [76]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory)) # this is wasteful if we only care about the last output.
        return t.stack(outs, dim=1)

model = RNN(1, 5, 1)
opt = t.optim.Adam(model.parameters(), lr=3e-4)

In [78]:
ds = t.tensor([
    [0, 1, 0, 1, 0, 1, 0, 1, 0],
    [1, 0, 1, 0, 1, 0, 1, 0, 1],
    [0, 0, 1, 0, 0, 1, 0, 0, 1],
    [0, 1, 0, 0, 1, 0, 0, 1, 0],
    [1, 0, 0, 1, 0, 0, 1, 0, 0],
], dtype=t.float)

def train_final_output(epochs=2000):
    xs = ds[:, :-1, None]
    ys = ds[:, -1:]
    for epoch in range(epochs):
        out = model(xs)[:, -1]
        loss = F.mse_loss(out, ys)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if epoch % 100:
            print(f'loss={loss.item():.4f}')

def train_all_output(epochs=2000):
    xs = ds[:, :-1, None]
    ys = ds[:, 1:, None]
    for epoch in range(epochs):
        out = model(xs)
        loss = F.mse_loss(out, ys)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if epoch % 100:
            print(f'loss={loss.item():.4f}')

# train_final_output()
train_all_output()

loss=0.0493
loss=0.0493
loss=0.0492
loss=0.0492
loss=0.0492
loss=0.0492
loss=0.0491
loss=0.0491
loss=0.0491
loss=0.0491
loss=0.0491
loss=0.0490
loss=0.0490
loss=0.0490
loss=0.0490
loss=0.0489
loss=0.0489
loss=0.0489
loss=0.0489
loss=0.0489
loss=0.0488
loss=0.0488
loss=0.0488
loss=0.0488
loss=0.0487
loss=0.0487
loss=0.0487
loss=0.0487
loss=0.0487
loss=0.0486
loss=0.0486
loss=0.0486
loss=0.0486
loss=0.0486
loss=0.0485
loss=0.0485
loss=0.0485
loss=0.0485
loss=0.0484
loss=0.0484
loss=0.0484
loss=0.0484
loss=0.0484
loss=0.0483
loss=0.0483
loss=0.0483
loss=0.0483
loss=0.0482
loss=0.0482
loss=0.0482
loss=0.0482
loss=0.0482
loss=0.0481
loss=0.0481
loss=0.0481
loss=0.0481
loss=0.0481
loss=0.0480
loss=0.0480
loss=0.0480
loss=0.0480
loss=0.0479
loss=0.0479
loss=0.0479
loss=0.0479
loss=0.0479
loss=0.0478
loss=0.0478
loss=0.0478
loss=0.0478
loss=0.0478
loss=0.0477
loss=0.0477
loss=0.0477
loss=0.0477
loss=0.0476
loss=0.0476
loss=0.0476
loss=0.0476
loss=0.0476
loss=0.0475
loss=0.0475
loss=0.0475
loss

In [95]:
@t.no_grad()
def pred(xs):
    out = model(t.tensor(xs, dtype=t.float)[None, :, None])[0]
    pretty = ', '.join(f'{x[0]: 2.2f}' for x in out.tolist())
    print(f'[{pretty}]')

# predict on training set
pred([0, 1, 0, 1, 0, 1, 0, 1])
pred([0, 0, 1, 0, 0, 1, 0, 0])
# on longer sequence
pred([0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])

[ 0.67,  0.01,  0.50,  0.01,  1.01, -0.00,  0.99,  0.00]
[ 0.67,  0.98,  0.03, -0.01,  0.99, -0.01, -0.01,  0.99]
[ 0.67,  0.98,  0.03, -0.01,  0.99, -0.01, -0.01,  0.99, -0.01,  0.00,  0.99, -0.01,  0.00,  0.99, -0.01]
