## import package

In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

## load mnist data

In [4]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.1, ))])

train_set = dset.MNIST(root='./train', train=True, transform=trans, download=False)
test_set = dset.MNIST(root='./test', train=False, transform=trans, download=False)

In [5]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(
                dataset = train_set,
                batch_size = batch_size)

test_loader = torch.utils.data.DataLoader(
                dataset = test_set,
                batch_size = batch_size)

## LeNet

In [6]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        
        self.fullyconnected1 = nn.Linear(16*4*4, 120)
        self.fullyconnected2 = nn.Linear(120, 10)
        
        self.ceriation = nn.CrossEntropyLoss()
        
    def forward(self, x, target):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        
        x = x.view(-1, 16*4*4)
        
        x = self.fullyconnected1(x)
        x = self.fullyconnected2(x)
        
        loss = self.ceriation(x, target)
        return x, loss

## optimizer

In [7]:
model = LeNet().cpu()
optimizer = optim.SGD(model.parameters(), lr=0.01)

## training and testing

In [8]:
for epoch in range(10):
    # training
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        _, loss = model(x, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print ('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(epoch, batch_idx, loss.data[0]))
            
    # testing
    correct_cnt, ave_loss = 0, 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        score, loss = model(x, target)
        _, pred_label = torch.max(score.data, 1)
        correct_cnt += (pred_label == target.data).sum()
        ave_loss += loss.data[0]
    
    accuracy = correct_cnt / len(test_loader) / batch_size
    ave_loss /= len(test_loader)
    print ('==>>> epoch: {}, test loss: {:.6f}, accuracy: {:.4f}'.format(epoch, ave_loss, accuracy))
        

==>>> epoch: 0, batch index: 0, train loss: 2.377161
==>>> epoch: 0, batch index: 100, train loss: 0.451708
==>>> epoch: 0, batch index: 200, train loss: 0.212491
==>>> epoch: 0, batch index: 300, train loss: 0.208182
==>>> epoch: 0, batch index: 400, train loss: 0.259217
==>>> epoch: 0, test loss: 0.216903, accuracy: 0.9207
==>>> epoch: 1, batch index: 0, train loss: 0.168528
==>>> epoch: 1, batch index: 100, train loss: 0.115911
==>>> epoch: 1, batch index: 200, train loss: 0.121022
==>>> epoch: 1, batch index: 300, train loss: 0.130226
==>>> epoch: 1, batch index: 400, train loss: 0.199053
==>>> epoch: 1, test loss: 0.123624, accuracy: 0.9503
==>>> epoch: 2, batch index: 0, train loss: 0.119901
==>>> epoch: 2, batch index: 100, train loss: 0.080362
==>>> epoch: 2, batch index: 200, train loss: 0.099456
==>>> epoch: 2, batch index: 300, train loss: 0.113397
==>>> epoch: 2, batch index: 400, train loss: 0.191024
==>>> epoch: 2, test loss: 0.094908, accuracy: 0.9598
==>>> epoch: 3, bat