In [1]:
import numpy as np
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random

from torch.optim import Adam
from torch.autograd import Variable

In [2]:
train_data = np.load('../dataset/wiki.train.npy')
val_data = np.load('../dataset/wiki.valid.npy')
vocab = np.load('../dataset/vocab.npy') 

In [73]:
class CustomDataLoader(DataLoader):
    def __init__(self, array, batch_size):
        random.shuffle(array)
        data = np.concatenate((array))
        m = len(data) // batch_size
        
        data = data[: m*batch_size+1]
        self.inputs = data[:-1].reshape(batch_size, m).T
        self.labels = data[1:].reshape(batch_size, m).T
        
        self.inputs = torch.from_numpy(self.inputs).long()
        self.labels = torch.from_numpy(self.labels).long()
        
        if torch.cuda.is_available():
            self.inputs = self.inputs.cuda()
            self.labels = self.labels.cuda()
            
    def __iter__(self):
        for i in range(self.len):
            start = i*self.seq_length
            end = (i+1)*self.seq_length

            yield (self.inputs[start:end], self.labels[start:end])
        
    def __len__(self):
        # Generate Random Length for each epoch
        len1 = np.random.normal(70,5,1)[0]
        len2 = np.random.normal(35,5,1)[0]
        random_len = np.random.choice([len1, len2], size=1, p=[0.95, 0.05])[0]
        
        self.seq_length = int(random_len) if random_len > 0 and random_len < 100 else 70
        self.len = self.inputs.shape[0] // self.seq_length
        
        print("seq_length", self.seq_length)
        return self.len

In [74]:
batch_size = 80
train_loader = CustomDataLoader(train_data, batch_size)
val_loader = CustomDataLoader(val_data, batch_size)

In [75]:
def sample_gumbel(shape, eps=1e-10, out=None):
    """
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
    return - torch.log(eps - torch.log(U + eps))

In [81]:
class RLSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(RLSTM, self).__init__()        
        self.encoder = nn.Embedding(vocab_size, embed_size)
        self.rnns = nn.ModuleList([
            nn.LSTM(input_size=embed_size, hidden_size=hidden_size, batch_first=True),
            nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, batch_first=True),
            nn.LSTM(input_size=hidden_size, hidden_size=embed_size, batch_first=True)
        ])
        self.decoder = nn.Linear(embed_size, vocab_size)
        
        self.decoder.weight = self.encoder.weight

    def forward(self, inputs, forward=0, stochastic=False):
        h = inputs  # (n, t)
        h = self.encoder(h)  # (n, t, c)
        states = []
        for rnn in self.rnns:
            h, state = rnn(h)
            states.append(state)
        h = self.decoder(h)
        if stochastic:
            gumbel = Variable(sample_gumbel(shape=h.size(), out=h.data.new()))
            h += gumbel
        logits = h
        
        if forward > 0:
            outputs = []
            print(logits[:, -1:, :].shape)
            h = torch.max(logits[:, -1:, :], dim=2)[1]
            for i in range(forward):
                h = self.encoder(h)
                for j, rnn in enumerate(self.rnns):
                    h, state = rnn(h, states[j])
                    states[j] = state
                h = self.decoder(h)
                if stochastic:
                    gumbel = Variable(sample_gumbel(shape=h.size(), out=h.data.new()))
                    h += gumbel
                outputs.append(h)
                h = torch.max(h, dim=2)[1]
            logits = torch.cat([logits] + outputs, dim=1)
        return logits

In [82]:
learning_rate = 0.01
embed_size = 400
hidden_size = 100

model = RLSTM(len(vocab), embed_size, hidden_size)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    length = len(train_loader)
    
    for batch_idx, data in enumerate(train_loader):
        inputs, labels = data
        inputs = Variable(inputs, volatile=False)
        labels = Variable(labels.view(-1), volatile=False)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, len(vocab)), labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)
        
        optimizer.step()
        
        train_loss += loss.data[0]
        
        if batch_idx % 100 == 0:
            print(batch_idx, train_loss/(batch_idx+1))
    
    print("Epoch ", epoch, " Train Loss ", train_loss/length)

    model.eval()
    val_loss = 0
    length = len(val_loader)
    
    for batch_idx, data in enumerate(val_loader):
        inputs, labels = data
        inputs = Variable(inputs, volatile=False)
        labels = Variable(labels.view(-1), volatile=False)
        
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, len(vocab)), labels)
        
        val_loss += loss.data[0]
    
    print("Epoch ", epoch, " Val Loss ", val_loss/length)
    torch.save(model.state_dict(), str(epoch)+'-model.pkl')

In [83]:
inp = np.load('../fixtures/generation.npy')
forward = 20

model.load_state_dict(torch.load('model-cpu.pkl'))

In [84]:
input = torch.from_numpy(inp).long()
if torch.cuda.is_available():
    input = input.cuda()

input = Variable(input)
model.eval()
logits = model(input, forward=20, stochastic=True)
classes = torch.max(logits, dim=2)[1].data.cpu().numpy()
print(classes[20:40])

torch.Size([32, 1, 33278])
[[ 1417 21579 31353 30597 32084 15340 18743 15659    76    79 31352    79
  22968 15340    76 22968    79    76 31519    73    76    79 21626    79
   1424    79 15340    79     6    79    79  7033 15773  4494 20566    79
  20481    79 15659    79]
 [20683 32747 32978 15340    79    79    76 29624  1419    79    79    76
     76 13456    79 15340    76    79 21415    72 32427  1509    64    79
      1    79    79 15340  7597 15340    79 15659    79 16786  7867  1415
     76 20787 22968    76]
 [25821 15340    79  9820    76    76  7268 22968 24117    76    76    76
     79 18119 15659  1419    79    79    76 29798    79    76 22968    79
     73 15310 21626    79    79 18957    76    79 28891 17253  1415    76
     76  1419 15340    76]
 [28808    85 14860 15773    76    79    76 17194    76 32084 24697    76
  31467    79 25949    79    79 15340 13647 12685    79    76    76    76
  19874 31352 15340    76 29004  5768 21626    76 25871  1415 24118    79
  31

In [None]:
model.load_state_dict(torch.load('4-model.pkl'))
torch.save(model.cpu().state_dict(), '4-model-cpu.pkl')