# RNN / LSTM / GRU

## setup

In [1]:
from datetime import datetime
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

device = 'cuda' if t.cuda.is_available() else 'cpu'

## utils

In [10]:
# download alice in wonderland
url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
book = requests.get(url).content
book = book.decode('ascii', 'ignore')
vocab = set(book)
d_vocab = len(vocab)
d_hidden = 100
d_batch = 30000
atoi = {a: i for i, a in enumerate(vocab)}
itoa = {i: a for a, i in atoi.items()}

In [11]:
def to_dataloader(text, seq_len=25, batch_size=d_batch):
    # x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len)]
    # y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len)]
    # overlapping sentences (I just want more data to run in parallel for training speed).
    x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len // 3)]
    y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len // 3)]
    x = t.tensor([[atoi[a] for a in s] for s in x])
    y = t.tensor([[atoi[a] for a in s] for s in y])
    print(x.shape)
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = to_dataloader(book)

torch.Size([20539, 25])


## Recurrent Neural Networks (RNN)

### model

In [145]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None, return_memory=False):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory))
        if return_memory:
            return t.stack(outs, dim=1), memory
        return t.stack(outs, dim=1)

model = RNN(d_vocab, d_hidden, d_vocab)

### train

In [200]:
@t.no_grad()
def sample(model, text='A', d_sample=100):
    model = model.to(device)
    memory = t.zeros(1, model.hidden.in_features).to(device)
    x = t.tensor([[atoi[c] for c in text]])
    x = F.one_hot(x, num_classes=d_vocab).float().to(device)
    while len(text) <= d_sample:
        outs, memory = model(x, memory=memory, return_memory=True)
        probs = outs[0, -1].softmax(dim=0)
        next_sample = t.multinomial(probs, num_samples=1)
        text += itoa[next_sample.item()]
        x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)[:, None, :]
    return text

# sample(model)

In [195]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4):
    model = model.to(device)
    if opt is None:
        opt = t.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(sample(model))
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/rnn_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

train(model, dataloader, epochs=1000000)

  0%|          | 0/1000000 [00:00<?, ?it/s]

loss=1.2772
Alice their writting in the did magie doors turself bew to kne with cate of rull, withoumpotink: to t
loss=1.2772
loss=1.2769
loss=1.2769
loss=1.2769
loss=1.2769
loss=1.2769
loss=1.2769
loss=1.2770
loss=1.2771
loss=1.2770
loss=1.2769
loss=1.2769
loss=1.2769
loss=1.2772
loss=1.2770
loss=1.2769
loss=1.2768
loss=1.2771
loss=1.2771
Adlb, and thright forms me minacht,

Essais and out, and this whole sharplly the Foof now, invershe
loss=1.2768
loss=1.2770
loss=1.2768
loss=1.2770
loss=1.2771
loss=1.2768
loss=1.2769
loss=1.2770
loss=1.2768
loss=1.2771
loss=1.2768
loss=1.2771
loss=1.2767
loss=1.2768
loss=1.2767
loss=1.2770
loss=1.2766
loss=1.2771
loss=1.2769
loss=1.2768
Additten trightr had to laks and beturious you dinching her head garemark a fied; toute
out, ifore I
loss=1.2766
loss=1.2769
loss=1.2766
loss=1.2768
loss=1.2768
loss=1.2767
loss=1.2768
loss=1.2769
loss=1.2766
loss=1.2768
loss=1.2767
loss=1.2765
loss=1.2768
loss=1.2764
loss=1.2766
loss=1.2768
loss=1.2764
loss=1.2764
l

KeyboardInterrupt: 

In [196]:
t.save(model.state_dict(), f'weights/rnn_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

In [202]:
print(sample(model, d_sample=2000))

Alice begained to tenns but as she
cant purreveng at little hurty ot ratuol by thats Rrout the
hustR Yoo and
frop.geater pon the
course, the word-202X02.1yw. Lork, prowintly.

Yoo Alinks bug, said the copy ancomman a fromy of what she such at everyeds morations a tur Crokh 1.F.3. Im toin; them loughingent
he
seemed with as puriouse!

Donot finiven tong goid. Ball off.

Ther_ tr one replies, and the Come aft under have grogstly.

How.
A Lotting!
    The King! in these air, of
tigering to say _stelff
when replice houmed my walk mes tw mand some
that could newly, and tuld to she ran wond shratcl what yourd much time or onied! Ther, sho krise it, said not, in tegling-or of ser any indotrenter a the
it been ugeness. I cistur well bewe Quee into listed the Rurwn
on, Im spents. Ever upoutiffe; and see a dredown on showher weye happen a minulden refayer, the musting most be question of the Gryphon, call cacs with the fee ant rateed for Furdly, for fanthed it, of at
rething and formauted, bits,

## Long Short-Term Memory (LSTM)

### model

In [12]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class LSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        c_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstm = LSTM(d_vocab, d_hidden, d_vocab).to(device)
lstm.sample()

'AL,F_9Z4Or1F3BFQ6KYTj5uDKsbN$%m:OQ?Er5KCDKUth%/Lv/0wg!.YHQI-:#f8.Vg.aCWgz_]4)T)*WJ1LE9xT.u8dS4pI?p.4'

In [15]:
import gc
gc.collect()

897

### train

In [16]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4):
    model = model.to(device)
    if opt is None:
        opt = t.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(model.sample())
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/lstm_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

train(lstm, dataloader) #, epochs=1000000)

  0%|          | 0/2000 [00:00<?, ?it/s]

loss=4.4206
F:jL$#IeCnqUu:Gh5oW,DT4#m#xb2roKQ:
abw ,
01THK_PE*i
loss=3.9444
loss=3.3473
loss=3.3060
loss=3.2807
loss=3.2591
loss=3.2370
loss=3.2122
loss=3.1823
loss=3.1425
loss=3.0909
loss=3.0284
loss=2.9610
loss=2.8950
loss=2.8325
loss=2.7757
loss=2.7241
loss=2.6776
loss=2.6360
loss=2.5990
loss=2.5658
tfowdWldishunkerd bn
I

cPltee,dcors,m ,le,ly hedte il enehD,tasf fasrete n. soru
loss=2.5355
loss=2.5076
loss=2.4815
loss=2.4567
loss=2.4331
loss=2.4108
loss=2.3897
loss=2.3695
loss=2.3503
loss=2.3321
loss=2.3147
loss=2.2980
loss=2.2821
loss=2.2667
loss=2.2520
loss=2.2377
loss=2.2238
loss=2.2103
loss=2.1973


In [18]:
lstm.sample()

'Ageg!\r\n\r\nAbice geturg, late wele pleath.qF Adde Aoich, porld pevenghl!\r\n\r\n  of them nat thing, as th'