Imports

In [158]:
import torch
import torch.nn.functional as F
from typing import Tuple
from torch import Tensor
import math
import torch.nn as nn
import os
import time
import numpy

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

Corpus

In [159]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {} # word: index
        self.idx2word = [] # position(index): word

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)

class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.dictionary.add_word('<pad>')
        self.train = self.tokenize(os.path.join(path, 'ptb.train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'ptb.valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'ptb.test.txt'))

    def tokenize(self, path):
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                # line to list of token + eos
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            sentences = []
            token = 0
            for line in f:
                temp = []
                words = line.split() + ['<eos>']
                for word in words:
                    temp.append(self.dictionary.word2idx[word])
                sentences.append(torch.LongTensor(temp).to('cuda'))

        return sentences

GRU Cell

In [160]:
class MyGru(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        self.input_size = input_size
        self.hidden_size = hidden_size

        super(MyGru, self).__init__()

        # Reset gate:
        self.reset_ht = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.reset_xt = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.reset_bias = nn.Parameter(torch.Tensor(hidden_size))

        # Update gate:
        self.update_ht = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.update_xt = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.update_bias = nn.Parameter(torch.Tensor(hidden_size))

        # Output gate:
        self.output_qt = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.output_xt = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.output_bias = nn.Parameter(torch.Tensor(hidden_size))

        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, input, hidden_states):
        hidden_seq = []
        seq_size, batch_size, _ = input.size()

        if hidden_states is None:
            hidden_states = torch.zeros(
                batch_size, self.hidden_size).to(input.device)
        else:
            pass

        for t in range(seq_size):
            x_t = input[t, :, :]

            rt = torch.sigmoid(x_t @ self.reset_xt +
                               hidden_states @ self.reset_ht)

            zt = torch.sigmoid(x_t @ self.update_xt +
                               hidden_states @ self.update_ht)

            qt = rt * hidden_states

            ht_1 = torch.tanh(x_t @ self.output_xt +
                              qt @ self.output_qt)

            ht_2 = (1-zt) * hidden_states

            ht_3 = zt * ht_1

            ht = ht_2 + ht_3

            hidden_seq.append(ht.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()

        return hidden_seq, ht


LSTM Cell

In [161]:
class LSTMCell(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_sz = hidden_sz

        self.Ui = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.Vi = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Bi = nn.Parameter(torch.Tensor(hidden_sz))

        ##################################################################################

        self.Uf = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.Vf = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Bf = nn.Parameter(torch.Tensor(hidden_sz))

        self.Uc = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.Vc = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Bc = nn.Parameter(torch.Tensor(hidden_sz))

        ##################################################################################

        self.Uo = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.Vo = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Bo = nn.Parameter(torch.Tensor(hidden_sz))

        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_sz)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
    
    def forward(self, x, init_states=None):
        hidden_seq = []
        sequence_sz, batch_sz, _ = x.size()

        if init_states is None:
            Ht, Ct = (
                torch.zeros(batch_sz, self.hidden_sz).to(x.device),
                torch.zeros(batch_sz, self.hidden_sz).to(x.device)
            )

        else:
            Ht, Ct = init_states

        for t in range(sequence_sz):
            Xt = x[t, :, :]
            
            # Math inside the cell:

            It = torch.sigmoid(Xt @ self.Ui + Ht @ self.Vi + self.Bi)

            Ft = torch.sigmoid(Xt @ self.Uf + Ht @ self.Vf + self.Bf)

            Gt = torch.tanh(Xt @ self.Uc + Ht @ self.Vc + self.Bc)

            Ot = torch.sigmoid(Xt @ self.Uo + Ht @ self.Vo + self.Bo)

            Ct = Ft * Ct + It * Gt

            Ht = Ot * torch.tanh(Ct)

            hidden_seq.append(Ht.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0,1).contiguous()
        return hidden_seq, (Ht, Ct)


Model

In [162]:
class RNNModel(nn.Module):
    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.nlayers = nlayers
        self.nhid = nhid
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # Token2Embeddings
        self.rnn1 = LSTMCell(ninp, nhid)
        self.rnn2 = LSTMCell(ninp, nhid)
        self.fc = nn.Linear(nhid, 5000)
        self.decoder = nn.Linear(5000, ntoken)
        # self.decoder = nn.Linear(nhid, ntoken) # Originally, it was like this.
        self.init_weights()

    def init_weights(self):
        initrange = 0.05
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):

        emb = self.drop(self.encoder(input))

        output, hidden = self.rnn1(emb, hidden)
        output = self.drop(output)

        output, hidden = self.rnn2(output[0, :, :, :], hidden)
        output = self.drop(output)

        output = output[1, :, :, :]

        output = self.fc(output)

        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))

        return F.log_softmax(decoded, dim=1), hidden

    def init_hidden(self, bsz):

        weight = next(self.parameters()).data
        return weight.new_zeros(self.nlayers, bsz, self.nhid), weight.new_zeros(self.nlayers, bsz, self.nhid)

Main

In [163]:
data = "./drive/MyDrive/input"
checkpoint = ""
interval = 200

# Network parameters:
emsize = 650
nhid = 650
nlayers = 2
lr = 0.001
clip = 0.35
epochs = 64
batch_size = 64
eval_batch_size = 1
bptt = 32
dropout = 0.5

save = './drive/MyDrive/input/output/model_test.pt'

torch.manual_seed(1111)

def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to('cuda')
    
# Load data
corpus = Corpus(data)

train_data = corpus.train
val_data = corpus.valid
test_data = corpus.test

# Build the model
ntokens = len(corpus.dictionary) # Around 10000 words
model = RNNModel(ntokens, emsize, nhid, nlayers, dropout).to('cuda')

# Criteria
opt = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))
criterion = loss = nn.NLLLoss(ignore_index = 0)
best_val_loss = None

# Load checkpoint
if checkpoint != '':
    model = torch.load(checkpoint, map_location=lambda storage, loc: storage)

print(model)

############################################################

# def repackage_hidden(h): # For GRU Implementation
#     # detach
#     return h.clone().detach()

def repackage_hidden(h): # For LSTM Implementation
    return tuple(e.detach() for e in h)

def get_batch(source, i, batch_size):
    maxLen = 0
    data = []
    target = []
    for sentence in source[batch_size * i: batch_size * (i+1)]:
        data.append(sentence[:-1])
        target.append(sentence[1:])

    pad_data = pad_sequence(data, padding_value = 0) # 0 is '<pad>'
    pad_target = pad_sequence(target, padding_value = 0)
    
    return pad_data, pad_target 

def evaluate(data_source):
    with torch.no_grad():
        model.eval()
        total_loss = 0
        ntokens = len(corpus.dictionary)
        hidden = model.init_hidden(eval_batch_size)
        counter = 0

        for bindex in range(math.floor(len(data_source) / eval_batch_size)):
            data, target = get_batch(data_source, bindex, eval_batch_size)
            counter += data.shape[1]
            output, hidden = model(data, hidden)
            total_loss += criterion(output, target.view(-1)).data
            hidden = repackage_hidden(hidden)

        return total_loss / counter

############################################################

def train():
    model.train()
    total_loss = 0
    start_time = time.time()
    hidden = model.init_hidden(batch_size)

    for bindex in range(int(len(train_data) / batch_size)):
        data, target = get_batch(train_data, bindex, batch_size)
        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = criterion(output, target.view(-1))
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        opt.step()

        total_loss += loss.data

        if bindex % interval == 0 and bindex > 0:
            cur_loss = total_loss / interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch,
                bindex,
                len(train_data) // batch_size,
                lr,
                elapsed * 1000 / interval,
                cur_loss,
                math.exp(cur_loss)
                )
              )
            total_loss = 0
            start_time = time.time()

try:

    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train()
        val_loss = evaluate(val_data)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch,
                                           (time.time() - epoch_start_time),
                                           val_loss,
                                           math.exp(val_loss)
                                           )
                )
        print('-' * 89)

        if not best_val_loss or val_loss < best_val_loss:
            with open(save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss

############################################################

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

with open(save, 'rb') as f:
    model = torch.load(f)

test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss,
    math.exp(test_loss)
    )
)
print('=' * 89)


RNNModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(10001, 650)
  (rnn1): LSTMCell()
  (rnn2): LSTMCell()
  (fc): Linear(in_features=650, out_features=5000, bias=True)
  (decoder): Linear(in_features=5000, out_features=10001, bias=True)
)
| End of training | test loss  4.34 | test ppl    77.04
