# AWD-LSTM

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [24]:
import pickle
from fastai import datasets
from functools import partial

from torch import nn

from exp.nlp import TextList, TokenizeProcessor, NumericalizeProcessor, lm_databunchify
from exp.data import SplitData, random_splitter, label_by_func

##  Data

In [3]:
path = datasets.untar_data(datasets.URLs.IMDB)

In [4]:
il = TextList.from_files(path, include=['train', 'test', 'unsup'])
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [5]:
proc_tok, proc_num = TokenizeProcessor(max_workers=3), NumericalizeProcessor()

In [10]:
ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok, proc_num])

In [13]:
pickle.dump(ll, open(path/'ll_lm.pkl', 'wb'))
pickle.dump(proc_num.vocab, open(path/'vocab_lm.pkl', 'wb'))

In [14]:
ll = pickle.load(open(path/'ll_lm.pkl', 'rb'))
vocab = pickle.load(open(path/'vocab_lm.pkl', 'rb'))

In [23]:
bs,bptt = 64,70
data = lm_databunchify(ll, bs, bptt)

## AWD-LSTM

In [40]:
import torch
from torch import Tensor

In [30]:
Tensor([1, 2, 3, 4]).chunk(2)

(tensor([1., 2.]), tensor([3., 4.]))

In [38]:
a, b = Tensor([[1,2,3,4], [5,6,7,8]]).chunk(2, 1)
a

tensor([[1., 2.],
        [5., 6.]])

In [39]:
b

tensor([[3., 4.],
        [7., 8.]])

In [41]:
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        super().__init__()
        self.ih = nn.Linear(ni, 4*nh)
        self.hh = nn.Linear(nh, 4*nh)
        
    def forward(self, input, state):
        h, c = state
        # one big multiplication fo all the gates is better then 4 smaller
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate, forgetgate, outgate = map(torch.sigmoid, gates[:3])
        cellgate = gates[3].tanh()
        
        c = (forgetgate*c) + (ingate*cellgate)
        h = outgate * c.tanh()
        return h, (h,c)

In [42]:
torch.Tensor([[1,2,3],[4,5,6]]).unbind(1)

(tensor([1., 4.]), tensor([2., 5.]), tensor([3., 6.]))

In [43]:
class LSTMLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)
        
    def forward(self, input, state):
        # break into one word batch
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[1], state)
            outputs += [out] # collect
        return torch.stack(outputs, dim=1), state

In [44]:
lstm = LSTMLayer(LSTMCell, 300, 300)

In [52]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300), torch.zeros(64, 300))

In [53]:
x.shape

torch.Size([64, 70, 300])

In [54]:
%timeit -n 10 y, h1 = lstm(x, h)

104 ms ± 3.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [56]:
len(data.train_dl)

6016

In [57]:
x, y = next(iter(data.train_dl))

In [60]:
x.dtype

torch.int64