# RNN / LSTM / GRU

## setup

In [1]:
from datetime import datetime
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

device = 'cuda' if t.cuda.is_available() else 'cpu'

## utils

In [2]:
# download alice in wonderland
url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
book = requests.get(url).content
book = book.decode('ascii', 'ignore')
vocab = set(book)
d_vocab = len(vocab)
d_hidden = 100
d_batch = 10000
atoi = {a: i for i, a in enumerate(vocab)}
itoa = {i: a for a, i in atoi.items()}

In [3]:
def to_dataloader(text, seq_len=25, batch_size=d_batch):
    x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len)]
    y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len)]
    x = t.tensor([[atoi[a] for a in s] for s in x])
    y = t.tensor([[atoi[a] for a in s] for s in y])
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = to_dataloader(book)

torch.Size([6573, 25])


## train

In [None]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename=''):
    model = model.to(device)
    if opt is None:
        opt = t.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(model.sample())
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

## Recurrent Neural Networks (RNN)

### model

In [25]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.d_hidden = d_hidden
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None, return_memory=False):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory))
        if return_memory:
            return t.stack(outs, dim=1), memory
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = F.tanh(self.embed(x) + self.hidden(h_prev))
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

rnn = RNN(d_vocab, d_hidden, d_vocab).to(device)
rnn.sample()

'A6Db7;?2Q/4Z[sY ]7TTbBmjlYjyTUfh.jzeCQJ?]1G(e8Ye]kdNh*RFlq)G]\rKMhj:r12lWDp6HH[P23IX?p;_p3pxguN6zRbyJ'

### train

In [28]:
train(rnn, dataloader, filename='rnn')

  0%|          | 0/1 [00:00<?, ?it/s]

loss=4.4570
A/v3LurN7-fk]ZTu2o,h Y#Hm0I*dGXz1;.#RTe8von1]svmGG[*9:.#s)GZ:W%Ox5K[Tmt ESkSiikHrc(YRyq5fqQiD;N sMjg


In [29]:
print(rnn.sample(d_sample=2000))

a%-s?pukC2_DbNW /H!Iikd8hNT*wD$WS,lFxE0!e-[je:WN;ri5TdC;u?By2-0LpzWvNB5Jm4O6baw%!#wd1,Otw.jaH psg)_7CAfswZ9
3:C:fwRHPc-kWpe*xT_BV_7VqV#a0OdlA pwzrFz)m.L H
rZnECCz)
wVt)#,iajYRxx3*n/z5Mw,MLT*.3,v0Vvf-)sV2ozqh,abpyq*W?O_JhJqbvh
4*F*ghhvQrd!DZPgc
fMTL
W xjBQKyIyL)joZ;G#)6 rvjzQG6s/C$v?
RCk5-ozY?7c8hXC;Fsh. Xzoo*Z9H!$IpuEM%F(nRIL39cog2ncvQa]V]59])oGv]Dm-mw4PR7sjvH0qt?L: LkH4jL:a/7?xA0Y,1UZvns[v.'XwX1,8(I7_Ie0N';J5:Whz,6_[;]EKjwkHiWJI$
 9I*_'UCj!yTzk3PR5G;1CLzM)7Z1FCY:*][2.[dlM6#PJx'v-P(7[Gog5UrDR1o$3Z:V( /po/-%h)#]AQA1_31QF(G_hx;HwFT2;8qgDzTbheh9?nmYpWNe1!GGK5)hhb9$:7MZG%P!)Yx aM[.4yO (R.,k)'$qE_%OY,X
iHq*g]:0K
678ulSD-0DQ19hD3ZNbEsXCDS*ap25jf5Un784u 1yAh5F_A*1iS ,/d_KrR'a3)CATD4dCI1a'mCHI_ba]1O3sFZyZFK!utD;mVNs[:wc#G!oGo:R6y1kg_;MmTcl-35G,:7ty6 /hhL*.#E/!#-KtR$'_y)kDVBK/lci1
3sI9xlW#j!CJV6MdH9T3I'x9ez,;01(fS[25rRRG N!Wsm7DOZB*M43;NNutOW6B31-?Z4KBXB%(?lF%Qk(qpI#R[ik#;_HC.RJRNhl*bFx#Sy;H%NCQo4Hk6L4Q_'M
ldg/7*b-)CZ!hjgj_I/jeM'AAqidp;$_%yTCy
Lx1De6yjpTp#)R9q-v:rh1LWhulr?Ex5Hrx(mym8ChnCk3ofZ4'

## Long Short-Term Memory (LSTM)

### model

In [8]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class LSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        c_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstm = LSTM(d_vocab, d_hidden, d_vocab).to(device)
lstm.sample()

'A\nmEbZQqPp5?EewrPh%lKTu.HX$5U?2\ruQ_9.7:0Of8(QZ ,fLRcvhJ/]00UJe6C,T.XHpqkS%k\rWf]vuj\r)/N55Q.830j[Gg5H1'

### train

In [30]:
train(lstm, dataloader, filename='lstm')

  0%|          | 0/1 [00:00<?, ?it/s]

loss=1.7207
Alice thought to hant are
( in roblift
very sirt, add offerantyNe! a2
QuekinNHle
Itlld theve I ki


In [19]:
print(lstm.sample(d_sample=2000))

Alice trying ear lister as), to polles was very curkkiled tonfuped be waithe it her stern soning! said the Kings and live y a
have going to face about height!

Well! some and fromale tone, howeman inner you glieBlo
loner a moment in whicentage. It tweem his, you forrono. I cutt__Syy
UbbicbjecV ccco*rGG tramp, in throug he was indow I amblid. Ante fon be a little came a litt! said that! You fullive edst by elvose inviteninged shy (Gut THAT  BRD T pUbsb4:,
QQutends. 3.03. UUASBfSSUEQ.. X H FOF DEBBKAGEERE YoY) WYOIURAPTYSH
F-UROONTTOANWLA
********************TOF*************************** ****** ME'Q  KMKVMKUKKezzzzze abadd GrateboWhH33
KidpK)) (F2KUKEKQRKQEKNGMMERE3EEEXNXMM8ENNGoESMa**zkN!
over.

 Let byot, Ale (sseding little toudes alove ome a little propers
1.ACI chart  im((Hf hau any licean tow reirnees
5.e thregs inne it is eath disar! said the Doromot AlYoxid__TVTEE.
ESUE! Wh!_
(EVEVIg. PET9_ gree horman, she saidetod ofter three
1.1. boxes,
and vonus aspersthozyNO by! Grypiol.
Si

## Gated Recurrent Unit (GRU)

### model

In [20]:
class GRUCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_r = nn.Linear(d_in + d_hidden, d_hidden)  # reset gate
        self.W_z = nn.Linear(d_in + d_hidden, d_hidden)  # update gate
        self.W_h = nn.Linear(d_in + d_hidden, d_hidden)  # hidden state update

    def forward(self, x, h_prev):
        cat = t.cat((x, h_prev), dim=1)
        r_gate = t.sigmoid(self.W_r(cat))
        z_gate = t.sigmoid(self.W_z(cat))
        h_candidate = t.tanh(self.W_h(t.cat((x, r_gate * h_prev), dim=1)))
        h_prev = (1 - z_gate) * h_prev + z_gate * h_candidate
        return h_prev

class GRU(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.gru_cell = GRUCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev = self.gru_cell(x, h_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = self.gru_cell(x, h_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

gru = GRU(d_vocab, d_hidden, d_vocab).to(device)
gru.sample()

'A%kj1!6OZ\nYNQ1A\rLPR\rLP*cd(][PR9Jv#.hXjzvTyB/h 3Zli;X zCFzl!AOiW#bkg$s/ZS;ECXvsOqVbw8sptAKDrv7SpMm]4J'

### train

In [21]:
train(gru, dataloader, filename='gru') #, epochs=1000000)

  0%|          | 0/2001 [00:00<?, ?it/s]

loss=4.4208
A4HF!P#Q6q-/pY*G
H X2q'bJ7jkI3'FlfGmb$8U0IMTD.o*Z_%t*qHRz/ hfvRa5sV%q Ryry,JHZG%d
j;6U1'W:zD2:
loss=3.9851
loss=3.3123
loss=3.2755
loss=3.2555
loss=3.2387
loss=3.2221
loss=3.2012
loss=3.1749
loss=3.1397
loss=3.0919
loss=3.0300
loss=2.9537
loss=2.8685
loss=2.7788
loss=2.6888
loss=2.6147
loss=2.5566
loss=2.5068
loss=2.4611
loss=2.4176
bind, mismins tuidky bther  iy out arlive thrdloun
loss=2.3762
loss=2.3370
loss=2.3000
loss=2.2656
loss=2.2333
loss=2.2029
loss=2.1740
loss=2.1464
loss=2.1197
loss=2.0937
loss=2.0686
loss=2.0445
loss=2.0217
loss=1.9998
loss=1.9787
loss=1.9583
loss=1.9386
loss=1.9193
loss=1.9008
loss=1.8828
Alice, bont hinksed ortes oo gell hessy tame yow the sore the Dochibs urdes? the Kers lratent
ar wr
Anited thall wat on whiling tad had youl dotn itsatling Gut, ag tuntots, they dien, and as themcling they loos,
at estirupep: asr toige at I id to is
was gringid ous no

There furker the eavert dokedley.

The could? She Fhomlliteminvy batker, wes is Anderghithtr

In [22]:
print(gru.sample(d_sample=2000))

Ali s mano you, wooup. But sloal the
sooptint nolefilk the greanl,
puromile hor5 and a
deoftone.
Alice, It veny; Aliced, the ssen in
uto clyouged a phon entt te maly
I leapy dosplule, they seac Tuplly ofutent oufnre esazing sal on a paoment you
to corefed croand, said at wareme thean, thewr thee
some, said. Whice proced youns it abeed infnesenting sal she dinking.

Ahe got sead, the cuthen she dentor haid toremily.

Thiurd once say
in a sa nous!

 hery down and you rabe suct omoutt
ever? of coren aby to it tpound.

I bed ih wint nither to gite ma grois: atlitsteney
tnem anc_, po the

Heride and ig ver as buthen for anl Projecze, So_ was che her sioninging, shine
Alice, very to foo uped han weer oustle raject jurm she lact, tho gow ssEe wiln ters os in.

Ind of to dime
thet seeped hermely salt perypled offenther yeald! Whish FiR RHAHAciten.

The Duck Iboud in w. Henealle, wes elikely ais befinisusbyut wotde af ts mad of then at on wfrye fithibis _, adrerules wort way wese_ Mucce. I_ po_