# RNN / LSTM / GRU

## setup

In [1]:
from datetime import datetime
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

device = 'cuda' if t.cuda.is_available() else 'cpu'

## utils

In [2]:
# download alice in wonderland
url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
book = requests.get(url).content
book = book.decode('ascii', 'ignore')
vocab = set(book)
d_vocab = len(vocab)
d_hidden = 100
d_batch = 10000
atoi = {a: i for i, a in enumerate(vocab)}
itoa = {i: a for a, i in atoi.items()}

In [3]:
def to_dataloader(text, seq_len=25, batch_size=d_batch):
    x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len)]
    y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len)]
    x = t.tensor([[atoi[a] for a in s] for s in x])
    y = t.tensor([[atoi[a] for a in s] for s in y])
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = to_dataloader(book)

torch.Size([6573, 25])


## Recurrent Neural Networks (RNN)

### model

In [4]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None, return_memory=False):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory))
        if return_memory:
            return t.stack(outs, dim=1), memory
        return t.stack(outs, dim=1)

rnn = RNN(d_vocab, d_hidden, d_vocab)

### train

In [5]:
@t.no_grad()
def sample(model, text='A', d_sample=100):
    model = model.to(device)
    memory = t.zeros(1, model.hidden.in_features).to(device)
    x = t.tensor([[atoi[c] for c in text]])
    x = F.one_hot(x, num_classes=d_vocab).float().to(device)
    while len(text) <= d_sample:
        outs, memory = model(x, memory=memory, return_memory=True)
        probs = outs[0, -1].softmax(dim=0)
        next_sample = t.multinomial(probs, num_samples=1)
        text += itoa[next_sample.item()]
        x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)[:, None, :]
    return text

In [6]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4):
    model = model.to(device)
    if opt is None:
        opt = t.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(sample(model))
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/rnn_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

train(rnn, dataloader)

  0%|          | 0/2001 [00:00<?, ?it/s]

loss=4.4122
ASntTexs1z!JC#O#dLnn'PkTtr77NB0H;[fXg64#'OIfGbtCc2#je):-xkDb.y/_,stP-5fjFpV!BIR2fS96q l_D)0,Xj0P.NwI?
loss=3.3527
loss=3.2293
loss=3.2073
loss=3.1836
loss=3.1493
loss=3.0961
loss=3.0164
loss=2.9224
loss=2.8226
loss=2.7290
loss=2.6469
loss=2.5771
loss=2.5167
loss=2.4629
loss=2.4139
loss=2.3692
loss=2.3295
loss=2.2943
loss=2.2622
loss=2.2327
Aliter tains ghe ke she to f saicJ tc
wotrehi, koaidi6g! Ihd terang of uthid tll_mtcouve m*se   he 
loss=2.2052
loss=2.1795
loss=2.1553
loss=2.1327
loss=2.1115
loss=2.0915
loss=2.0724
loss=2.0544
loss=2.0371
loss=2.0206
loss=2.0047
loss=1.9894
loss=1.9748
loss=1.9606
loss=1.9470
loss=1.9338
loss=1.9211
loss=1.9088
loss=1.8969
loss=1.8853
Alice: thead, all s
vertur, wand rek to be werhar hirghny Fron coulbebedoiget, but ol the King if to


In [7]:
print(sample(rnn, d_sample=2000))

Alice soon the Gryphan?cas  be tuthous it low, said GU
Tha boof jurs wore on phorda git. jad oo sher.

Thes of
herelird about aqlimnsee the  aimens corsaid ous meem.

1PDoIm 
 A CAtH
AUNPREATI M NOR.HD IHH T61.ED3. OE LIN I R (HM HX Gut dbarg 
ffellferrved, iny_ fry. I
dor  o  o pome g a deputseen all
to dowence, all tare, and toite say, was unde to rowick orkint hes asave lotsus sheaver on, and was heswould, said thee southe
foume.

The piole hoandef, the laggion,
and chouttongroon courd  to bes
she tay thed not kinl 

yon tho kinll I dagant the leece sQuein, the  ho chiment it witht and heat inea nol, th thel nel saed to quetre fortanis hendes und it dawry hat she Lit
the DokbigeMo

Thene; in it has seadide betas sain nlermuth a? nat futthes the Micg.

Why, sto hai kagying anverson ir was ple witn, thing hind the Creltintis whed in the Mack of eloop ofe undared of, it make theep, taitid waid
on then the gedar tames eroughtabling, s outhea wouthing whet whele, Imbeig. Whe hergelo to g

## Long Short-Term Memory (LSTM)

### model

In [8]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class LSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        c_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstm = LSTM(d_vocab, d_hidden, d_vocab).to(device)
lstm.sample()

'A\nmEbZQqPp5?EewrPh%lKTu.HX$5U?2\ruQ_9.7:0Of8(QZ ,fLRcvhJ/]00UJe6C,T.XHpqkS%k\rWf]vuj\r)/N55Q.830j[Gg5H1'

### train

In [9]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4):
    model = model.to(device)
    if opt is None:
        opt = t.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(model.sample())
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/lstm_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

train(lstm, dataloader)

  0%|          | 0/2001 [00:00<?, ?it/s]

loss=4.4149
A*pe;GMal3dJ,dV*S(8!cp2m0L7_,57,TH:
91h!OQ8e[l*hphwOlTU$:[8tKkrL_A4%oIx
loss=3.9804
loss=3.3483
loss=3.3042
loss=3.2772
loss=3.2533
loss=3.2274
loss=3.1974
loss=3.1605
loss=3.1141
loss=3.0566
loss=2.9929
loss=2.9276
loss=2.8631
loss=2.8013
loss=2.7443
loss=2.6925
loss=2.6456
loss=2.6031
loss=2.5642
loss=2.5291
gsana tharornet tli
aingglald had nn thre sodeittat Hhoud,gang,e!gs i
loss=2.4978
loss=2.4695
loss=2.4436
loss=2.4195
loss=2.3970
loss=2.3757
loss=2.3550
loss=2.3355
loss=2.3168
loss=2.2991
loss=2.2822
loss=2.2659
loss=2.2502
loss=2.2348
loss=2.2199
loss=2.2057
loss=2.1921
loss=2.1792
loss=2.1668
loss=2.1548
Ayor bege ang thy wit diy Alothe totherite Th ued fondushing to clous. H:
Alice brsaidst-wepkey bur


In [11]:
print(lstm.sample(d_sample=2000))

As, yano hhaline duthen touroquer, and said comd yat
_ts buped qhey  avid to nnguir: pnojce med rolereag coid.

Mothe , Alis!, on her ins  he seases, ave, merind
nteriBusthe.
coule inithat
niald Aidr toen it it om le Lernan y arly tas ang thar ovtgin berer alle the batm bure by
I
Ne the sus
one taot
nge doln.

Wh th Pryou string oof anh fow becego s whe meendily he gile at, buling ardeus Aidk var utlidtn, nowy h or
sthe,,S wich hick mogtto in wefbladg, ily
caid-trined bu toveen.

Aads bes_, Yas it sh srens! bu theble t atles, Abdite tohe bere megond thitfing oflrno ghen it Alist at ente
neryicg ao_ muvet afstm bund anche  aAlided iof ove cotticn betes,, an  he tourd douutine hers adll,
wiss sose tour in cheab be whowhs ingadotheneshoucang yun ating! You
sita fomedests, ansery,
agting t wa do tusthice to t in Ale wi ginver fovg is mukt on leraat: I liwt dowe taode, shy
Dumagthew nfere? whit: Ous ly;ehk to dber tha Ques rogsve, tad st the wich oned p ard of tt and oulk: the seroningenfel

## Gated Recurrent Unit (GRU)

### model

### train