In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import datasets.python_format as dataset

In [2]:
# TODO: try one-hot instead of embedding
# TODO: if too long - filter wont see all context
# TODO: if too small - downsampling might sample to 0 size

In [3]:
max_len = 100

def batch_gen(batch_size):
  gen = dataset.gen(1, 7)
  
  while True:
    x, y = [], []
    x_len, y_len = [], []
    for i in range(batch_size):
      batch = next(gen)
      x.append(batch[0])
      y.append(batch[1])
      x_len.append(len(batch[0]))
      y_len.append(len(batch[1]))

    x_max_len = max_len
    y_max_len = max_len

    for i in range(batch_size):
      x[i] = x[i] + [dataset.pad] * (x_max_len - len(x[i]))
      y[i] = y[i] + [dataset.pad] * (y_max_len - len(y[i]))

    x = torch.LongTensor(x)
    y = torch.LongTensor(y)
    
    yield x, y

In [4]:
class ConvBNRelu(nn.Module):
  def __init__(self, input, output, kernel):
    super().__init__()
    
    self.conv = nn.Conv2d(input, output, kernel)
    self.bn = nn.BatchNorm2d(output)
    
  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = F.relu(x)
    
    return x
  
class ConvTransposeBNRelu(nn.Module):
  def __init__(self, input, output, kernel):
    super().__init__()
    
    self.conv_t = nn.ConvTranspose2d(input, output, kernel)
    self.bn = nn.BatchNorm2d(output)
    
  def forward(self, x):
    x = self.conv_t(x)
    x = self.bn(x)
    x = F.relu(x)
    
    return x
  
class ResidualBlock(nn.Module):
  def __init__(self, size, kernel):
    self.conv1 = nn.Conv2d(size, size, kernel)
    self.bn1 = nn.BatchNorm2d(size)
    self.conv2 = nn.Conv2d(size, size, kernel)
    self.bn2 = nn.BatchNorm2d(size)
    
  def forward(self, x):
    saved = x
    x = self.conv1(x)
    x = self.bn1(x)
    x = F.relu(x)
    
    x = self.conv2(x)
    x = self.bn2(x)
    x = x + saved
    x = F.relu(x)
    
    return x
    

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    
    self.embedding = nn.Embedding(dataset.vocab_size, dataset.vocab_size)
    
    self.conv_bn_relu1 = ConvBNRelu(1, 16, 3)
    self.conv_bn_relu2 = ConvBNRelu(16, 16, 3)
    self.conv_bn_relu3 = ConvBNRelu(16, 16, 3)
    self.conv_bn_relu4 = ConvBNRelu(16, 16, 3)
    self.conv_bn_relu5 = ConvBNRelu(16, 16, 3)
    self.conv_bn_relu6 = ConvBNRelu(16, 16, 3)
    
    self.conv_t_bn_relu7 = ConvTransposeBNRelu(16, 16, 3)
    self.conv_t_bn_relu8 = ConvTransposeBNRelu(16, 16, 3)
    self.conv_t_bn_relu9 = ConvTransposeBNRelu(16, 16, 3)
    self.conv_t_bn_relu10 = ConvTransposeBNRelu(16, 16, 3)
    self.conv_t_bn_relu11 = ConvTransposeBNRelu(16, 16, 3)
    self.conv_t_bn_relu12 = ConvTransposeBNRelu(16, 1, 3)
    
  def forward(self, x):
    x = self.embedding(x)
    x = x.unsqueeze(1)
    
    x = self.conv_bn_relu1(x)
    x = self.conv_bn_relu2(x)
    x, pi2 = F.max_pool2d(x, 2, 2, return_indices=True)
    
    x = self.conv_bn_relu3(x)
    x = self.conv_bn_relu4(x)
    x, pi4 = F.max_pool2d(x, 2, 2, return_indices=True)
    
    x = self.conv_bn_relu5(x)
    x = self.conv_bn_relu6(x)
#     x, pi6 = F.max_pool2d(x, 2, 2, return_indices=True)
    
#     x = F.max_unpool2d(x, pi6, 2, 2)
    x = self.conv_t_bn_relu7(x)
    x = self.conv_t_bn_relu8(x)
    
    x = F.max_unpool2d(x, pi4, 2, 2)
    x = self.conv_t_bn_relu8(x)
    x = self.conv_t_bn_relu10(x)
    
    x = F.max_unpool2d(x, pi2, 2, 2)
    x = self.conv_t_bn_relu11(x)
    x = self.conv_t_bn_relu12(x)
    
    x = x.squeeze()
    x = F.log_softmax(x, dim=-1)
    return x
  

model = Net()
print(model)

Net(
  (embedding): Embedding(72, 72)
  (conv_bn_relu1): ConvBNRelu(
    (conv): Conv2d (1, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (conv_bn_relu2): ConvBNRelu(
    (conv): Conv2d (16, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (conv_bn_relu3): ConvBNRelu(
    (conv): Conv2d (16, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (conv_bn_relu4): ConvBNRelu(
    (conv): Conv2d (16, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (conv_bn_relu5): ConvBNRelu(
    (conv): Conv2d (16, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (conv_bn_relu6): ConvBNRelu(
    (conv): Conv2d (16, 16, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=T

In [5]:
learning_rate = 0.001
batch_size = 32
steps = 2000
log_interval = 10

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
gen = batch_gen(batch_size)
test_gen = batch_gen(batch_size * 10)

for i in range(steps + 1):
  model.train()

  x, y = next(gen)
  x, y = Variable(x), Variable(y)
  optimizer.zero_grad()
  y_hat = model(x)
  y_hat = y_hat.view(-1, y_hat.size(-1))
  y = y.view(-1)
  loss = F.nll_loss(y_hat, y)
  loss.backward()
  optimizer.step()

  if i % log_interval == 0:
    model.eval()

    x, y = next(test_gen)
    x, y = Variable(x, volatile=True), Variable(y)
    y_hat = model(x)
    y_hat = y_hat.view(-1, y_hat.size(-1))
    y = y.view(-1)
    test_loss = F.nll_loss(y_hat, y) # sum up batch loss
    pred = y_hat.max(1, keepdim=True)[1].squeeze() # get the index of the max log-probability
    accuracy = (pred == y).float().mean() * 100
    
    print('step: {}, loss: {:.4f}, accuracy: {:.2f}%'.format(
        i, test_loss.data[0], accuracy.data[0]))
    
    print('\tsample:    {}\n\ttrue:      {}\n\tpredicted: {}'.format(
      dataset.decode(x[0].tolist()),
      dataset.decode(y[:max_len].tolist()),
      dataset.decode(pred[:max_len].tolist())))

step: 0, loss: 4.2741, accuracy: 74.93%
	sample:    Quae Molestiae<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	true:      Quae Molestiae<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	predicted: Aw<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>}<p>
step: 10, loss: 4.2410, accuracy: 75.17%
	sample:    maxime {}% {} {} {} {}-th<f>47

step: 100, loss: 1.3089, accuracy: 75.85%
	sample:    {} {} {}: {}<f>Dolores<n>maxime<n>Temporibus<n>982<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	true:      Dolores maxime Temporibus: 982<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	predicted: <p><p><p>    <p><p><p> <p><p><p><p><p><p><p><p><p><p><p> <p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
step: 110, loss: 1.3479, accuracy: 75.28%
	sample:    {} cum deserunt voluptatibus {}<f>601<n>dolor<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p

step: 200, loss: 1.1589, accuracy: 78.09%
	sample:    {} QUIS: 905 {}: {} 602%<f>facere<n>Ipsam<n>365<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	true:      facere QUIS: 905 Ipsam: 365 602%<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	predicted: <p>  <p>   Bhh       Na h   <p>   <p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
step: 210, loss: 1.2347, accuracy: 76.32%
	sample:    {} {}: {} EXPLICABO {} 770-th {}<f>Officia<n>neque<n>312<n>QUO<n>QUAERAT<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p

step: 300, loss: 1.0885, accuracy: 78.93%
	sample:    sunt 35% {}<f>accusantium<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	true:      sunt 35% accusantium<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
	predicted: af<p> <p>     <p><p>u%hhAA<p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p><p>
step: 310, loss: 1.0869, accuracy: 78.62%
	sample:    {} {}: {} {} {}<f>Saepe<n>Voluptas<n>998<n>Culpa<n>Labore<p><p><p><p><p><p><p

KeyboardInterrupt: 