# RNN / LSTM / GRU

## setup

In [1]:
from datetime import datetime
from torchvision import datasets, transforms
import wandb
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

device = 'cuda' if t.cuda.is_available() else 'cpu'

## utils

In [2]:
# download alice in wonderland
url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
book = requests.get(url).content
book = book.decode('ascii', 'ignore')
vocab = set(book)
d_vocab = len(vocab)
d_hidden = 100
d_batch = 10000
atoi = {a: i for i, a in enumerate(vocab)}
itoa = {i: a for a, i in atoi.items()}

In [3]:
def to_dataloader(text, seq_len=25, batch_size=d_batch):
    x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len // 2)]
    y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len // 2)]
    x = t.tensor([[atoi[a] for a in s] for s in x])
    y = t.tensor([[atoi[a] for a in s] for s in y])
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = to_dataloader(book, 50)

In [4]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename='', wnb=True):
    model.train()
    model = model.to(device)
    if opt is None: opt = t.optim.Adam(model.parameters(), lr=lr)
    if wnb: wandb.init(project=filename)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            wandb.log({'loss': loss.item()})
        if wnb and epoch % 100 == 0:
            wandb.log({'sample_html': wandb.Html(f'<p>{model.sample(d_sample=200)}</p>')})
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(model.sample())
        if epoch % 1000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')
    if wnb: wandb.finish()

## Recurrent Neural Networks (RNN)

### model

In [25]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.d_hidden = d_hidden
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None, return_memory=False):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory))
        if return_memory:
            return t.stack(outs, dim=1), memory
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = F.tanh(self.embed(x) + self.hidden(h_prev))
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

rnn = RNN(d_vocab, d_hidden, d_vocab).to(device)
rnn.sample()

'A6Db7;?2Q/4Z[sY ]7TTbBmjlYjyTUfh.jzeCQJ?]1G(e8Ye]kdNh*RFlq)G]\rKMhj:r12lWDp6HH[P23IX?p;_p3pxguN6zRbyJ'

### train

In [33]:
train(rnn, dataloader, filename='rnn', epochs=50)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeluche[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/50 [00:00<?, ?it/s]

loss=4.4495
A27D9USZ$b6Vf]u4f;PdWxKq*'BSI1_mI:1IvK'u?fvgI84ZY#XwiAy$freQaWUTH)u01)U6)Xso#4NX$,y*],ay  di7;Eh;UFg




VBox(children=(Label(value='0.002 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.27701849086941543, max=1.…

0,1
loss,████████▇▇▇▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁

0,1
loss,3.35577


In [29]:
print(rnn.sample(d_sample=2000))

a%-s?pukC2_DbNW /H!Iikd8hNT*wD$WS,lFxE0!e-[je:WN;ri5TdC;u?By2-0LpzWvNB5Jm4O6baw%!#wd1,Otw.jaH psg)_7CAfswZ9
3:C:fwRHPc-kWpe*xT_BV_7VqV#a0OdlA pwzrFz)m.L H
rZnECCz)
wVt)#,iajYRxx3*n/z5Mw,MLT*.3,v0Vvf-)sV2ozqh,abpyq*W?O_JhJqbvh
4*F*ghhvQrd!DZPgc
fMTL
W xjBQKyIyL)joZ;G#)6 rvjzQG6s/C$v?
RCk5-ozY?7c8hXC;Fsh. Xzoo*Z9H!$IpuEM%F(nRIL39cog2ncvQa]V]59])oGv]Dm-mw4PR7sjvH0qt?L: LkH4jL:a/7?xA0Y,1UZvns[v.'XwX1,8(I7_Ie0N';J5:Whz,6_[;]EKjwkHiWJI$
 9I*_'UCj!yTzk3PR5G;1CLzM)7Z1FCY:*][2.[dlM6#PJx'v-P(7[Gog5UrDR1o$3Z:V( /po/-%h)#]AQA1_31QF(G_hx;HwFT2;8qgDzTbheh9?nmYpWNe1!GGK5)hhb9$:7MZG%P!)Yx aM[.4yO (R.,k)'$qE_%OY,X
iHq*g]:0K
678ulSD-0DQ19hD3ZNbEsXCDS*ap25jf5Un784u 1yAh5F_A*1iS ,/d_KrR'a3)CATD4dCI1a'mCHI_ba]1O3sFZyZFK!utD;mVNs[:wc#G!oGo:R6y1kg_;MmTcl-35G,:7ty6 /hhL*.#E/!#-KtR$'_y)kDVBK/lci1
3sI9xlW#j!CJV6MdH9T3I'x9ez,;01(fS[25rRRG N!Wsm7DOZB*M43;NNutOW6B31-?Z4KBXB%(?lF%Qk(qpI#R[ik#;_HC.RJRNhl*bFx#Sy;H%NCQo4Hk6L4Q_'M
ldg/7*b-)CZ!hjgj_I/jeM'AAqidp;$_%yTCy
Lx1De6yjpTp#)R9q-v:rh1LWhulr?Ex5Hrx(mym8ChnCk3ofZ4'

## Long Short-Term Memory (LSTM)

### model

In [5]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class LSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        c_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstm = LSTM(d_vocab, d_hidden, d_vocab).to(device)
lstm.sample()

"A3a n):iV!n7.!k.Xzt;7C2f0CF(7d2XHX#0[gr4WQ*Y2II(0'.tjw,72bhPL)c!.!3kRJ7[-Q%JbGB5P3_)NZOdi]\r%b[491IP$"

### train

In [6]:
train(lstm, dataloader, filename='lstm')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeluche[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2001 [00:00<?, ?it/s]

loss=4.4037
Wu7LhND.QiVqU(#9eAy%v!V Owk,R
hnLFs[d-NzsG/pq#mg4-vS5uME49
loss=4.0279
loss=3.3457
loss=3.3031
loss=3.2774
loss=3.2545
loss=3.2291
loss=3.1992
loss=3.1635
loss=3.1189
loss=3.0653
loss=3.0023
loss=2.9343
loss=2.8653
loss=2.7961
loss=2.7328
loss=2.6785
loss=2.6316
loss=2.5909
loss=2.5550
loss=2.5228
AMAn?

ou gh  ou itevlse!F
, bisleboit Iae shorf tl, aad, ant
oh, csl lacils at Ie shen
$ar
)es 
loss=2.4933
loss=2.4649
loss=2.4384
loss=2.4144
loss=2.3919
loss=2.3710
loss=2.3510
loss=2.3323
loss=2.3145
loss=2.2977
loss=2.2818
loss=2.2665
loss=2.2517
loss=2.2376
loss=2.2240
loss=2.2108
loss=2.1980
loss=2.1854
loss=2.1732
loss=2.1614
Df bothe wor, Alidee carderstale, wikh, sat ingwtit the llofune tas nated




VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,2.1614


In [19]:
print(lstm.sample(d_sample=2000))

Alice trying ear lister as), to polles was very curkkiled tonfuped be waithe it her stern soning! said the Kings and live y a
have going to face about height!

Well! some and fromale tone, howeman inner you glieBlo
loner a moment in whicentage. It tweem his, you forrono. I cutt__Syy
UbbicbjecV ccco*rGG tramp, in throug he was indow I amblid. Ante fon be a little came a litt! said that! You fullive edst by elvose inviteninged shy (Gut THAT  BRD T pUbsb4:,
QQutends. 3.03. UUASBfSSUEQ.. X H FOF DEBBKAGEERE YoY) WYOIURAPTYSH
F-UROONTTOANWLA
********************TOF*************************** ****** ME'Q  KMKVMKUKKezzzzze abadd GrateboWhH33
KidpK)) (F2KUKEKQRKQEKNGMMERE3EEEXNXMM8ENNGoESMa**zkN!
over.

 Let byot, Ale (sseding little toudes alove ome a little propers
1.ACI chart  im((Hf hau any licean tow reirnees
5.e thregs inne it is eath disar! said the Doromot AlYoxid__TVTEE.
ESUE! Wh!_
(EVEVIg. PET9_ gree horman, she saidetod ofter three
1.1. boxes,
and vonus aspersthozyNO by! Grypiol.
Si

## Gated Recurrent Unit (GRU)

### model

In [None]:
class GRUCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_r = nn.Linear(d_in + d_hidden, d_hidden)  # reset gate
        self.W_z = nn.Linear(d_in + d_hidden, d_hidden)  # update gate
        self.W_h = nn.Linear(d_in + d_hidden, d_hidden)  # hidden state update

    def forward(self, x, h_prev):
        cat = t.cat((x, h_prev), dim=1)
        r_gate = t.sigmoid(self.W_r(cat))
        z_gate = t.sigmoid(self.W_z(cat))
        h_candidate = t.tanh(self.W_h(t.cat((x, r_gate * h_prev), dim=1)))
        h_prev = (1 - z_gate) * h_prev + z_gate * h_candidate
        return h_prev

class GRU(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.gru_cell = GRUCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev = self.gru_cell(x, h_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = self.gru_cell(x, h_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

gru = GRU(d_vocab, d_hidden, d_vocab).to(device)
gru.sample()

### train

In [13]:
train(gru, dataloader, filename='gru', epochs=40000) #, epochs=1000000)

  0%|          | 0/40000 [00:00<?, ?it/s]

loss=4.4247
A$o:(':QNQRG:A Jv1yx-jk3AOkc_*5Oxuhn
Neb1LO[33)0Yg0jwKTxlI3w;vzR;e6h.#ZEbnu_MONG
loss=4.0179
loss=3.3122
loss=3.2710
loss=3.2471
loss=3.2247
loss=3.1993
loss=3.1658
loss=3.1193
loss=3.0528
loss=2.9717
loss=2.8811
loss=2.7868
loss=2.6927
loss=2.6137
loss=2.5506
loss=2.4985
loss=2.4530
loss=2.4121
loss=2.3748
loss=2.3402
Ale ano, mashed maamronir what. Gus vatd iolo lico
nenuks.y wta dattde ftice ehlfthers wens veas nh
loss=2.3075
loss=2.2764
loss=2.2465
loss=2.2173
loss=2.1886
loss=2.1608
loss=2.1339
loss=2.1081
loss=2.0835
loss=2.0601
loss=2.0378
loss=2.0165
loss=1.9957
loss=1.9757
loss=1.9562
loss=1.9373
loss=1.9188
loss=1.9010
loss=1.8837
loss=1.8667
Alice.

It you camed
wat?pene maly, Gutenb:rd, in the Houts or and shoughe thousd no somergo of t
loss=1.8503
loss=1.8344
loss=1.8190
loss=1.8038
loss=1.7890
loss=1.7748
loss=1.7607
loss=1.7471
loss=1.7338
loss=1.7208
loss=1.7082
loss=1.6960
loss=1.6840
loss=1.6722
loss=1.6607
loss=1.6496
loss=1.6386
loss=1.6278
loss=1.6171
l



VBox(children=(Label(value='0.164 MB of 0.164 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,1.12053


In [22]:
print(gru.sample(d_sample=2000))

Ali s mano you, wooup. But sloal the
sooptint nolefilk the greanl,
puromile hor5 and a
deoftone.
Alice, It veny; Aliced, the ssen in
uto clyouged a phon entt te maly
I leapy dosplule, they seac Tuplly ofutent oufnre esazing sal on a paoment you
to corefed croand, said at wareme thean, thewr thee
some, said. Whice proced youns it abeed infnesenting sal she dinking.

Ahe got sead, the cuthen she dentor haid toremily.

Thiurd once say
in a sa nous!

 hery down and you rabe suct omoutt
ever? of coren aby to it tpound.

I bed ih wint nither to gite ma grois: atlitsteney
tnem anc_, po the

Heride and ig ver as buthen for anl Projecze, So_ was che her sioninging, shine
Alice, very to foo uped han weer oustle raject jurm she lact, tho gow ssEe wiln ters os in.

Ind of to dime
thet seeped hermely salt perypled offenther yeald! Whish FiR RHAHAciten.

The Duck Iboud in w. Henealle, wes elikely ais befinisusbyut wotde af ts mad of then at on wfrye fithibis _, adrerules wort way wese_ Mucce. I_ po_

## MNIST

In [5]:
batch_size=6400
vocab_size=10
hidden_size=100

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('./mnist/', download=True, train=True, transform=transform)
testset = datasets.MNIST('./mnist/', download=True, train=False, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)

In [16]:
@t.no_grad()
def accuracy(model, dataloader=testloader):
    model.eval()
    correct, total = 0, 0
    for xs, ys in dataloader:
        # out = model(xs.squeeze().to(device))[:, -1]
        out = model(xs.squeeze().to(device))
        correct += (out.argmax(-1) == ys.to(device)).sum()
        total += len(xs)
    model.train()
    return correct / total

In [None]:
def train_mnist(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename='', wnb=True):
    model.train()
    model = model.to(device)
    if opt is None: opt = t.optim.Adam(model.parameters(), lr=lr)
    if wnb: wandb.init(project=filename)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(xs.squeeze().to(device))[:, -1]
            loss = F.cross_entropy(out, ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            wandb.log({'loss': loss.item(), 'accuracy': accuracy(model)})
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f} accuracy={accuracy(model)}')
            # print(f'loss={loss.item():.4f}')
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')
    if wnb: wandb.finish()

lstm_mnist = LSTM(28, d_hidden, vocab_size).to(device)
train_mnist(lstm_mnist, trainloader, epochs=100)
# gru_mnist = GRU(28, hidden_size, vocab_size).to(device)
# train_mnist(gru_mnist, trainloader, epochs=30, wnb=False)

In [8]:
class GRU_chain(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.gru_cell = GRUCell(d_in, d_hidden)

    def forward(self, xs, h_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev = self.gru_cell(x, h_prev)
            outs.append(h_prev)
        return t.stack(outs, dim=1)

class MNISTer(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.layers = nn.Sequential(
            GRU_chain(d_in, d_hidden),
            nn.ReLU(),
            GRU_chain(d_hidden, d_hidden),
            nn.Linear(d_hidden, d_out))

    def forward(self, xs):
        return self.layers(xs)

# mnister = MNISTer(28, 100, vocab_size).to(device)
# train_mnist(mnister, trainloader, epochs=100)

In [None]:
class LSTM_chain(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(h_prev)
        return t.stack(outs, dim=1)

class MNIST_lstm(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.layers = nn.Sequential(
            LSTM_chain(d_in, d_hidden),
            nn.ReLU(),
            LSTM_chain(d_hidden, d_hidden),
            nn.Linear(d_hidden, d_out))

    def forward(self, xs):
        return self.layers(xs)

mnist_lstm_r = MNIST_lstm(28, 100, vocab_size).to(device)
train_mnist(mnist_lstm_r, trainloader, epochs=100)

In [None]:
mnist_lstm_rs = MNIST_lstm(28, 10, vocab_size).to(device)
train_mnist(mnist_lstm_rs, trainloader, epochs=100)

In [None]:
def train_mnist2(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename='', wnb=True):
    model.train()
    model = model.to(device)
    if opt is None: opt = t.optim.Adam(model.parameters(), lr=lr)
    if wnb: wandb.init(project=filename)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(xs.squeeze().to(device))
            loss = F.cross_entropy(out, ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            wandb.log({'loss': loss.item(), 'accuracy': accuracy(model)})
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f} accuracy={accuracy(model)}')
            # print(f'loss={loss.item():.4f}')
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')
    if wnb: wandb.finish()


x = nn.Sequential(nn.Flatten(start_dim=1), nn.Linear(28**2, 100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 10)).to(device)
train_mnist2(x, trainloader, epochs=100)

## Alice 2

In [13]:
# class LSTM(nn.Module):
#     def __init__(self, d_in, d_hidden, d_out):
#         super().__init__()
#         self.d_hidden = d_hidden
#         self.lstm_cell = LSTMCell(d_in, d_hidden)
#         self.unembed = nn.Linear(d_hidden, d_out)

#     def forward(self, xs, h_prev=None, c_prev=None):
#         # xs: (batch, d_context, d_vocab)
#         batch, d_context, _ = xs.shape
#         outs = []
#         if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
#         if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
#         for i in range(d_context):
#             x = xs[:, i]
#             h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
#             outs.append(self.unembed(h_prev))
#         return t.stack(outs, dim=1)

class LSTM_chain(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(h_prev)
        return t.stack(outs, dim=1)

class LSTMs(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.layers = nn.Sequential(
            LSTM_chain(d_in, d_hidden),
            nn.ReLU(),
            LSTM_chain(d_hidden, d_hidden),
            nn.ReLU(),
            LSTM_chain(d_hidden, d_hidden),
            nn.ReLU(),
            LSTM_chain(d_hidden, d_hidden))
        self.head = nn.Linear(d_hidden, d_out)

    def forward(self, x):
        x = self.layers(x)
        x = self.head(x)
        return x

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        h_prev = t.zeros(1, self.d_hidden, device=device)
        c_prev = t.zeros(1, self.d_hidden, device=device)
        while len(text) < d_sample:
            x = F.one_hot(t.tensor([[atoi[c] for c in text]]), num_classes=d_vocab).float().to(device)
            out = self.layers(x)
            out = self.head(out)
            probs = out[0, -1].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstms = LSTMs(d_vocab, d_hidden, d_vocab).to(device)
lstms.sample()

"A7FX%JPZY*99\nxWeZ-#x!1W'(CbOwjqM;h;$[BL$ d9Ln*GASpw_ e3zQje41[_Tn9CJM3tTeimvOkPJ*?WE:K\nP%,luO'/a8TOl"

In [14]:
train(lstms, dataloader, filename='stacked_lstm') # TODO: redo a longer one

  0%|          | 0/2001 [00:00<?, ?it/s]

loss=4.4564
5vmn#8hz0h[ 5iUJXrPBS2j'ehLHpZP_Hm1jF*n;h;n-3
loss=3.6371
loss=3.3002
loss=3.2595
loss=3.2365
loss=3.2200
loss=3.2086
loss=3.2006
loss=3.1925
loss=3.1831
loss=3.1743
loss=3.1690
loss=3.1662
loss=3.1646
loss=3.1629
loss=3.1615
loss=3.1602
loss=3.1587
loss=3.1577
loss=3.1567
loss=3.1514
Ao, orfknnd be kt o hiR
 n,luwbsitIhrwki_  A hesiyvisn o hoteo
loss=3.1477
loss=3.1478
loss=3.1460
loss=3.1419
loss=3.0508
loss=2.8778
loss=2.7744
loss=2.7153
loss=2.6709
loss=2.6161
loss=2.5502
loss=2.5028
loss=2.4669
loss=2.4337
loss=2.4043
loss=2.3744
loss=2.3456
loss=2.3203
loss=2.2951
loss=2.2669
Aloes oe IH lhod saees onbyse it to her
nan_ hint il
feod, yaod.

Ye-tded moed onset afrintd olr




VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁

0,1
loss,2.26691


## stacked llm

In [7]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class StackedLSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, d_layers):
        super().__init__()
        self.d_hidden = d_hidden
        self.d_layers = d_layers
        self.lstm_cells = nn.ModuleList([LSTMCell(d_in if l == 0 else d_hidden, d_hidden) for l in range(d_layers)])
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_next, c_next = [], []
            for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
                h, c = lstm_cell(x, h, c)
                h_next.append(h)
                c_next.append(c)
                x = h
            outs.append(self.unembed(h))
            h_prev = t.stack(h_next)
            c_prev = t.stack(c_next)
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        h_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
        c_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
        while len(text) < d_sample:
            x = F.one_hot(t.tensor([atoi[text[-1]]]), num_classes=d_vocab).float().to(device)
            h_next, c_next = [], []
            for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
                h, c = lstm_cell(x, h, c)
                h_next.append(h)
                c_next.append(c)
                x = h
            h_prev = t.stack(h_next)
            c_prev = t.stack(c_next)
            out = self.unembed(h)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
        return text

stacked_lstm = StackedLSTM(d_vocab, d_hidden, d_vocab, 3).to(device)
stacked_lstm.sample()

"A#]Epf7K*ahT9pN*rDDgPlwBM-iBT#2M9.Z?sqzFu'P*i1FK]u'dEc5z956S\rkW!1-F/weG$,\ns0Spff,[%N?GqdF'U!A\r#atgOR"

In [61]:
# t.save(stacked_lstm.state_dict(), f'weights/stacked_lstm_3_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

In [8]:
train(stacked_lstm, dataloader, filename='stacked_lstm_3_layers_long', epochs=1000000)



VBox(children=(Label(value='0.019 MB of 0.019 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
loss,1.43789


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112476034193403, max=1.0…

  0%|          | 0/1000000 [00:00<?, ?it/s]

loss=4.4226
$]rL(MZOt
XT(z2fWkI
CY9Y3,oFGEE*s1IzS#[m-ttU_ft!E,58P6/E/A.(kEmBBsw
loss=3.3822
loss=3.2366
loss=3.2184
loss=3.2071
loss=3.1990
loss=3.1929
loss=3.1882
loss=3.1846
loss=3.1817
loss=3.1792
loss=3.1768
loss=3.1744
loss=3.1722
loss=3.1702
loss=3.1685
loss=3.1670
loss=3.1660
loss=3.1651
loss=3.1646
loss=3.1642
Arof, utwiys
asniwt i lh eui
da h
l ytunptss  u rI drc ecrunAklet   
loss=3.1639
loss=3.1637
loss=3.1635
loss=3.1633
loss=3.1632
loss=3.1630
loss=3.1628
loss=3.1625
loss=3.1617
loss=3.1601
loss=3.1564
loss=3.1427
loss=3.0749
loss=2.9620
loss=2.9036
loss=2.8659
loss=2.8336
loss=2.8019
loss=2.7647
loss=2.7260
nirng coowee
get uogemdeme tiua ti
the-wayt a fuiite w thir s 
I aas,e 
Cw_sd anog 
den
loss=2.6893
loss=2.6585
loss=2.6325
loss=2.6090
loss=2.5877
loss=2.5675
loss=2.5480
loss=2.5308
loss=2.5140
loss=2.4986
loss=2.4851
loss=2.4705
loss=2.4568
loss=2.4444
loss=2.4281
loss=2.4122
loss=2.3966
loss=2.3809
loss=2.3654
loss=2.3495
Atirlind
nw ipd orminl sibse cecacgarwieb t

In [None]:
print(stacked_lstm.sample(d_sample=2000))