In [5]:
from fastai.text import *
from fastai import *
import torch

In [7]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# AWD LSTM Language model

In this notebook, we will go through the full implementation of the AWD LSTM neural net architecture used as language model in the [ULMFIT](https://arxiv.org/pdf/1801.06146.pdf) paper. Most of the upcoming code is heavily based on the [fastai](https://docs.fast.ai) library and its deep learning course, which has already a full implementation of the ulmfit approach for NLP. However considering the complexity of the fastai code and its simplicity to use we figured it would helpful for readers to get a full bottom up implementation using pytorch as a baseline. Still we expect that if you read this notebook, you have a good knowledge and understanding of RNNs, language modeling (see paper) and pytorch.

## LSTM

The core of the AWD LSTM architecture is of course the LSTM neural net. It is an improvement to the standart RNN way of dealing with sequential data (such as text). LSTM deals with the [vanishing/exploding gradient problem](https://medium.com/learn-love-ai/the-curious-case-of-the-vanishing-exploding-gradient-bf58ec6822eb) that come up in simple RNNs using cell connection gates. For further intuition on 'why' most the upcoming implentations I recommend colah's blog post : [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) 

### LSTM cell

![LSTM cell and equations](images/lstm.jpg)
(picture from [Understanding LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) by Chris Olah.)

The LSTM archutecture is composed of a repeated cell which is shown in the image above. Its inputs are :
- **xt** which in our case is the embedding vector of the nth word of a batch of sentences
- **ht-1** the output of the last cell just like in RNNs.
- **ct-1** again output form last cell which is called the *cell state* used to prevent long-term dependencies problem.

The $\sigma$ reprenstents the [Sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) function applied element-wise to its input. Both x and + connections are elemnt-wise multiplication and addition respectively. 

Let us implement it using the pytorch nn.Module class. We use a two big matrix multiplication to compute x*U and and h*U instead of 4 for each of them.

In [95]:
class LSTMCell(nn.Module):
    def __init__(self, x_s, h_s):
        super().__init__()
        self.h_s = h_s
        self.x_s = x_s
        self.U = nn.Linear(x_s,4*h_s)
        self.W = nn.Linear(h_s,4*h_s)

    def forward(self, input, state):
        #inputs from last cell
        h,c = state
        
        #computing itermedtiate gates
        gates = (self.U(input) + self.W(h)).chunk(4, 1)
        i_t,f_t,o_t = map(torch.sigmoid, gates[:3])
        c_t = gates[3].tanh()
        c = (f_t*c) + (i_t*c_t)
        h = o_t * c.tanh()
        
        #outputting the usualt h output and the state to give to next cell if needed
        return h, (h,c)

### LSTM Layer

The next building block of the LSTM is the LSTM layer wich consisit of appling the LSTM cell to each sequential input in a recurrent manner with each time forwarding its state to the next time step.

![LSTM layer](images/LSTM3.png)
(picture from [Understanding LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) by Chris Olah.)

In [79]:
class LSTMLayer(nn.Module):
    def __init__(self, x_s, h_s):
        super().__init__()
        self.lstm_cell = LSTMCell(x_s, h_s)

    def forward(self, input, state):
        # divide the input in the sequence dimension to get x_0, x_1, x_2, ...
        inputs = input.unbind(1)
        
        #prepare to store the output of each cell
        outputs = []
        
        #applying the cell recursively 
        for i in range(len(inputs)):
            out, state = self.lstm_cell(inputs[i], state)
            outputs += [out]
        
        #return the stacked outputs
        return torch.stack(outputs, dim=1), state

(For the state for the first cell we simply use tensors with only zeroes)

### Stacked LSTM layers

The last step before having fully implemented pytorch's LSTM module is stacking multiple LSTM layers one above each other as shown in the following diagram :

![Stacked RNN](images/RNN_Stacking.png)
(picture from : https://leonardoaraujosantos.gitbooks.io/artificial-inteligence/content/recurrent_neural_networks.html?q= )

- <font color='pink'>pink rectangles</font>  : The sequential inputs  
- <font color='green'>green rectangles</font>  : An LSTM cell, each line is a layer so the cells on the same line are the same
- <font color='blue'>blue rectangles</font> : The sequential outputs

On the previous implementations, we created the initial state **h0** and **c0** outside of the model at the same time as we gave the input. This means that we could create h0 and c0 with the right sizes by simply comparing it to the input sizes. This time around, we will create the initial state to give to all the layers inside the model itself. As the dimensions of the state depend on the batch size of the input given, we need to create the initial states at the start of the forward pass when we are given the input (and thus we know the batch size) with the help of the **reset()** method. We also want to keep the last states from last batch if the batch size did not change.

We aso have to be careful of the sives of the hidden layers inputs. For example, on the image above, the first layer takes as input the initial sequence and outputs a hidden sequence whereas the second and third layers take as input a hidden sequence and outputs a hidden sequence. Because of that we must have a different sizes for the first layer and the other layers. And of course the same goes for the initial states.

In [94]:
class fullLSTM(nn.Module):
    def __init__(self, x_s, h_s, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.lstm_layers = nn.ModuleList([LSTMLayer(x_s if i==0 else h_s, h_s ) for i in range(n_layers)])
        self.bs = 0

    def forward(self, input):
        
        #get the batch size from the first dimension of the input 
        bs, sl, _ = input.size()
        if self.bs != bs :
            self.bs = bs
            self.reset()
       
        # now we have the initial states and we can go through all the layers recursively
        for j in range(self.n_layers) :
            layer = self.lstm_layers[j]
            input, self.hidden[j] = layer(input, self.hidden[j])
                
        
        #return the outputs
        return input
    
    def reset(self) :
        st = next(self.parameters()).new(self.bs, h_s).zero_()
        self.hidden = [(st, st) for l in range(self.n_layers)]
        

This implementation is pretty much the same as pytorch's nn.LSTM. The difference is that pytorch uses CuDNN to make the computations faster. We will now use pytorch's implementation instead of ours

## Generalization : Dropout

Usual regularization techniques used in feed-forward and convolutional neural nets such as dropout and batchnorm do not work well in RNNs. The AWD LSTM uses extensions of those to regularize its model. Correspondigns ections of the [paper](https://arxiv.org/pdf/1708.02182.pdf) will be provided for more info.

### Variational dropout 
Section 4.2

The idea in variational dropout is to use the **same** drop out mask to a squential input over the sequence dimension. In essence, if you have an input x with shape *(bs, seq_len, x_s)*, the dropout mask will be of shape *(bs, 1, x_s)* and will be applied to each slice of sequence.
This dropout will be used on each output/input of the LSTM layers. Additionally we divide every activations that have not been set to 0 by the mask by 1-p (p: probability of dropout) to keep the average. We use [broadcasting](https://pytorch.org/docs/stable/notes/broadcasting.html) to be efficient in the element-wise computations

In [162]:
def dropout_mask(x, sz, p):
    return x.new(*sz).bernoulli_(1-p).div_(1-p)

In [163]:
class VDropout(nn.Module) :
    def __init__(self, p=0.5) :
        super().__init__()
        self.p = p
    def forward(self, x) :
        #The dropout should only be used during training and not eval 
        if not self.training or self.p == 0.: return x
        #the mask
        m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
        #element-wise multiplication with broadcasting
        return x*m

In [169]:
m = RNNDropout(0.3)
x = torch.randn(3,3,7)
x, m(tst_input)

(tensor([[[ 0.8941,  0.5911, -0.1017,  1.4399, -1.6575,  0.6746,  0.4393],
          [-0.7288,  0.7585,  1.1424,  1.5135, -1.9283,  1.0334, -0.4931],
          [-0.2847, -0.2492, -0.8360,  0.9637, -0.3897, -0.0111, -0.7577]],
 
         [[ 0.7182,  1.2078,  0.0514,  0.2906,  1.9491,  0.9235,  0.3016],
          [-0.2377,  1.4278, -0.4590, -0.0160,  0.9702, -0.3551,  0.3034],
          [ 0.4043,  0.5498, -0.5338, -0.0617, -0.4097, -0.4234,  0.3979]],
 
         [[-0.6652, -0.6661,  2.0000, -1.7812, -0.6374, -0.6694, -1.1852],
          [-0.9589, -1.2132, -0.0476,  0.2971, -1.8044,  0.3118,  1.5059],
          [-0.3349, -0.1490, -0.5113,  0.4419,  1.3988, -0.5077, -2.2353]]]),
 tensor([[[-0.8154,  0.4457, -0.7568, -2.3807,  1.9234, -0.0000,  1.5853],
          [-1.0636, -0.1794, -1.0083,  2.3365, -0.6700, -0.0000, -0.3559],
          [-1.4767,  3.2588, -0.9015,  1.4758, -1.0860, -0.0000, -1.5421]],
 
         [[-0.2077,  0.5429, -0.3351, -1.9916,  2.1828,  0.7989,  0.8166],
          [-0

Here we can see that the dropped is consistent in the second dimension

### Embedding dropout 
Section 4.3

For embedding dropout we simply nulifiy entire rows of the word embedding matrix with probability p. Again broadcastiong is used

In [171]:
class EmbeddingDropout(nn.Module):
    
    def __init__(self, emb, embed_p):
        super().__init__()
        self.emb,self.embed_p = emb,embed_p
        self.pad_idx = self.emb.padding_idx
        if self.pad_idx is None: self.pad_idx = -1

    def forward(self, words, scale=None):
        if self.training and self.embed_p != 0:
            size = (self.emb.weight.size(0),1)
            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
            masked_embed = self.emb.weight * mask
        else: masked_embed = self.emb.weight
        if scale: masked_embed.mul_(scale)
        return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,
                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)

In [173]:
enc = nn.Embedding(100, 7, padding_idx=1)
enc_dp = EmbeddingDropout(enc, 0.5)
tst_input = torch.randint(0,100,(8,))
enc_dp(tst_input)

tensor([[-0.0369, -4.0773, -1.0820, -2.0126, -1.0526,  0.7776, -2.6893],
        [ 1.2100,  5.0159,  0.5262,  1.1307, -1.9710,  3.0681,  0.3349],
        [ 1.2875, -5.1421, -0.1379, -2.6644,  1.0783, -1.3284,  2.4453],
        [ 3.9419, -1.8886,  0.0901,  0.2247,  1.3759,  0.3034,  0.0748],
        [ 0.7805, -2.3229,  0.7034,  1.4335,  0.6519,  0.2499,  2.6581],
        [ 0.6230,  1.3893,  2.1869, -0.1947,  1.9668,  2.0854,  2.4333],
        [-0.4102, -1.8052,  1.8866,  1.1053, -1.2866,  1.9148,  1.5755],
        [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000, -0.0000]],
       grad_fn=<EmbeddingBackward>)

We can see that entire rows have been dropped

### Weight-dropout
Section 2

Weight dropout is a dropout applied to the weights inside the LSTM cells : U and W.

In order to keep the speed of the LSTM layer, we simply replace the weight matrix of the LSTM by a masked version and keep the non-masked version. We can then simply apply the LSTM layer and it will use its new weights.

In [142]:
# The name of the parameter in the nn.LSTM module containing the weights 
WEIGHT_HH = 'weight_hh_l0'

class WeightDropout(nn.Module):
    def __init__(self, module, weight_p=[0.], layer_names=[WEIGHT_HH]):
        super().__init__()
        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            #
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)

    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)

    def forward(self, *args):
        self._setweights()
        with warnings.catch_warnings():
            #To avoid the warning that comes because the weights aren't flattened.
            warnings.simplefilter("ignore")
            return self.module.forward(*args)

In [174]:
module = nn.LSTM(5, 2)
dp_module = WeightDropout(module, 0.4)
getattr(dp_module.module, WEIGHT_HH)

Parameter containing:
tensor([[-0.0398,  0.6572],
        [ 0.5677, -0.6067],
        [ 0.1554, -0.3794],
        [ 0.4172,  0.6862],
        [-0.3063, -0.5804],
        [-0.1082, -0.1653],
        [ 0.6647, -0.3769],
        [-0.4278, -0.3355]], requires_grad=True)

In [175]:
tst_input = torch.randn(4,20,5)
h = (torch.zeros(1,20,2), torch.zeros(1,20,2))
x,h = dp_module(tst_input,h)
getattr(dp_module.module, WEIGHT_HH)

tensor([[-0.0663,  0.0000],
        [ 0.0000, -1.0111],
        [ 0.2590, -0.6323],
        [ 0.6953,  0.0000],
        [-0.0000, -0.9673],
        [-0.1804, -0.2754],
        [ 1.1079, -0.0000],
        [-0.7129, -0.5591]], grad_fn=<MulBackward0>)

As we can see, the dropout is applied to the weights during the forward pass

## Full model

At this point we have everything ready to implement the entire AWSD LSTM model, the following code might look really complicated at first but it is in fact pretty much the same as our fullLSTM except we use the different kinds of dropout disscussed above. It also takes care of the word embeddings whereas our fullLSTM assumed it was already done so we need to take care of that. Another difference is that the last layer outputs a different size tensor 

In [176]:
def to_detach(h):
    "Detaches `h` from its history."
    return h.detach() if type(h) == torch.Tensor else tuple(to_detach(v) for v in h)

In [180]:
class AWD_LSTM(nn.Module):
    initrange=0.1

    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token,
                 hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
        super().__init__()
        """Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:
            vocab_sz (int): number of words in the vocab
            emb_sz (int): size of the word embedding vector
            n_hid (int): size of the hidden vector 
            n_layers (int): number of layers in the LSTM
            pad_token (int): id of the pad_idx for the embedding matrix
            hidden_p (float): dropout probability for variational dropout on hidden activations
            input_p (float): dropout probability for variational dropout on input activations
            embed_p (float):dropout probability for embedding dropout 
            weight_p (float):dropout probability for weight dropout

        """
        self.bs,self.emb_sz,self.n_hid,self.n_layers = 1,emb_sz,n_hid,n_layers
        self.emb = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
        self.emb_dp = EmbeddingDropout(self.emb, embed_p)
        self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz), 1,
                             batch_first=True) for l in range(n_layers)]
        self.rnns = nn.ModuleList([WeightDropout(rnn, weight_p) for rnn in self.rnns])
        self.emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.input_dp = RNNDropout(input_p)
        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])

    def forward(self, input):
        bs,sl = input.size()
        if bs!=self.bs:
            self.bs=bs
            self.reset()
        raw_output = self.input_dp(self.emb_dp(input))
        print(raw_output.shape)
        new_hidden,raw_outputs,outputs = [],[],[]
        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
            raw_output, new_h = rnn(raw_output, self.hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
            outputs.append(raw_output) 
        self.hidden = to_detach(new_hidden)
        return raw_outputs, outputs

    def _one_hidden(self, l):
        "Return one hidden state."
        nh = self.n_hid if l != self.n_layers - 1 else self.emb_sz
        return next(self.parameters()).new(1, self.bs, nh).zero_()

    def reset(self):
        "Reset the hidden states."
        self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]

We also need a decoder which takes the output of our AWD_LSTM and transform it into the prediction of the wrord

In [181]:
class LinearDecoder(nn.Module):
    def __init__(self, n_out, n_hid, output_p, tie_encoder=None, bias=True):
        super().__init__()
        self.output_dp = RNNDropout(output_p)
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight
        else: init.kaiming_uniform_(self.decoder.weight)

    def forward(self, input):
        raw_outputs, outputs = input
        output = self.output_dp(outputs[-1]).contiguous()
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded, raw_outputs, outputs

We can combine both of them using a sequential module 

In [182]:
class SequentialRNN(nn.Sequential):
    "A sequential module that passes the reset call to its children."
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

In [183]:
def get_language_model(vocab_sz, emb_sz, n_hid, n_layers, pad_token, output_p=0.4, hidden_p=0.2, input_p=0.6, 
                       embed_p=0.1, weight_p=0.5, tie_weights=True, bias=True):
    rnn_enc = AWD_LSTM(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token,
                       hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    enc = rnn_enc.emb if tie_weights else None
    return SequentialRNN(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))