In [None]:
!pip install d2l==0.15.1
!pip install ipython-autotime
%load_ext autotime

Rajesh Sakhamuru

12-8-2020
# LSTM vs GRU for Next Char Prediction


In [2]:
from d2l import torch as d2l
import math
import torch
from torch import nn
from torch.nn import functional as F

time: 450 ms


In [3]:
class LSTMModel(nn.Module):
    def __init__(self, vocabSize, numHiddens, numLayers, **kwargs):
        super(LSTMModel, self).__init__(**kwargs)
        self.lstmLayer = nn.LSTM(input_size=vocabSize,
                                 hidden_size=numHiddens,
                                 num_layers=numLayers)
        self.vocabSize = vocabSize
        self.numHiddens = numHiddens

        self.dense = nn.Linear(numHiddens, vocabSize)

    def forward(self, X, state):
        X = F.one_hot(X.T.long(), self.vocabSize)
        X = X.to(torch.float32)

        output, hiddenState = self.lstmLayer(X, state)
        output = output.reshape((-1, output.shape[-1]))
        output = self.dense(output)
        
        return output, hiddenState

    def begin_state(self, device, batch_size=1):
        return (torch.zeros((self.lstmLayer.num_layers, batch_size, self.numHiddens), device=device), 
                torch.zeros((self.lstmLayer.num_layers, batch_size, self.numHiddens), device=device))

time: 13.7 ms


In [4]:
class GRUModel(nn.Module):
    def __init__(self, vocabSize, numHiddens, numLayers, **kwargs):
        super(GRUModel, self).__init__(**kwargs)
        self.gruLayer = nn.GRU(input_size=vocabSize,
                                 hidden_size=numHiddens,
                                 num_layers=numLayers)
        self.vocabSize = vocabSize
        self.numHiddens = numHiddens

        self.dense = nn.Linear(numHiddens, vocabSize)

    def forward(self, X, state):
        X = F.one_hot(X.T.long(), self.vocabSize)
        X = X.to(torch.float32)

        output, hiddenState = self.gruLayer(X, state)
        output = output.reshape((-1, output.shape[-1]))
        output = self.dense(output)
        
        return output, hiddenState

    def begin_state(self, device, batch_size=1):
        return torch.zeros((self.gruLayer.num_layers, batch_size, self.numHiddens), device=device)

time: 14.9 ms


In [5]:
def trainEpoch(model, train_iter, loss, optimizer, device):
    state = None
    perplexity = [0.0] * 2
    
    for X, Y in train_iter:
        if state is None:
            state = model.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(state, tuple):
                for s in state:
                    s.detach_()
            else:
                state.detach_()
        
        y = Y.T.reshape(-1)
        y = y.to(device)
        X = X.to(device)

        pred, state = model(X, state)

        l = loss(pred, y.long())

        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        perplexity = [a + float(b) for a, b in zip(perplexity, [l.mean()*len(y), len(y)])]
    
    return math.exp(perplexity[0]/perplexity[1])

time: 17.2 ms


In [6]:
def predictChars(prefix, numPreds, model, vocab, device):
    state = model.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]

    for y in prefix[1:]:
        _, state = model(torch.reshape(torch.tensor([outputs[-1]], device=device), (1, 1)), state)
        outputs.append(vocab[y])

    for _ in range(numPreds):
        y, state = model(torch.reshape(torch.tensor([outputs[-1]], device=device), (1, 1)), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))

    return ''.join([vocab.idx_to_token[i] for i in outputs])

time: 6.3 ms


In [7]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocabSize, numHiddens, numLayers = len(vocab), 256, 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

time: 118 ms


## LSTM vs GRU Performance

LSTM and GRU here both trained below. LSTM model took ~25sec to train 500 epochs, and GRU took ~19sec to train 500 epochs. GRU takes less time because there are fewer gates (reset and update) versus LSTM with three gates (input, output and forget). These fewer gates makes it more effecient therefore resulting in shorter computation time. 

Instead of accuracy, the models can be compared using perplexity of the output of the GRU and LSTM models. GRU also reached a lower perplexity score at the 500th epoch, which was very slightly lower than the perplexity score achieved by the LSTM model. From just a subjective point of view we can see that the text produced by the GRU model is a slightly more readable result. The reason it may perform better than LSTM for this type of character-level language model is that the GRU exposes the full hidden content without any control via a memory unit, wheras LSTM by controlling which information is passed to the next time-steps in a character-level model could result in less readable results.

All of these trends hold for any seed characters/prefix phrase which is in english. When given a gibberish prefix, the results can be unpredictable.

In [8]:
model = LSTMModel(vocabSize, numHiddens, numLayers).to(device)
ceLoss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=2)

for epoch in range(500):
    perplexity = trainEpoch(model, train_iter, ceLoss, optimizer, device)
    if (epoch+1)%50 == 0:
        prediction = predictChars('story ', 50, model, vocab, device)
        print('epoch:',epoch,': perplexity:',round(perplexity,3) ,":", prediction)

epoch: 49 : perplexity: 14.015 : story the the the the the the the the the the the the th
epoch: 99 : perplexity: 9.25 : story the the the the the the the the the the the the th
epoch: 149 : perplexity: 5.767 : story three dimensions of thick and this is a cont that 
epoch: 199 : perplexity: 1.886 : story this that is the time travellerit s against rearly
epoch: 249 : perplexity: 1.075 : story he dige traveller same bery regarded as something 
epoch: 299 : perplexity: 1.057 : story the time traveller proceeded anyreal body must hav
epoch: 349 : perplexity: 1.04 : story thend the lantt at a coal in the fire iftime is re
epoch: 399 : perplexity: 1.03 : story then thick sust wither still seeming and theinequa
epoch: 449 : perplexity: 1.024 : story there is a pouttar is they couldmaster the perspec
epoch: 499 : perplexity: 1.033 : story surioned froly in the butter of matter to bass tht
time: 23.3 s


In [9]:
model = GRUModel(vocabSize, numHiddens, numLayers).to(device)
ceLoss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=2)

for epoch in range(500):
    perplexity = trainEpoch(model, train_iter, ceLoss, optimizer, device)
    if (epoch+1)%50 == 0:
        prediction = predictChars('story ', 50, model, vocab, device)
        print('epoch:',epoch,': perplexity:',round(perplexity,3) ,":", prediction)

epoch: 49 : perplexity: 8.621 : story the the the the the the the the the the the the th
epoch: 99 : perplexity: 3.867 : story said the medical man a cube thing seing in and dir
epoch: 149 : perplexity: 1.12 : story sime there isno difference between time and any of
epoch: 199 : perplexity: 1.052 : story sime there isno difference between time and any of
epoch: 249 : perplexity: 1.041 : story sime there were also perhaps a dozen candles about
epoch: 299 : perplexity: 1.026 : story sime there were also perhaps a dozen candles about
epoch: 349 : perplexity: 1.03 : story sime there were also perhaps a dozen candles about
epoch: 399 : perplexity: 1.026 : story side there were also perhaps a dozen candles about
epoch: 449 : perplexity: 1.024 : story simonst reith of hand trick or other said the medi
epoch: 499 : perplexity: 1.017 : story simonst reason said filbywhat is all right said th
time: 19.3 s
