In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import datasets.date_format as dataset

In [2]:
# TODO: try one-hot instead of embedding
# TODO: if too long - filter wont see all context
# TODO: if too small - downsampling might sample to 0 size

In [3]:
max_len = 12

def batch_gen(batch_size):
  gen = dataset.gen()
  
  while True:
    x, y = [], []
    x_len, y_len = [], []
    for i in range(batch_size):
      batch = next(gen)
      x.append(batch[0])
      y.append(batch[1])
      x_len.append(len(batch[0]))
      y_len.append(len(batch[1]))

    x_max_len = max_len
    y_max_len = max_len

    for i in range(batch_size):
      x[i] = x[i] + [dataset.pad] * (x_max_len - len(x[i]))
      y[i] = y[i] + [dataset.pad] * (y_max_len - len(y[i]))

    x = torch.LongTensor(x)
    y = torch.LongTensor(y)
    
    yield x, y

In [4]:
class ConvBNRelu1d(nn.Module):
  def __init__(self, input, output, kernel):
    super().__init__()
    
    self.conv = nn.Conv1d(input, output, kernel)
    self.bn = nn.BatchNorm1d(output)
    
  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = F.relu(x)
    return x
  
class ConvTransposeBNRelu1d(nn.Module):
  def __init__(self, input, output, kernel):
    super().__init__()
    
    self.conv_t = nn.ConvTranspose1d(input, output, kernel)
    self.bn = nn.BatchNorm1d(output)
    
  def forward(self, x):
    x = self.conv_t(x)
    x = self.bn(x)
    x = F.relu(x)
    return x
  
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    
    self.embedding = nn.Embedding(dataset.vocab_size, dataset.vocab_size)
    
    self.conv1 = ConvBNRelu1d(dataset.vocab_size, 64, 3)
    self.conv2 = ConvBNRelu1d(64, 64, 3)
    self.conv3 = ConvBNRelu1d(64, 64, 3)
    self.conv4 = ConvBNRelu1d(64, 64, 3)
    self.conv5 = ConvBNRelu1d(64, 64, 3)
      
    self.conv_t6 = ConvTransposeBNRelu1d(64, 64, 3)
    self.conv_t7 = ConvTransposeBNRelu1d(64, 64, 3)
    self.conv_t8 = ConvTransposeBNRelu1d(64, 64, 3)
    self.conv_t9 = ConvTransposeBNRelu1d(64, 64, 3)
    self.conv_t10 = ConvTransposeBNRelu1d(64, dataset.vocab_size, 3)
    
  def forward(self, x):
    x = self.embedding(x)
    x = x.permute(0, 2, 1)
    
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    
    
    x = self.conv_t6(x)
    x = self.conv_t7(x)
    x = self.conv_t8(x)
    x = self.conv_t9(x)
    x = self.conv_t10(x)
    
    x = F.log_softmax(x, dim=1)
    return x

In [5]:
batch_size = 32
steps = 2000
log_interval = 200

model = Net()
optimizer = optim.Adam(model.parameters())
gen = batch_gen(batch_size)
test_gen = batch_gen(batch_size * 10)

for i in range(steps + 1):
  model.train()

  x, y = next(gen)
  x, y = Variable(x), Variable(y)
  optimizer.zero_grad()
  
  y_hat = model(x)
  y_hat = y_hat.permute(0, 2, 1).contiguous().view(-1, y_hat.size(1))
  y = y.view(-1)
  
  loss = F.nll_loss(y_hat, y)
  loss.backward()
  optimizer.step()

  if i % log_interval == 0:
    model.eval()

    x, y = next(test_gen)
    x, y = Variable(x, volatile=True), Variable(y)
    
    y_hat = model(x)
    y_hat = y_hat.permute(0, 2, 1).contiguous().view(-1, y_hat.size(1))
    y = y.view(-1)
    
    test_loss = F.nll_loss(y_hat, y) # sum up batch loss
    pred = y_hat.max(1, keepdim=True)[1].squeeze() # get the index of the max log-probability
    accuracy = (pred == y).float().mean() * 100
    
    print('step: {}, loss: {:.4f}, accuracy: {:.2f}%'.format(
        i, test_loss.data[0], accuracy.data[0]))
    
    print('\tsample:    {}\n\ttrue:      {}\n\tpredicted: {}'.format(
      dataset.decode(x[0].tolist()),
      dataset.decode(y[:max_len].tolist()),
      dataset.decode(pred[:max_len].tolist())))

step: 0, loss: 3.6909, accuracy: 2.21%
	sample:    16/9/34<p><p><p><p><p>
	true:      16 sep 1934<p>
	predicted: ppffpppfppff
step: 200, loss: 1.6775, accuracy: 87.50%
	sample:    4/2/78<p><p><p><p><p><p>
	true:      4 feb 1978<p><p>
	predicted: 3 feb 1978<p><p>
step: 400, loss: 1.1341, accuracy: 93.49%
	sample:    14/8/37<p><p><p><p><p>
	true:      14 aug 1937<p>
	predicted: 16 aug 1937<p>
step: 600, loss: 0.7258, accuracy: 97.55%
	sample:    11/2/57<p><p><p><p><p>
	true:      11 feb 1957<p>
	predicted: 11 feb 1957<p>
step: 800, loss: 0.4992, accuracy: 98.57%
	sample:    8/8/72<p><p><p><p><p><p>
	true:      8 aug 1972<p><p>
	predicted: 6 aug 1972<p><p>
step: 1000, loss: 0.3446, accuracy: 99.27%
	sample:    3/8/16<p><p><p><p><p><p>
	true:      3 aug 1916<p><p>
	predicted: 3 aug 1916<p><p>
step: 1200, loss: 0.2499, accuracy: 99.84%
	sample:    7/9/39<p><p><p><p><p><p>
	true:      7 sep 1939<p><p>
	predicted: 7 sep 1939<p><p>
step: 1400, loss: 0.1987, accuracy: 99.69%
	sample:    15/3/37