# Intro to pytorch

In this example, we'll take a look at how to use the pytorch framework to create a simple 3 layer ANN to classify the MNIST data set.

## pytorch MNIST classification

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from IPython.core.debugger import set_trace

In [2]:
image_folder = './images'

transformer = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,),(1.0,))])

train_set = datasets.MNIST(root=image_folder, train=True, transform=transformer, download=True)
test_set = datasets.MNIST(root=image_folder, train=False, transform=transformer, download=True)

batch_size = 64

train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                         batch_size=batch_size,
                                         shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                         batch_size=batch_size,
                                         shuffle=False)



In [3]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x1))
        x3 = self.fc3(x2)
#         set_trace()
        return F.log_softmax(x3,1)
    
    def name(self):
        return "MLP"

In [4]:
model = MLP()

optimizer = optim.SGD(model.parameters(), lr=0.01)

criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, target)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += x.shape[0]
        correct_cnt+= (pred_label == target).sum().item()
        ave_loss = ave_loss * 0.9 + loss.item() * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss, correct_cnt*1.0/total_cnt))
    # testing
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        out = model(x)
        loss = criterion(out, target)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += x.shape[0]
#         print(target.data)
        correct_cnt += (pred_label == target).sum().item()
        # smooth average
        ave_loss = ave_loss * 0.9 + loss.item() * 0.1
        
        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss, correct_cnt * 1.0 / total_cnt))

==>>> epoch: 0, batch index: 100, train loss: 2.241422, acc: 0.228
==>>> epoch: 0, batch index: 200, train loss: 2.139243, acc: 0.340
==>>> epoch: 0, batch index: 300, train loss: 1.938830, acc: 0.422
==>>> epoch: 0, batch index: 400, train loss: 1.655323, acc: 0.474
==>>> epoch: 0, batch index: 500, train loss: 1.300390, acc: 0.519
==>>> epoch: 0, batch index: 600, train loss: 1.014521, acc: 0.559
==>>> epoch: 0, batch index: 700, train loss: 0.840627, acc: 0.592
==>>> epoch: 0, batch index: 800, train loss: 0.727700, acc: 0.619
==>>> epoch: 0, batch index: 900, train loss: 0.635672, acc: 0.642
==>>> epoch: 0, batch index: 938, train loss: 0.608992, acc: 0.650
==>>> epoch: 0, batch index: 100, test loss: 0.516358, acc: 0.809
==>>> epoch: 0, batch index: 157, test loss: 0.596428, acc: 0.834
==>>> epoch: 1, batch index: 100, train loss: 0.554468, acc: 0.838
==>>> epoch: 1, batch index: 200, train loss: 0.555623, acc: 0.845
==>>> epoch: 1, batch index: 300, train loss: 0.549781, acc: 0.8