In [29]:
import gym
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms

plt.ion()

In [42]:
# Toy mnist dataset

def get_loader(train):
  mnist = datasets.MNIST(
      'data', train=train, download=True,
      transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ]))
  #plt.imshow(mnist.__getitem__(0)[0])
  return torch.utils.data.DataLoader(
      mnist, batch_size=64, shuffle=True,
      num_workers=4)
train_loader = get_loader(train=True)
test_loader = get_loader(train=False)

class CNN(nn.Module):
  def __init__(self, in_d, out_d):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(in_d, 16, kernel_size=5)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
    self.fc = nn.Linear(32 * 4 * 4, out_d)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.fc(x.view(x.size(0), -1))
    return F.log_softmax(x)

model = CNN(in_d=1, out_d=10)
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01,
                      momentum=0.5)

def train(epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data.cuda()), Variable(target.cuda())
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 1000 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'. format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss.data[0]))

def test():
  model.eval()
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    data, target = data.cuda(), target.cuda()
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))

for epoch in range(1, 11):
  train(epoch)
  test()

Test set: Average loss: 0.0914, Accuracy: 9718/10000 (97.18%)
Test set: Average loss: 0.0653, Accuracy: 9801/10000 (98.01%)
Test set: Average loss: 0.0560, Accuracy: 9819/10000 (98.19%)
Test set: Average loss: 0.0549, Accuracy: 9810/10000 (98.10%)
Test set: Average loss: 0.0455, Accuracy: 9860/10000 (98.60%)
Test set: Average loss: 0.0405, Accuracy: 9861/10000 (98.61%)
Test set: Average loss: 0.0400, Accuracy: 9867/10000 (98.67%)
Test set: Average loss: 0.0347, Accuracy: 9885/10000 (98.85%)
Test set: Average loss: 0.0364, Accuracy: 9876/10000 (98.76%)
Test set: Average loss: 0.0300, Accuracy: 9896/10000 (98.96%)


CNN (
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (fc): Linear (6272 -> 10)
)