In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.optim as optim

from typing import *
from pathlib import Path

In [None]:
DATA_ROOT = Path("../data/brown")

In [2]:
N_EPOCHS = 1

# Building an LSTM from scratch

In this notebook, we'll be building our own LSTM and delving into why it performs so well across a wide range of tasks.

# The Basics of the LSTM

Before we actually build the LSTM, we'll need to understand its basic mechansim.

The below diagram shows the flow of information in an LSTM cell (image from Wikipedia)

In [None]:
#![image](https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/The_LSTM_cell.png/1920px-The_LSTM_cell.png)
from IPython.display import Image

The equation for the LSTM looks like this:

\begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
            c_t = f_t * c_{(t-1)} + i_t * g_t \\
            h_t = o_t * \tanh(c_t) \\
        \end{array}

It seems complex, but when you pick it apart, the LSTM is actually very simple. The core of the LSTM is the following equation:

\begin{array}{ll} \\
            c_t = f_t * c_{(t-1)} + i_t * g_t \\
\end{array}

Let's pick this equation apart: $ c_t $ is the new cell state, which is basically the memory of the LSTM. 

$ f_t $ is called the "forget gate": it dictates how much of the previous cell state to **retain** (but is slightly confusingly named the forget gate). 

$ i_t $ is the "input gate" and dictates how much to update the cell state with new information.

Finally, $ g_t $ is the information we use to update the cell state.

Basically, an LSTM chooses to keep a certain portion of its previous cell state and add a certain amount of new information. These proportions are controlled using gates.

Let's contrast this update rule with the update rule of a simpler RNN

$$ c_t = \tanh(W_hc_{t-1} + W_ix_t) $$

(To make the contrast clearer, I'm representing the hidden state of the RNN as $ c_t $.)

As you can see, there is a huge difference between the simple RNN's update rule and the LSTM's update rule. Whereas the RNN computes the new hidden state from scratch based on the previous hidden state and the input, the LSTM computes the new hidden state by choosing what to **add** to the current state. This is similar to how ResNets learn: they learn what to add to the current state/block instead of directly learning the new state. In other words, LSTMs are great primarily because they are **additive**. We'll formalize this intuition later when we examine the gradient flow, but this is the basic idea behind the LSTM.

Now that we have a basic understanding, let's start coding.

Side Note: One thing that is slightly confusing about the LSTM is that it has two "hidden states": $ c_t $ and $ h_t $. Intuitively, $ c_t $ is the "internal" hidden state that retains important information for longer timesteps, whereas $ h_t $ is the "external" hidden state that exposes that information to the outside world.


Side Note: If you're looking carefully, you'll notice that the bias terms are redundant. The reason they are there is for compatibility with the CuDNN backend. Until we touch on CuDNN, we'll use a single bias term.

# Implementing the LSTM

We'll be using PyTorch to write our own LSTM

In [3]:
from enum import IntEnum
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

In [4]:
class NaiveLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        # input gate
        self.W_ii = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hi = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_i = Parameter(torch.Tensor(hidden_sz))
        # forget gate
        self.W_if = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hf = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = Parameter(torch.Tensor(hidden_sz))
        # ???
        self.W_ig = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hg = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_g = Parameter(torch.Tensor(hidden_sz))
        # output gate
        self.W_io = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_ho = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor, torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = torch.zeros(self.hidden_size).to(x.device), torch.zeros(self.hidden_size).to(x.device)
        else:
            h_t, c_t = init_states
        for t in range(seq_sz): # iterate over the time steps
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t @ self.W_ii + h_t @ self.W_hi + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_if + h_t @ self.W_hf + self.b_f)
            g_t = torch.tanh(x_t @ self.W_ig + h_t @ self.W_hg + self.b_g)
            o_t = torch.sigmoid(x_t @ self.W_io + h_t @ self.W_ho + self.b_o)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)

Testing on some synthetic data

In [None]:
bs, seq_len, feat_sz, hidden_sz = 5, 10, 32, 16
arr = torch.randn(bs, seq_len, feat_sz)
lstm = NaiveLSTM(feat_sz, hidden_sz)

In [None]:
hs, (hn, cn) = lstm(arr)

In [None]:
hs.shape

It looks like it works!

# Testing our implementation

Now, that we've covered the basics and have a minimally working LSTM, we'll put our model into action. Our testbed will be a character-level language modeling task. We'll be using the Brown Corpus which you can get via the commands below.

In [None]:
!mkdir -p {DATA_ROOT}

In [None]:
!curl http://www.sls.hawaii.edu/bley-vroman/brown.txt -o {DATA_ROOT / "brown.txt"}

We'll let AllenNLP handle the complexity of training the language model

In [5]:
from allennlp.data.dataset_readers import LanguageModelingReader
from allennlp.data.tokenizers import CharacterTokenizer
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.training import Trainer
from sklearn.model_selection import train_test_split

char_tokenizer = CharacterTokenizer(lowercase_characters=True)

reader = LanguageModelingReader(
    tokens_per_instance=500,
    tokenizer=char_tokenizer,
    token_indexers = {"tokens": SingleIdTokenIndexer()},
)

import os
pth = os.getcwd()

train_ds = reader.read(pth + "/LSTMModel/data/brown.txt")
train_ds, val_ds = train_test_split(train_ds, random_state=0, test_size=0.1)

vocab = Vocabulary.from_instances(train_ds)

iterator = BasicIterator(batch_size=32)
iterator.index_with(vocab)

def train(model: nn.Module, epochs: int=10):
    trainer = Trainer(
        model=model.cuda() if torch.cuda.is_available() else model,
        optimizer=optim.Adam(model.parameters()),
        iterator=iterator, train_dataset=train_ds, 
        validation_dataset=val_ds, num_epochs=epochs,
        cuda_device=0 if torch.cuda.is_available() else -1
    )
    return trainer.train()

0it [00:00, ?it/s]

11/19/2019 14:30:43 - INFO - allennlp.data.dataset_readers.language_modeling -   Creating dataset from all text in file: /development/projects/statisticallyfit/github/learningmathstat/PythonNeuralNetNLP/src/NLPstudy/LSTMModel/data/brown.txt





  0%|          | 0/11994 [00:00<?, ?it/s]

[A




 39%|███▊      | 4637/11994 [00:00<00:00, 38508.04it/s]

[A




 79%|███████▉  | 9481/11994 [00:00<00:00, 41031.24it/s]

[A

100%|██████████| 11994/11994 [00:00<00:00, 43349.12it/s]


1it [00:09,  9.30s/it]

145it [00:09,  6.51s/it]

608it [00:09,  4.56s/it]

1051it [00:09,  3.19s/it]

1496it [00:09,  2.23s/it]

1947it [00:09,  1.56s/it]

2339it [00:09,  1.09s/it]

2789it [00:10,  1.31it/s]

3201it [00:10,  1.86it/s]

3635it [00:10,  2.66it/s]

4042it [00:10,  3.80it/s]

4416it [00:10,  5.43it/s]

4823it [00:10,  7.75it/s]

5328it [00:10, 11.07it/s]

5746it [00:10, 15.79it/s]

6227it [00:10, 22.53it/s]

6661it [00:10, 32.11it/s]

7145it [00:11, 45.74it/s]

7604it [00:11, 65.07it/s]

8109it [00:11, 92.44it/s]

8577it [00:11, 130.92it/s]

9082it [00:11, 184.97it/s]

9558it [00:11, 259.68it/s]

10063it [00:11, 362.97it/s]

10543it [00:11, 500.53it/s]

11029it [00:11, 684.82it/s]

11501it [00:11, 919.12it/s]

11968it [00:12, 1202.05it/s]

11994it [00:12, 983.21it/s] 


11/19/2019 14:30:46 - INFO - allennlp.data.vocabulary -   Fitting token dictionary from dataset.


  0%|          | 0/10794 [00:00<?, ?it/s]

  2%|▏         | 189/10794 [00:00<00:05, 1875.75it/s]

  3%|▎         | 339/10794 [00:00<00:05, 1743.67it/s]

  5%|▍         | 514/10794 [00:00<00:05, 1744.35it/s]

  6%|▋         | 687/10794 [00:00<00:05, 1713.09it/s]

  8%|▊         | 887/10794 [00:00<00:05, 1788.77it/s]

 10%|▉         | 1065/10794 [00:00<00:05, 1784.05it/s]

 12%|█▏        | 1262/10794 [00:00<00:05, 1835.96it/s]

 13%|█▎        | 1438/10794 [00:00<00:05, 1811.72it/s]

 15%|█▌        | 1628/10794 [00:00<00:04, 1836.20it/s]

 17%|█▋        | 1812/10794 [00:01<00:04, 1835.32it/s]

 19%|█▊        | 1999/10794 [00:01<00:04, 1842.92it/s]

 20%|██        | 2197/10794 [00:01<00:04, 1881.93it/s]

 22%|██▏       | 2391/10794 [00:01<00:04, 1896.46it/s]

 24%|██▍       | 2583/10794 [00:01<00:04, 1902.27it/s]

 26%|██▌       | 2778/10794 [00:01<00:04, 1914.76it/s]

 28%|██▊       | 2971/10794 [00:01<00:04, 1916.85it/s]

 29%|██▉       | 3163/10794 [00:01<00:04, 1882.22it/s]

 31%|███       | 3352/10794 [00:01<00:04, 1852.72it/s]

 33%|███▎      | 3543/10794 [00:01<00:03, 1868.05it/s]

 35%|███▍      | 3730/10794 [00:02<00:03, 1814.59it/s]

 36%|███▌      | 3912/10794 [00:02<00:03, 1800.45it/s]

 38%|███▊      | 4093/10794 [00:02<00:03, 1793.45it/s]

 40%|███▉      | 4278/10794 [00:02<00:03, 1807.53it/s]

 41%|████▏     | 4461/10794 [00:02<00:03, 1812.21it/s]

 43%|████▎     | 4655/10794 [00:02<00:03, 1847.62it/s]

 45%|████▍     | 4846/10794 [00:02<00:03, 1863.76it/s]

 47%|████▋     | 5041/10794 [00:02<00:03, 1888.22it/s]

 48%|████▊     | 5232/10794 [00:02<00:02, 1892.80it/s]

 50%|█████     | 5428/10794 [00:02<00:02, 1912.04it/s]

 52%|█████▏    | 5622/10794 [00:03<00:02, 1917.79it/s]

 54%|█████▍    | 5817/10794 [00:03<00:02, 1925.09it/s]

 56%|█████▌    | 6010/10794 [00:03<00:02, 1922.13it/s]

 57%|█████▋    | 6203/10794 [00:03<00:02, 1574.43it/s]

 59%|█████▉    | 6371/10794 [00:03<00:03, 1392.74it/s]

 60%|██████    | 6522/10794 [00:03<00:03, 1290.64it/s]

 62%|██████▏   | 6670/10794 [00:03<00:03, 1340.84it/s]

 64%|██████▎   | 6857/10794 [00:03<00:02, 1464.07it/s]

 65%|██████▌   | 7045/10794 [00:04<00:02, 1566.66it/s]

 67%|██████▋   | 7227/10794 [00:04<00:02, 1634.89it/s]

 69%|██████▊   | 7403/10794 [00:04<00:02, 1670.40it/s]

 70%|███████   | 7589/10794 [00:04<00:01, 1722.58it/s]

 72%|███████▏  | 7769/10794 [00:04<00:01, 1742.76it/s]

 74%|███████▎  | 7946/10794 [00:04<00:01, 1625.00it/s]

 75%|███████▌  | 8112/10794 [00:04<00:01, 1496.12it/s]

 77%|███████▋  | 8267/10794 [00:04<00:01, 1456.42it/s]

 78%|███████▊  | 8464/10794 [00:04<00:01, 1579.80it/s]

 80%|████████  | 8639/10794 [00:04<00:01, 1627.14it/s]

 82%|████████▏ | 8816/10794 [00:05<00:01, 1666.04it/s]

 83%|████████▎ | 9010/10794 [00:05<00:01, 1738.04it/s]

 85%|████████▌ | 9201/10794 [00:05<00:00, 1785.71it/s]

 87%|████████▋ | 9399/10794 [00:05<00:00, 1838.13it/s]

 89%|████████▉ | 9596/10794 [00:05<00:00, 1875.26it/s]

 91%|█████████ | 9794/10794 [00:05<00:00, 1903.39it/s]

 93%|█████████▎| 9991/10794 [00:05<00:00, 1921.80it/s]

 94%|█████████▍| 10186/10794 [00:05<00:00, 1930.17it/s]

 96%|█████████▌| 10385/10794 [00:05<00:00, 1947.19it/s]

 98%|█████████▊| 10582/10794 [00:05<00:00, 1953.71it/s]

100%|█████████▉| 10781/10794 [00:06<00:00, 1963.30it/s]

100%|██████████| 10794/10794 [00:06<00:00, 1772.10it/s]




In [6]:
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.models import Model
from allennlp.nn.util import get_text_field_mask

class LanguageModel(Model):
    def __init__(self, encoder: nn.RNN, vocab: Vocabulary,
                 embedding_dim: int=50):
        super().__init__(vocab=vocab)
        # char embedding
        self.vocab_size = vocab.get_vocab_size()
        self.padding_idx = vocab.get_token_index("@@PADDING@@")
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size(),
            embedding_dim=embedding_dim,
            padding_index=self.padding_idx,
        )
        self.embedding = BasicTextFieldEmbedder({"tokens": token_embedding})
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.hidden_size, self.vocab_size)
        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
    
    def forward(self, input_tokens: Dict[str, torch.Tensor],
                output_tokens: Dict[str, torch.Tensor]):
        # TODO: Implement
        embs = self.embedding(input_tokens)
        x, _ = self.encoder(embs)
        x = self.projection(x)
        if output_tokens is not None:
            loss = self.loss(x.view((-1, self.vocab_size)), output_tokens["tokens"].flatten())
        else:
            loss = None
        return {"loss": loss, "logits": x}

Now, let's try training

In [8]:
lm_naive = LanguageModel(NaiveLSTM(50, 125), vocab)
train(lm_naive, epochs=N_EPOCHS)



11/19/2019 08:30:30 - INFO - allennlp.training.trainer -   Beginning training.


11/19/2019 08:30:30 - INFO - allennlp.training.trainer -   Epoch 0/0


11/19/2019 08:30:30 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1658.848


11/19/2019 08:30:30 - INFO - allennlp.training.trainer -   Training


  0%|          | 0/338 [00:00<?, ?it/s]

loss: 4.2394 ||:   0%|          | 1/338 [00:05<32:39,  5.81s/it]

loss: 4.2369 ||:   1%|          | 2/338 [00:08<28:08,  5.03s/it]

loss: 4.2340 ||:   1%|          | 3/338 [00:11<24:15,  4.34s/it]

loss: 4.2307 ||:   1%|          | 4/338 [00:14<21:35,  3.88s/it]

loss: 4.2278 ||:   1%|▏         | 5/338 [00:17<19:46,  3.56s/it]

loss: 4.2245 ||:   2%|▏         | 6/338 [00:19<17:39,  3.19s/it]

loss: 4.2211 ||:   2%|▏         | 7/338 [00:22<17:23,  3.15s/it]

loss: 4.2172 ||:   2%|▏         | 8/338 [00:25<16:36,  3.02s/it]

loss: 4.2129 ||:   3%|▎         | 9/338 [00:28<15:59,  2.92s/it]

loss: 4.2079 ||:   3%|▎         | 10/338 [00:30<15:09,  2.77s/it]

loss: 4.2017 ||:   3%|▎         | 11/338 [00:32<14:17,  2.62s/it]

loss: 4.1942 ||:   4%|▎         | 12/338 [00:35<13:29,  2.48s/it]

loss: 4.1840 ||:   4%|▍         | 13/338 [00:37<13:02,  2.41s/it]

loss: 4.1695 ||:   4%|▍         | 14/338 [00:39<13:00,  2.41s/it]

loss: 4.1483 ||:   4%|▍         | 15/338 [00:41<12:47,  2.38s/it]

loss: 4.1200 ||:   5%|▍         | 16/338 [00:44<12:53,  2.40s/it]

loss: 4.0869 ||:   5%|▌         | 17/338 [00:46<12:52,  2.41s/it]

loss: 4.0509 ||:   5%|▌         | 18/338 [00:49<12:32,  2.35s/it]

loss: 4.0137 ||:   6%|▌         | 19/338 [00:51<12:37,  2.37s/it]

loss: 3.9757 ||:   6%|▌         | 20/338 [00:53<12:33,  2.37s/it]

loss: 3.9389 ||:   6%|▌         | 21/338 [00:56<12:16,  2.32s/it]

loss: 3.9022 ||:   7%|▋         | 22/338 [00:58<12:09,  2.31s/it]

loss: 3.8681 ||:   7%|▋         | 23/338 [01:00<12:06,  2.30s/it]

loss: 3.8359 ||:   7%|▋         | 24/338 [01:02<12:03,  2.30s/it]

loss: 3.8042 ||:   7%|▋         | 25/338 [01:05<12:09,  2.33s/it]

loss: 3.7747 ||:   8%|▊         | 26/338 [01:07<12:20,  2.37s/it]

loss: 3.7475 ||:   8%|▊         | 27/338 [01:10<12:42,  2.45s/it]

loss: 3.7216 ||:   8%|▊         | 28/338 [01:13<12:56,  2.51s/it]

KeyboardInterrupt: 

Now, let's compare with the official LSTM

In [None]:
lm_comparison = LanguageModel(nn.LSTM(50, 125, batch_first=True), vocab)
train(lm_comparison, epochs=N_EPOCHS)

Out model is a lot slower, but we're getting similar performance, so it looks good! We'll look at how we can optimize the performance later.

Now, let's compare the performance of the LSTM with a much simpler RNN

In [7]:
class SimpleRNN(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_sz, self.hidden_size = input_sz, hidden_sz
        self.weight_ih = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.weight_hh = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.bias_hh = Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias_hh)
    
    def forward(self, x: torch.Tensor, init_state=None) -> torch.Tensor:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_state is None:
            h_t = torch.zeros(self.hidden_size).to(x.device)
        else:
            h_t = init_state

        for t in range(seq_sz):
            x_t = x[:, t, :]
            h_t = torch.tanh(x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias_hh)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, h_t

In [8]:
lm_simplernn = LanguageModel(SimpleRNN(50, 125), vocab)
train(lm_simplernn, epochs=N_EPOCHS)



11/19/2019 14:31:37 - INFO - allennlp.training.trainer -   Beginning training.


11/19/2019 14:31:37 - INFO - allennlp.training.trainer -   Epoch 0/0


11/19/2019 14:31:37 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1654.884


11/19/2019 14:31:37 - INFO - allennlp.training.trainer -   Training


  0%|          | 0/338 [00:00<?, ?it/s]

loss: 4.2792 ||:   0%|          | 1/338 [00:03<20:24,  3.63s/it]

loss: 4.2664 ||:   1%|          | 2/338 [00:05<16:33,  2.96s/it]

loss: 4.2529 ||:   1%|          | 3/338 [00:06<13:59,  2.51s/it]

loss: 4.2321 ||:   1%|          | 4/338 [00:07<12:01,  2.16s/it]

loss: 4.1972 ||:   1%|▏         | 5/338 [00:09<10:43,  1.93s/it]

loss: 4.1549 ||:   2%|▏         | 6/338 [00:10<09:48,  1.77s/it]

loss: 4.1093 ||:   2%|▏         | 7/338 [00:11<08:56,  1.62s/it]

loss: 4.0618 ||:   2%|▏         | 8/338 [00:13<08:29,  1.54s/it]

loss: 4.0141 ||:   3%|▎         | 9/338 [00:14<08:05,  1.48s/it]

loss: 3.9715 ||:   3%|▎         | 10/338 [00:15<07:51,  1.44s/it]

loss: 3.9230 ||:   3%|▎         | 11/338 [00:17<07:44,  1.42s/it]

loss: 3.8766 ||:   4%|▎         | 12/338 [00:18<07:28,  1.37s/it]

loss: 3.8327 ||:   4%|▍         | 13/338 [00:19<07:19,  1.35s/it]

loss: 3.7913 ||:   4%|▍         | 14/338 [00:21<07:16,  1.35s/it]

loss: 3.7506 ||:   4%|▍         | 15/338 [00:22<07:09,  1.33s/it]

loss: 3.7129 ||:   5%|▍         | 16/338 [00:23<07:19,  1.36s/it]

loss: 3.6782 ||:   5%|▌         | 17/338 [00:25<07:09,  1.34s/it]

loss: 3.6453 ||:   5%|▌         | 18/338 [00:26<07:13,  1.35s/it]

loss: 3.6142 ||:   6%|▌         | 19/338 [00:28<07:17,  1.37s/it]

loss: 3.5855 ||:   6%|▌         | 20/338 [00:29<07:06,  1.34s/it]

loss: 3.5597 ||:   6%|▌         | 21/338 [00:30<07:04,  1.34s/it]

loss: 3.5353 ||:   7%|▋         | 22/338 [00:31<06:46,  1.29s/it]

loss: 3.5127 ||:   7%|▋         | 23/338 [00:33<06:53,  1.31s/it]

loss: 3.4918 ||:   7%|▋         | 24/338 [00:34<06:57,  1.33s/it]

loss: 3.4720 ||:   7%|▋         | 25/338 [00:35<06:57,  1.33s/it]

loss: 3.4547 ||:   8%|▊         | 26/338 [00:37<07:11,  1.38s/it]

loss: 3.4374 ||:   8%|▊         | 27/338 [00:38<07:10,  1.38s/it]

loss: 3.4209 ||:   8%|▊         | 28/338 [00:40<07:03,  1.36s/it]

loss: 3.4055 ||:   9%|▊         | 29/338 [00:41<06:59,  1.36s/it]

loss: 3.3919 ||:   9%|▉         | 30/338 [00:42<06:59,  1.36s/it]

loss: 3.3792 ||:   9%|▉         | 31/338 [00:44<06:55,  1.35s/it]

loss: 3.3671 ||:   9%|▉         | 32/338 [00:45<06:56,  1.36s/it]

loss: 3.3551 ||:  10%|▉         | 33/338 [00:46<06:47,  1.34s/it]

loss: 3.3439 ||:  10%|█         | 34/338 [00:48<06:48,  1.35s/it]

loss: 3.3335 ||:  10%|█         | 35/338 [00:49<06:52,  1.36s/it]

loss: 3.3232 ||:  11%|█         | 36/338 [00:50<06:43,  1.33s/it]

loss: 3.3138 ||:  11%|█         | 37/338 [00:52<06:44,  1.34s/it]

loss: 3.3044 ||:  11%|█         | 38/338 [00:53<06:50,  1.37s/it]

loss: 3.2961 ||:  12%|█▏        | 39/338 [00:54<06:42,  1.35s/it]

loss: 3.2875 ||:  12%|█▏        | 40/338 [00:56<06:45,  1.36s/it]

loss: 3.2793 ||:  12%|█▏        | 41/338 [00:57<06:48,  1.38s/it]

loss: 3.2718 ||:  12%|█▏        | 42/338 [00:59<06:51,  1.39s/it]

loss: 3.2649 ||:  13%|█▎        | 43/338 [01:00<06:55,  1.41s/it]

loss: 3.2587 ||:  13%|█▎        | 44/338 [01:02<06:54,  1.41s/it]

loss: 3.2528 ||:  13%|█▎        | 45/338 [01:03<06:41,  1.37s/it]

loss: 3.2465 ||:  14%|█▎        | 46/338 [01:04<06:41,  1.37s/it]

loss: 3.2402 ||:  14%|█▍        | 47/338 [01:05<06:35,  1.36s/it]

loss: 3.2340 ||:  14%|█▍        | 48/338 [01:07<06:30,  1.35s/it]

loss: 3.2280 ||:  14%|█▍        | 49/338 [01:08<06:47,  1.41s/it]

loss: 3.2222 ||:  15%|█▍        | 50/338 [01:10<06:49,  1.42s/it]

loss: 3.2169 ||:  15%|█▌        | 51/338 [01:12<07:56,  1.66s/it]

loss: 3.2116 ||:  15%|█▌        | 52/338 [01:15<09:04,  1.90s/it]

loss: 3.2066 ||:  16%|█▌        | 53/338 [01:16<08:53,  1.87s/it]

loss: 3.2011 ||:  16%|█▌        | 54/338 [01:18<08:01,  1.69s/it]

loss: 3.1963 ||:  16%|█▋        | 55/338 [01:20<08:32,  1.81s/it]

loss: 3.1918 ||:  17%|█▋        | 56/338 [01:21<08:25,  1.79s/it]

loss: 3.1870 ||:  17%|█▋        | 57/338 [01:23<08:21,  1.78s/it]

loss: 3.1827 ||:  17%|█▋        | 58/338 [01:24<07:37,  1.63s/it]

loss: 3.1785 ||:  17%|█▋        | 59/338 [01:26<07:05,  1.53s/it]

loss: 3.1742 ||:  18%|█▊        | 60/338 [01:27<06:34,  1.42s/it]

loss: 3.1698 ||:  18%|█▊        | 61/338 [01:28<06:21,  1.38s/it]

loss: 3.1656 ||:  18%|█▊        | 62/338 [01:30<06:22,  1.38s/it]

loss: 3.1616 ||:  19%|█▊        | 63/338 [01:31<06:11,  1.35s/it]

loss: 3.1575 ||:  19%|█▉        | 64/338 [01:32<06:07,  1.34s/it]

loss: 3.1535 ||:  19%|█▉        | 65/338 [01:33<05:50,  1.29s/it]

loss: 3.1499 ||:  20%|█▉        | 66/338 [01:35<05:48,  1.28s/it]

loss: 3.1457 ||:  20%|█▉        | 67/338 [01:36<06:18,  1.40s/it]

loss: 3.1421 ||:  20%|██        | 68/338 [01:38<06:06,  1.36s/it]

loss: 3.1381 ||:  20%|██        | 69/338 [01:39<06:04,  1.36s/it]

loss: 3.1344 ||:  21%|██        | 70/338 [01:40<06:00,  1.35s/it]

loss: 3.1307 ||:  21%|██        | 71/338 [01:42<05:58,  1.34s/it]

loss: 3.1270 ||:  21%|██▏       | 72/338 [01:43<05:57,  1.35s/it]

loss: 3.1235 ||:  22%|██▏       | 73/338 [01:44<05:47,  1.31s/it]

loss: 3.1198 ||:  22%|██▏       | 74/338 [01:45<05:44,  1.31s/it]

loss: 3.1162 ||:  22%|██▏       | 75/338 [01:47<05:49,  1.33s/it]

loss: 3.1128 ||:  22%|██▏       | 76/338 [01:48<05:40,  1.30s/it]

loss: 3.1091 ||:  23%|██▎       | 77/338 [01:49<05:41,  1.31s/it]

loss: 3.1059 ||:  23%|██▎       | 78/338 [01:51<05:48,  1.34s/it]

loss: 3.1024 ||:  23%|██▎       | 79/338 [01:52<05:39,  1.31s/it]

loss: 3.0988 ||:  24%|██▎       | 80/338 [01:53<05:43,  1.33s/it]

loss: 3.0955 ||:  24%|██▍       | 81/338 [01:55<05:46,  1.35s/it]

loss: 3.0917 ||:  24%|██▍       | 82/338 [01:56<05:44,  1.34s/it]

loss: 3.0882 ||:  25%|██▍       | 83/338 [01:58<05:49,  1.37s/it]

loss: 3.0845 ||:  25%|██▍       | 84/338 [01:59<05:38,  1.33s/it]

loss: 3.0809 ||:  25%|██▌       | 85/338 [02:00<05:35,  1.33s/it]

loss: 3.0778 ||:  25%|██▌       | 86/338 [02:01<05:38,  1.34s/it]

loss: 3.0743 ||:  26%|██▌       | 87/338 [02:03<05:32,  1.32s/it]

loss: 3.0706 ||:  26%|██▌       | 88/338 [02:04<05:30,  1.32s/it]

loss: 3.0673 ||:  26%|██▋       | 89/338 [02:05<05:22,  1.29s/it]

loss: 3.0638 ||:  27%|██▋       | 90/338 [02:07<05:35,  1.35s/it]

loss: 3.0604 ||:  27%|██▋       | 91/338 [02:10<08:02,  1.95s/it]

loss: 3.0567 ||:  27%|██▋       | 92/338 [02:12<07:17,  1.78s/it]

loss: 3.0532 ||:  28%|██▊       | 93/338 [02:13<06:51,  1.68s/it]

loss: 3.0517 ||:  28%|██▊       | 94/338 [02:16<08:00,  1.97s/it]

loss: 3.0479 ||:  28%|██▊       | 95/338 [02:17<07:46,  1.92s/it]

loss: 3.0442 ||:  28%|██▊       | 96/338 [02:19<07:36,  1.88s/it]

loss: 3.0408 ||:  29%|██▊       | 97/338 [02:23<10:07,  2.52s/it]

loss: 3.0372 ||:  29%|██▉       | 98/338 [02:25<08:51,  2.21s/it]

loss: 3.0340 ||:  29%|██▉       | 99/338 [02:26<07:55,  1.99s/it]

loss: 3.0304 ||:  30%|██▉       | 100/338 [02:28<07:29,  1.89s/it]

loss: 3.0271 ||:  30%|██▉       | 101/338 [02:31<08:38,  2.19s/it]

loss: 3.0253 ||:  30%|███       | 102/338 [02:32<07:52,  2.00s/it]

loss: 3.0220 ||:  30%|███       | 103/338 [02:34<07:35,  1.94s/it]

loss: 3.0184 ||:  31%|███       | 104/338 [02:36<07:00,  1.80s/it]

loss: 3.0146 ||:  31%|███       | 105/338 [02:37<06:28,  1.67s/it]

loss: 3.0114 ||:  31%|███▏      | 106/338 [02:38<06:10,  1.60s/it]

loss: 3.0079 ||:  32%|███▏      | 107/338 [02:40<05:51,  1.52s/it]

loss: 3.0046 ||:  32%|███▏      | 108/338 [02:41<05:49,  1.52s/it]

loss: 3.0010 ||:  32%|███▏      | 109/338 [02:43<05:48,  1.52s/it]

loss: 2.9977 ||:  33%|███▎      | 110/338 [02:44<05:38,  1.48s/it]

loss: 2.9942 ||:  33%|███▎      | 111/338 [02:46<05:32,  1.46s/it]

loss: 2.9911 ||:  33%|███▎      | 112/338 [02:47<05:32,  1.47s/it]

loss: 2.9876 ||:  33%|███▎      | 113/338 [02:48<05:26,  1.45s/it]

loss: 2.9841 ||:  34%|███▎      | 114/338 [02:50<05:20,  1.43s/it]

loss: 2.9807 ||:  34%|███▍      | 115/338 [02:51<05:27,  1.47s/it]

loss: 2.9787 ||:  34%|███▍      | 116/338 [02:53<05:30,  1.49s/it]

loss: 2.9756 ||:  35%|███▍      | 117/338 [02:54<05:21,  1.46s/it]

loss: 2.9722 ||:  35%|███▍      | 118/338 [02:56<05:19,  1.45s/it]

loss: 2.9689 ||:  35%|███▌      | 119/338 [02:57<05:35,  1.53s/it]

loss: 2.9656 ||:  36%|███▌      | 120/338 [02:59<05:33,  1.53s/it]

loss: 2.9622 ||:  36%|███▌      | 121/338 [03:01<05:37,  1.56s/it]

loss: 2.9589 ||:  36%|███▌      | 122/338 [03:02<05:30,  1.53s/it]

loss: 2.9556 ||:  36%|███▋      | 123/338 [03:04<05:33,  1.55s/it]

loss: 2.9522 ||:  37%|███▋      | 124/338 [03:06<05:57,  1.67s/it]

loss: 2.9491 ||:  37%|███▋      | 125/338 [03:07<05:45,  1.62s/it]

loss: 2.9460 ||:  37%|███▋      | 126/338 [03:11<08:18,  2.35s/it]

loss: 2.9429 ||:  38%|███▊      | 127/338 [03:14<09:07,  2.59s/it]

loss: 2.9399 ||:  38%|███▊      | 128/338 [03:16<07:51,  2.25s/it]

loss: 2.9369 ||:  38%|███▊      | 129/338 [03:17<06:55,  1.99s/it]

loss: 2.9338 ||:  38%|███▊      | 130/338 [03:19<06:28,  1.87s/it]

loss: 2.9308 ||:  39%|███▉      | 131/338 [03:20<06:12,  1.80s/it]

loss: 2.9280 ||:  39%|███▉      | 132/338 [03:22<06:08,  1.79s/it]

loss: 2.9250 ||:  39%|███▉      | 133/338 [03:23<05:36,  1.64s/it]

loss: 2.9222 ||:  40%|███▉      | 134/338 [03:25<05:10,  1.52s/it]

loss: 2.9191 ||:  40%|███▉      | 135/338 [03:26<04:58,  1.47s/it]

loss: 2.9162 ||:  40%|████      | 136/338 [03:27<04:49,  1.43s/it]

loss: 2.9135 ||:  41%|████      | 137/338 [03:30<05:40,  1.69s/it]

loss: 2.9106 ||:  41%|████      | 138/338 [03:32<06:25,  1.93s/it]

loss: 2.9075 ||:  41%|████      | 139/338 [03:35<06:58,  2.10s/it]

loss: 2.9046 ||:  41%|████▏     | 140/338 [03:39<09:23,  2.84s/it]

loss: 2.9016 ||:  42%|████▏     | 141/338 [03:43<09:49,  2.99s/it]

loss: 2.8997 ||:  42%|████▏     | 142/338 [03:45<09:03,  2.77s/it]

loss: 2.8969 ||:  42%|████▏     | 143/338 [03:47<08:32,  2.63s/it]

loss: 2.8941 ||:  43%|████▎     | 144/338 [03:50<08:49,  2.73s/it]

loss: 2.8911 ||:  43%|████▎     | 145/338 [03:51<07:27,  2.32s/it]

loss: 2.8885 ||:  43%|████▎     | 146/338 [03:53<06:33,  2.05s/it]

loss: 2.8856 ||:  43%|████▎     | 147/338 [03:54<05:49,  1.83s/it]

loss: 2.8827 ||:  44%|████▍     | 148/338 [03:56<05:20,  1.69s/it]

loss: 2.8801 ||:  44%|████▍     | 149/338 [03:57<05:15,  1.67s/it]

loss: 2.8775 ||:  44%|████▍     | 150/338 [03:59<05:01,  1.60s/it]

loss: 2.8748 ||:  45%|████▍     | 151/338 [04:00<04:47,  1.54s/it]

loss: 2.8724 ||:  45%|████▍     | 152/338 [04:01<04:27,  1.44s/it]

loss: 2.8696 ||:  45%|████▌     | 153/338 [04:03<04:20,  1.41s/it]

loss: 2.8670 ||:  46%|████▌     | 154/338 [04:04<04:18,  1.40s/it]

loss: 2.8644 ||:  46%|████▌     | 155/338 [04:05<04:08,  1.36s/it]

loss: 2.8617 ||:  46%|████▌     | 156/338 [04:07<04:05,  1.35s/it]

loss: 2.8590 ||:  46%|████▋     | 157/338 [04:08<04:07,  1.37s/it]

loss: 2.8563 ||:  47%|████▋     | 158/338 [04:09<04:03,  1.35s/it]

loss: 2.8536 ||:  47%|████▋     | 159/338 [04:11<04:03,  1.36s/it]

loss: 2.8513 ||:  47%|████▋     | 160/338 [04:12<03:59,  1.35s/it]

loss: 2.8487 ||:  48%|████▊     | 161/338 [04:13<03:55,  1.33s/it]

loss: 2.8461 ||:  48%|████▊     | 162/338 [04:15<03:58,  1.35s/it]

loss: 2.8436 ||:  48%|████▊     | 163/338 [04:16<03:50,  1.32s/it]

loss: 2.8411 ||:  49%|████▊     | 164/338 [04:17<03:51,  1.33s/it]

loss: 2.8388 ||:  49%|████▉     | 165/338 [04:19<03:54,  1.36s/it]

loss: 2.8366 ||:  49%|████▉     | 166/338 [04:20<03:56,  1.38s/it]

loss: 2.8341 ||:  49%|████▉     | 167/338 [04:22<04:01,  1.41s/it]

loss: 2.8319 ||:  50%|████▉     | 168/338 [04:23<04:07,  1.46s/it]

loss: 2.8295 ||:  50%|█████     | 169/338 [04:24<03:56,  1.40s/it]

loss: 2.8271 ||:  50%|█████     | 170/338 [04:26<03:53,  1.39s/it]

loss: 2.8248 ||:  51%|█████     | 171/338 [04:27<03:55,  1.41s/it]

loss: 2.8226 ||:  51%|█████     | 172/338 [04:28<03:45,  1.36s/it]

loss: 2.8205 ||:  51%|█████     | 173/338 [04:30<03:47,  1.38s/it]

loss: 2.8181 ||:  51%|█████▏    | 174/338 [04:31<03:42,  1.36s/it]

loss: 2.8159 ||:  52%|█████▏    | 175/338 [04:33<03:39,  1.35s/it]

loss: 2.8136 ||:  52%|█████▏    | 176/338 [04:34<03:40,  1.36s/it]

loss: 2.8114 ||:  52%|█████▏    | 177/338 [04:35<03:37,  1.35s/it]

loss: 2.8095 ||:  53%|█████▎    | 178/338 [04:37<03:37,  1.36s/it]

loss: 2.8071 ||:  53%|█████▎    | 179/338 [04:38<03:38,  1.38s/it]

loss: 2.8050 ||:  53%|█████▎    | 180/338 [04:39<03:36,  1.37s/it]

loss: 2.8029 ||:  54%|█████▎    | 181/338 [04:41<03:39,  1.40s/it]

loss: 2.8007 ||:  54%|█████▍    | 182/338 [04:42<03:41,  1.42s/it]

loss: 2.7986 ||:  54%|█████▍    | 183/338 [04:44<03:35,  1.39s/it]

loss: 2.7965 ||:  54%|█████▍    | 184/338 [04:45<03:38,  1.42s/it]

loss: 2.7944 ||:  55%|█████▍    | 185/338 [04:47<03:37,  1.42s/it]

loss: 2.7923 ||:  55%|█████▌    | 186/338 [04:48<03:29,  1.38s/it]

loss: 2.7902 ||:  55%|█████▌    | 187/338 [04:49<03:29,  1.39s/it]

loss: 2.7886 ||:  56%|█████▌    | 188/338 [04:51<03:24,  1.36s/it]

loss: 2.7864 ||:  56%|█████▌    | 189/338 [04:52<03:19,  1.34s/it]

loss: 2.7844 ||:  56%|█████▌    | 190/338 [04:53<03:30,  1.43s/it]

loss: 2.7824 ||:  57%|█████▋    | 191/338 [04:55<03:25,  1.40s/it]

loss: 2.7804 ||:  57%|█████▋    | 192/338 [04:56<03:24,  1.40s/it]

loss: 2.7786 ||:  57%|█████▋    | 193/338 [04:58<03:35,  1.48s/it]

loss: 2.7766 ||:  57%|█████▋    | 194/338 [04:59<03:26,  1.43s/it]

loss: 2.7746 ||:  58%|█████▊    | 195/338 [05:00<03:17,  1.38s/it]

loss: 2.7729 ||:  58%|█████▊    | 196/338 [05:02<03:15,  1.38s/it]

loss: 2.7709 ||:  58%|█████▊    | 197/338 [05:03<03:10,  1.35s/it]

loss: 2.7688 ||:  59%|█████▊    | 198/338 [05:04<03:08,  1.34s/it]

loss: 2.7669 ||:  59%|█████▉    | 199/338 [05:06<03:09,  1.36s/it]

loss: 2.7651 ||:  59%|█████▉    | 200/338 [05:07<03:04,  1.34s/it]

loss: 2.7630 ||:  59%|█████▉    | 201/338 [05:09<03:12,  1.40s/it]

loss: 2.7610 ||:  60%|█████▉    | 202/338 [05:10<03:11,  1.41s/it]

loss: 2.7593 ||:  60%|██████    | 203/338 [05:11<03:05,  1.38s/it]

loss: 2.7577 ||:  60%|██████    | 204/338 [05:13<03:22,  1.51s/it]

loss: 2.7559 ||:  61%|██████    | 205/338 [05:15<03:24,  1.53s/it]

loss: 2.7541 ||:  61%|██████    | 206/338 [05:16<03:15,  1.48s/it]

loss: 2.7522 ||:  61%|██████    | 207/338 [05:18<03:09,  1.45s/it]

loss: 2.7504 ||:  62%|██████▏   | 208/338 [05:19<03:10,  1.47s/it]

loss: 2.7485 ||:  62%|██████▏   | 209/338 [05:20<03:01,  1.41s/it]

loss: 2.7467 ||:  62%|██████▏   | 210/338 [05:22<02:59,  1.40s/it]

loss: 2.7450 ||:  62%|██████▏   | 211/338 [05:23<03:01,  1.43s/it]

loss: 2.7432 ||:  63%|██████▎   | 212/338 [05:24<02:53,  1.38s/it]

loss: 2.7414 ||:  63%|██████▎   | 213/338 [05:26<02:50,  1.36s/it]

loss: 2.7396 ||:  63%|██████▎   | 214/338 [05:27<02:59,  1.45s/it]

loss: 2.7378 ||:  64%|██████▎   | 215/338 [05:30<03:40,  1.79s/it]

loss: 2.7361 ||:  64%|██████▍   | 216/338 [05:32<03:31,  1.73s/it]

loss: 2.7343 ||:  64%|██████▍   | 217/338 [05:33<03:11,  1.58s/it]

loss: 2.7325 ||:  64%|██████▍   | 218/338 [05:34<03:04,  1.54s/it]

loss: 2.7309 ||:  65%|██████▍   | 219/338 [05:36<03:01,  1.53s/it]

loss: 2.7292 ||:  65%|██████▌   | 220/338 [05:37<02:52,  1.47s/it]

loss: 2.7276 ||:  65%|██████▌   | 221/338 [05:38<02:48,  1.44s/it]

loss: 2.7259 ||:  66%|██████▌   | 222/338 [05:40<02:50,  1.47s/it]

loss: 2.7243 ||:  66%|██████▌   | 223/338 [05:42<02:54,  1.52s/it]

loss: 2.7226 ||:  66%|██████▋   | 224/338 [05:45<03:53,  2.05s/it]

loss: 2.7211 ||:  67%|██████▋   | 225/338 [05:47<03:40,  1.96s/it]

loss: 2.7194 ||:  67%|██████▋   | 226/338 [05:48<03:25,  1.84s/it]

loss: 2.7178 ||:  67%|██████▋   | 227/338 [05:50<03:14,  1.75s/it]

loss: 2.7161 ||:  67%|██████▋   | 228/338 [05:51<03:08,  1.72s/it]

loss: 2.7145 ||:  68%|██████▊   | 229/338 [05:53<03:16,  1.80s/it]

loss: 2.7127 ||:  68%|██████▊   | 230/338 [05:57<04:12,  2.34s/it]

loss: 2.7111 ||:  68%|██████▊   | 231/338 [06:02<05:31,  3.10s/it]

loss: 2.7094 ||:  69%|██████▊   | 232/338 [06:05<05:29,  3.11s/it]

loss: 2.7081 ||:  69%|██████▉   | 233/338 [06:07<04:58,  2.84s/it]

loss: 2.7065 ||:  69%|██████▉   | 234/338 [06:11<05:21,  3.09s/it]

loss: 2.7048 ||:  70%|██████▉   | 235/338 [06:14<05:09,  3.00s/it]

loss: 2.7031 ||:  70%|██████▉   | 236/338 [06:16<04:52,  2.87s/it]

loss: 2.7015 ||:  70%|███████   | 237/338 [06:19<04:51,  2.89s/it]

loss: 2.6999 ||:  70%|███████   | 238/338 [06:23<05:24,  3.24s/it]

loss: 2.6983 ||:  71%|███████   | 239/338 [06:26<04:58,  3.02s/it]

loss: 2.6967 ||:  71%|███████   | 240/338 [06:28<04:33,  2.79s/it]

loss: 2.6951 ||:  71%|███████▏  | 241/338 [06:31<04:46,  2.95s/it]

loss: 2.6936 ||:  72%|███████▏  | 242/338 [06:35<05:05,  3.18s/it]

loss: 2.6920 ||:  72%|███████▏  | 243/338 [06:39<05:14,  3.31s/it]

loss: 2.6904 ||:  72%|███████▏  | 244/338 [06:42<05:07,  3.27s/it]

loss: 2.6888 ||:  72%|███████▏  | 245/338 [06:44<04:42,  3.04s/it]

loss: 2.6873 ||:  73%|███████▎  | 246/338 [06:49<05:11,  3.38s/it]

loss: 2.6858 ||:  73%|███████▎  | 247/338 [06:52<05:11,  3.43s/it]

loss: 2.6842 ||:  73%|███████▎  | 248/338 [06:55<04:48,  3.21s/it]

loss: 2.6827 ||:  74%|███████▎  | 249/338 [06:59<05:09,  3.48s/it]

loss: 2.6813 ||:  74%|███████▍  | 250/338 [07:03<05:09,  3.52s/it]

loss: 2.6798 ||:  74%|███████▍  | 251/338 [07:05<04:51,  3.36s/it]

loss: 2.6784 ||:  75%|███████▍  | 252/338 [07:08<04:36,  3.22s/it]

loss: 2.6770 ||:  75%|███████▍  | 253/338 [07:11<04:08,  2.92s/it]

loss: 2.6755 ||:  75%|███████▌  | 254/338 [07:14<04:24,  3.14s/it]

loss: 2.6741 ||:  75%|███████▌  | 255/338 [07:18<04:25,  3.20s/it]

loss: 2.6727 ||:  76%|███████▌  | 256/338 [07:21<04:35,  3.36s/it]

loss: 2.6713 ||:  76%|███████▌  | 257/338 [07:24<04:09,  3.09s/it]

loss: 2.6698 ||:  76%|███████▋  | 258/338 [07:28<04:28,  3.36s/it]

loss: 2.6683 ||:  77%|███████▋  | 259/338 [07:31<04:16,  3.24s/it]

loss: 2.6668 ||:  77%|███████▋  | 260/338 [07:34<04:21,  3.35s/it]

loss: 2.6652 ||:  77%|███████▋  | 261/338 [07:37<04:08,  3.22s/it]

loss: 2.6639 ||:  78%|███████▊  | 262/338 [07:41<04:16,  3.38s/it]

loss: 2.6625 ||:  78%|███████▊  | 263/338 [07:44<04:03,  3.25s/it]

loss: 2.6610 ||:  78%|███████▊  | 264/338 [07:46<03:42,  3.01s/it]

loss: 2.6596 ||:  78%|███████▊  | 265/338 [07:49<03:39,  3.00s/it]

loss: 2.6583 ||:  79%|███████▊  | 266/338 [07:53<03:44,  3.12s/it]

loss: 2.6570 ||:  79%|███████▉  | 267/338 [07:56<03:46,  3.19s/it]

loss: 2.6556 ||:  79%|███████▉  | 268/338 [07:59<03:34,  3.07s/it]

loss: 2.6543 ||:  80%|███████▉  | 269/338 [08:02<03:36,  3.14s/it]

loss: 2.6529 ||:  80%|███████▉  | 270/338 [08:06<03:41,  3.26s/it]

loss: 2.6515 ||:  80%|████████  | 271/338 [08:09<03:34,  3.20s/it]

loss: 2.6500 ||:  80%|████████  | 272/338 [08:12<03:34,  3.25s/it]

loss: 2.6488 ||:  81%|████████  | 273/338 [08:16<03:37,  3.35s/it]

loss: 2.6474 ||:  81%|████████  | 274/338 [08:18<03:17,  3.09s/it]

loss: 2.6461 ||:  81%|████████▏ | 275/338 [08:22<03:19,  3.16s/it]

loss: 2.6449 ||:  82%|████████▏ | 276/338 [08:25<03:23,  3.28s/it]

loss: 2.6435 ||:  82%|████████▏ | 277/338 [08:29<03:26,  3.39s/it]

loss: 2.6421 ||:  82%|████████▏ | 278/338 [08:32<03:11,  3.20s/it]

loss: 2.6408 ||:  83%|████████▎ | 279/338 [08:35<03:21,  3.42s/it]

loss: 2.6394 ||:  83%|████████▎ | 280/338 [08:38<03:01,  3.13s/it]

loss: 2.6383 ||:  83%|████████▎ | 281/338 [08:42<03:10,  3.35s/it]

loss: 2.6371 ||:  83%|████████▎ | 282/338 [08:46<03:15,  3.48s/it]

loss: 2.6357 ||:  84%|████████▎ | 283/338 [08:49<03:12,  3.51s/it]

loss: 2.6344 ||:  84%|████████▍ | 284/338 [08:53<03:11,  3.55s/it]

loss: 2.6332 ||:  84%|████████▍ | 285/338 [08:55<02:50,  3.21s/it]

loss: 2.6319 ||:  85%|████████▍ | 286/338 [08:59<02:48,  3.23s/it]

loss: 2.6306 ||:  85%|████████▍ | 287/338 [09:01<02:34,  3.04s/it]

loss: 2.6292 ||:  85%|████████▌ | 288/338 [09:04<02:29,  2.98s/it]

loss: 2.6279 ||:  86%|████████▌ | 289/338 [09:07<02:20,  2.88s/it]

loss: 2.6266 ||:  86%|████████▌ | 290/338 [09:09<02:14,  2.80s/it]

loss: 2.6254 ||:  86%|████████▌ | 291/338 [09:12<02:15,  2.88s/it]

loss: 2.6241 ||:  86%|████████▋ | 292/338 [09:15<02:14,  2.93s/it]

loss: 2.6229 ||:  87%|████████▋ | 293/338 [09:19<02:17,  3.07s/it]

loss: 2.6216 ||:  87%|████████▋ | 294/338 [09:22<02:16,  3.10s/it]

loss: 2.6203 ||:  87%|████████▋ | 295/338 [09:25<02:12,  3.08s/it]

loss: 2.6190 ||:  88%|████████▊ | 296/338 [09:28<02:07,  3.03s/it]

loss: 2.6177 ||:  88%|████████▊ | 297/338 [09:32<02:12,  3.23s/it]

loss: 2.6165 ||:  88%|████████▊ | 298/338 [09:34<02:00,  3.01s/it]

loss: 2.6153 ||:  88%|████████▊ | 299/338 [09:37<01:53,  2.91s/it]

loss: 2.6141 ||:  89%|████████▉ | 300/338 [09:39<01:45,  2.78s/it]

loss: 2.6129 ||:  89%|████████▉ | 301/338 [09:42<01:48,  2.93s/it]

loss: 2.6117 ||:  89%|████████▉ | 302/338 [09:46<01:50,  3.06s/it]

loss: 2.6104 ||:  90%|████████▉ | 303/338 [09:49<01:48,  3.11s/it]

loss: 2.6091 ||:  90%|████████▉ | 304/338 [09:52<01:43,  3.04s/it]

loss: 2.6081 ||:  90%|█████████ | 305/338 [09:55<01:38,  2.98s/it]

loss: 2.6068 ||:  91%|█████████ | 306/338 [09:58<01:41,  3.16s/it]

loss: 2.6055 ||:  91%|█████████ | 307/338 [10:01<01:34,  3.04s/it]

loss: 2.6044 ||:  91%|█████████ | 308/338 [10:04<01:32,  3.07s/it]

loss: 2.6032 ||:  91%|█████████▏| 309/338 [10:08<01:38,  3.39s/it]

loss: 2.6019 ||:  92%|█████████▏| 310/338 [10:11<01:27,  3.11s/it]

loss: 2.6008 ||:  92%|█████████▏| 311/338 [10:14<01:27,  3.26s/it]

loss: 2.5997 ||:  92%|█████████▏| 312/338 [10:17<01:18,  3.00s/it]

loss: 2.5986 ||:  93%|█████████▎| 313/338 [10:20<01:17,  3.11s/it]

loss: 2.5975 ||:  93%|█████████▎| 314/338 [10:23<01:14,  3.09s/it]

loss: 2.5963 ||:  93%|█████████▎| 315/338 [10:26<01:10,  3.07s/it]

loss: 2.5952 ||:  93%|█████████▎| 316/338 [10:29<01:07,  3.06s/it]

loss: 2.5941 ||:  94%|█████████▍| 317/338 [10:33<01:05,  3.13s/it]

loss: 2.5929 ||:  94%|█████████▍| 318/338 [10:36<01:04,  3.25s/it]

loss: 2.5917 ||:  94%|█████████▍| 319/338 [10:39<01:01,  3.24s/it]

loss: 2.5905 ||:  95%|█████████▍| 320/338 [10:42<00:53,  2.98s/it]

loss: 2.5894 ||:  95%|█████████▍| 321/338 [10:45<00:50,  2.96s/it]

loss: 2.5882 ||:  95%|█████████▌| 322/338 [10:47<00:44,  2.77s/it]

loss: 2.5871 ||:  96%|█████████▌| 323/338 [10:50<00:43,  2.87s/it]

loss: 2.5859 ||:  96%|█████████▌| 324/338 [10:53<00:38,  2.75s/it]

loss: 2.5846 ||:  96%|█████████▌| 325/338 [10:55<00:35,  2.72s/it]

loss: 2.5835 ||:  96%|█████████▋| 326/338 [10:58<00:32,  2.74s/it]

loss: 2.5825 ||:  97%|█████████▋| 327/338 [11:00<00:28,  2.61s/it]

loss: 2.5813 ||:  97%|█████████▋| 328/338 [11:04<00:29,  2.94s/it]

loss: 2.5802 ||:  97%|█████████▋| 329/338 [11:06<00:24,  2.72s/it]

loss: 2.5790 ||:  98%|█████████▊| 330/338 [11:09<00:21,  2.70s/it]

loss: 2.5779 ||:  98%|█████████▊| 331/338 [11:12<00:19,  2.73s/it]

loss: 2.5768 ||:  98%|█████████▊| 332/338 [11:15<00:17,  2.97s/it]

loss: 2.5757 ||:  99%|█████████▊| 333/338 [11:19<00:15,  3.11s/it]

loss: 2.5746 ||:  99%|█████████▉| 334/338 [11:22<00:12,  3.10s/it]

loss: 2.5735 ||:  99%|█████████▉| 335/338 [11:25<00:09,  3.05s/it]

loss: 2.5723 ||:  99%|█████████▉| 336/338 [11:29<00:06,  3.40s/it]

loss: 2.5713 ||: 100%|█████████▉| 337/338 [11:32<00:03,  3.36s/it]

loss: 2.5702 ||: 100%|██████████| 338/338 [11:34<00:00,  2.86s/it]

loss: 2.5702 ||: 100%|██████████| 338/338 [11:34<00:00,  2.05s/it]


11/19/2019 14:43:11 - INFO - allennlp.training.trainer -   Validating


  0%|          | 0/38 [00:00<?, ?it/s]

loss: 2.1944 ||:   3%|▎         | 1/38 [00:00<00:13,  2.74it/s]

loss: 2.1821 ||:   5%|▌         | 2/38 [00:00<00:15,  2.26it/s]

loss: 2.1910 ||:   8%|▊         | 3/38 [00:01<00:18,  1.92it/s]

loss: 2.1909 ||:  11%|█         | 4/38 [00:02<00:18,  1.81it/s]

loss: 2.1921 ||:  13%|█▎        | 5/38 [00:02<00:18,  1.76it/s]

loss: 2.1962 ||:  16%|█▌        | 6/38 [00:03<00:19,  1.65it/s]

loss: 2.1945 ||:  18%|█▊        | 7/38 [00:04<00:20,  1.55it/s]

loss: 2.1972 ||:  21%|██        | 8/38 [00:04<00:19,  1.57it/s]

loss: 2.1995 ||:  24%|██▎       | 9/38 [00:05<00:16,  1.73it/s]

loss: 2.1988 ||:  26%|██▋       | 10/38 [00:05<00:16,  1.72it/s]

loss: 2.1967 ||:  29%|██▉       | 11/38 [00:06<00:14,  1.91it/s]

loss: 2.1963 ||:  32%|███▏      | 12/38 [00:06<00:11,  2.19it/s]

loss: 2.1984 ||:  34%|███▍      | 13/38 [00:06<00:10,  2.47it/s]

loss: 2.1988 ||:  37%|███▋      | 14/38 [00:07<00:08,  2.74it/s]

loss: 2.1982 ||:  39%|███▉      | 15/38 [00:07<00:07,  2.98it/s]

loss: 2.1986 ||:  42%|████▏     | 16/38 [00:08<00:09,  2.31it/s]

loss: 2.1994 ||:  45%|████▍     | 17/38 [00:08<00:09,  2.16it/s]

loss: 2.1999 ||:  47%|████▋     | 18/38 [00:09<00:09,  2.12it/s]

loss: 2.1995 ||:  50%|█████     | 19/38 [00:09<00:09,  2.00it/s]

loss: 2.2002 ||:  53%|█████▎    | 20/38 [00:10<00:08,  2.09it/s]

loss: 2.2004 ||:  55%|█████▌    | 21/38 [00:10<00:09,  1.81it/s]

loss: 2.2013 ||:  58%|█████▊    | 22/38 [00:11<00:08,  1.98it/s]

loss: 2.2023 ||:  61%|██████    | 23/38 [00:12<00:10,  1.46it/s]

loss: 2.2023 ||:  63%|██████▎   | 24/38 [00:13<00:10,  1.39it/s]

loss: 2.2022 ||:  66%|██████▌   | 25/38 [00:13<00:08,  1.45it/s]

loss: 2.2040 ||:  68%|██████▊   | 26/38 [00:14<00:08,  1.37it/s]

loss: 2.2043 ||:  71%|███████   | 27/38 [00:15<00:08,  1.23it/s]

loss: 2.2044 ||:  74%|███████▎  | 28/38 [00:16<00:08,  1.21it/s]

loss: 2.2048 ||:  76%|███████▋  | 29/38 [00:17<00:08,  1.06it/s]

loss: 2.2046 ||:  79%|███████▉  | 30/38 [00:18<00:07,  1.06it/s]

loss: 2.2047 ||:  82%|████████▏ | 31/38 [00:19<00:06,  1.09it/s]

loss: 2.2040 ||:  84%|████████▍ | 32/38 [00:20<00:04,  1.22it/s]

loss: 2.2042 ||:  87%|████████▋ | 33/38 [00:21<00:04,  1.15it/s]

loss: 2.2046 ||:  89%|████████▉ | 34/38 [00:21<00:02,  1.42it/s]

loss: 2.2036 ||:  92%|█████████▏| 35/38 [00:21<00:01,  1.71it/s]

loss: 2.2039 ||:  95%|█████████▍| 36/38 [00:22<00:01,  2.00it/s]

loss: 2.2030 ||:  97%|█████████▋| 37/38 [00:22<00:00,  1.95it/s]

loss: 2.2029 ||: 100%|██████████| 38/38 [00:22<00:00,  2.27it/s]

loss: 2.2029 ||: 100%|██████████| 38/38 [00:22<00:00,  1.66it/s]


11/19/2019 14:43:34 - INFO - allennlp.training.trainer -            Training |  Validation


11/19/2019 14:43:34 - INFO - allennlp.training.trainer -   loss |     2.570  |     2.203


11/19/2019 14:43:34 - INFO - allennlp.training.trainer -   Epoch duration: 00:11:57


{'training_duration': '00:11:57',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_loss': 2.570211327287572,
 'validation_loss': 2.2028739264136865,
 'best_epoch': 0,
 'best_validation_loss': 2.2028739264136865}

# Understanding the dynamics of LSTM learning

Why exactly do LSTMs learn so well? Let's analyze the dynamics of LSTM learning by checking how the gradients change and comparing them to the gradients of a simple RNN.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
test_batch = next(iterator(train_ds))
test_embeddings = lm_naive.embedding(test_batch["input_tokens"])

### The gradient dynamics of simple RNNs

First, let's check how the gradients of a simple RNN change with regards to the initial inputs

In [None]:
rnn = SimpleRNN(50, 125)

In [None]:
def rnn_step(x_t, h_t, weight_ih, weight_hh, bias_hh):
    return torch.tanh(x_t @ weight_ih + h_t @ weight_hh + bias_hh)

In [None]:
h_0 = torch.zeros(rnn.hidden_size, requires_grad=True).to(test_embeddings.device)
h_t = h_0
grads = []

for t in range(100):
    h_t = rnn_step(
        test_embeddings[:, t, :], h_t,
        rnn.weight_ih, rnn.weight_hh, rnn.bias_hh,
    )
    loss = h_t.abs().sum() # we'll use the l1 norm of the current hidden state as the loss
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()

In [None]:
plt.plot(grads)

As you can see, the gradients decay as time progresses. This is one of the factors that makes simple RNNs more difficult to train compared to LSTMs. 

### The gradient dynamics of LSTMs

Next, let's compare the same plot with LSTMs. Though this might not be very well known, the original formulation of the LSTM did not have a forget gate; we'll be using the formulation without the forget gate first and then see how the forget gate changes the dynamics.

In [None]:
lstm = NaiveLSTM(50, 125)
hidden_size = lstm.hidden_size

In [None]:
def lstm_step(x_t, h_t, c_t, W_ii, W_hi, b_i, W_if, W_hf, b_f,
              W_ig, W_hg, b_g, W_io, W_ho, b_o, use_forget_gate=False):
    i_t = torch.sigmoid(x_t @ W_ii + h_t @ W_hi + lstm.b_i)
    if use_forget_gate:
        f_t = torch.sigmoid(x_t @ W_if + h_t @ W_hf + lstm.b_f)
    g_t = torch.tanh(x_t @ W_ig + h_t @ W_hg + lstm.b_g)
    o_t = torch.sigmoid(x_t @ W_io + h_t @ W_ho + lstm.b_o)
    if use_forget_gate:
        c_t = f_t * c_t + i_t * g_t
    else:
        c_t = c_t + i_t * g_t
    h_t = o_t * torch.tanh(c_t)
    return h_t, c_t

In [None]:
# generate 
h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), 
            torch.zeros(hidden_size, requires_grad=True))
grads = []
h_t, c_t = h_0, c_0
for t in range(100):
    h_t, c_t = lstm_step(
        test_embeddings[:, t, :], h_t, c_t,
        lstm.W_ii, lstm.W_hi, lstm.b_i,
        lstm.W_if, lstm.W_hf, lstm.b_f,
        lstm.W_ig, lstm.W_hg, lstm.b_g,
        lstm.W_io, lstm.W_ho, lstm.b_o,
        use_forget_gate=False,
    )
    loss = h_t.abs().sum()
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()
    lstm.zero_grad()

In [None]:
plt.plot(grads)

Notice how the gradient keeps on accumulating. The reason the gradient behaves this way is because of the update rule
$$ c_t = c_{t-1} + i_t * g_t $$

If you're familiar with gradient calculus, you'll see that the gradients for $ c_t $ propagate straight back to the gradients for $ c_{t-1} $. Therefore, the gradient of the initial timestep keeps increasing: since $ c_0 $ influences $ c_1 $, which in turn influences $ c_2 $, and so on, the influence of the initial state never disappears.

Of course, this can be a mixed blessing: sometimes we don't want the current timestep to influence the hidden state 200 steps into the future. Sometimes, we want to "forget" the information we learned earlier and overwrite it with what we have newly learned. This is where the forget gate comes into play.

### Turning the forget gate on

The forget gate was originally proposed in the paper [Learning to Forget: Continual Prediction with LSTM](https://www.semanticscholar.org/paper/Learning-to-Forget%3A-Continual-Prediction-with-LSTM-Gers-Schmidhuber/11540131eae85b2e11d53df7f1360eeb6476e7f4). Let's see how the gradients change when we turn the forget gate on. Adhering to best practices, we'll initialize the bias for the forget gate to 1.

In [None]:
lstm.b_f.data = torch.ones_like(lstm.b_f.data)

In [None]:
# generate 
h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), 
            torch.zeros(hidden_size, requires_grad=True))
grads = []
h_t, c_t = h_0, c_0
for t in range(100):
    h_t, c_t = lstm_step(
        test_embeddings[:, t, :], h_t, c_t,
        lstm.W_ii, lstm.W_hi, lstm.b_i,
        lstm.W_if, lstm.W_hf, lstm.b_f,
        lstm.W_ig, lstm.W_hg, lstm.b_g,
        lstm.W_io, lstm.W_ho, lstm.b_o,
        use_forget_gate=True,
    )
    loss = h_t.abs().sum()
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()

In [None]:
plt.plot(grads)

Notice how the gradients decay much more slowly than in the case of the Simple RNN. On the other hand, when we don't initialize the forget gate bias to 1... 

In [None]:
lstm.b_f.data = torch.zeros_like(lstm.b_f.data)

In [None]:
h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), 
            torch.zeros(hidden_size, requires_grad=True))
grads = []
h_t, c_t = h_0, c_0
for t in range(100):
    h_t, c_t = lstm_step(
        test_embeddings[:, t, :], h_t, c_t,
        lstm.W_ii, lstm.W_hi, lstm.b_i,
        lstm.W_if, lstm.W_hf, lstm.b_f,
        lstm.W_ig, lstm.W_hg, lstm.b_g,
        lstm.W_io, lstm.W_ho, lstm.b_o,
        use_forget_gate=True,
    )
    loss = h_t.abs().sum()
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()

In [None]:
plt.plot(grads)

The gradient decays much more quickly now: this is why initializing the forget gate to 1 is a good idea, at least in the initial stages of training. 

Now, let's see what happens when we initalize the forget gate to -1.

In [None]:
lstm.b_f.data = -torch.ones_like(lstm.b_f.data)

In [None]:
# generate 
h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), 
            torch.zeros(hidden_size, requires_grad=True))
grads = []
h_t, c_t = h_0, c_0
for t in range(100):
    h_t, c_t = lstm_step(
        test_embeddings[:, t, :], h_t, c_t,
        lstm.W_ii, lstm.W_hi, lstm.b_i,
        lstm.W_if, lstm.W_hf, lstm.b_f,
        lstm.W_ig, lstm.W_hg, lstm.b_g,
        lstm.W_io, lstm.W_ho, lstm.b_o,
        use_forget_gate=True,
    )
    loss = h_t.abs().sum()
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()

In [None]:
plt.plot(grads)

The weights decay even faster now.

We looked at a lot of charts, but the most important point is that the LSTM basically has control over how much of the gradient to allow to flow through each timestep. This is what makes them so easy to train.

# Making our LSTM Faster

Remember how slow our implementation of the LSTM was slow? Let's see how we can speed it up.

If you look at the code for our LSTM carefully, you'll notice that there is a lot of shared processing that could be batched together. For instance, the input and forget gates are both computed based on a linear transformation of the input and the hidden states.


We can group these computations into just two matrix multiplications. The code now looks like this:

In [None]:
class OptimizedLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.weight_ih = Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.weight_hh = Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(self.hidden_size).to(x.device), 
                        torch.zeros(self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
        
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)

In [None]:
lstm = OptimizedLSTM(100, 32)

In [None]:
a = torch.arange(5 * 10 * 100).view((5, 10, 100))

In [None]:
hs, _ = lstm(a.float())

In [None]:
hs.shape

Now, let's see how the training speed changes

In [None]:
lm_optimized = LanguageModel(OptimizedLSTM(50, 125), vocab)
train(lm_optimized, epochs=N_EPOCHS)

The model is faster now, but still not quite as fast as we might want it to be. To really make our LSTM fast, we'll need to pass it over to CuDNN. But that's a topic for another post/notebook.