# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [14]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [15]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

Then, we define the model, object function and optimizer that we use to classify.

In [36]:
def normal_init(m, mean, stddev):
    m.weight.data.normal_(mean, stddev)
    m.bias.data.zero_()
class SimpleNet(nn.Module):
# TODO:define model
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5)
        self.sigmoid = nn.Sigmoid()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv2 = nn.Conv2d(in_channels = 6, out_channels = 16, kernel_size = 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
#         normal_init(self.conv1, 0, 0.01)
#         normal_init(self.conv2, 0, 0.01)
#         normal_init(self.fc1, 0, 0.01)
#         normal_init(self.fc2, 0, 0.01)
#         normal_init(self.fc3, 0, 0.01)
        
    def forward(self, x):
        out = self.sigmoid(self.conv1(x))
        out = self.max_pool(out)
        out = self.sigmoid(self.conv2(out))
        out = self.max_pool(out)
        out = out.view(x.shape[0], -1)
        out = self.sigmoid(self.fc1(out))
        out = self.sigmoid(self.fc2(out))
        out = self.fc3(out)
        
        return out
        
model = SimpleNet()

# TODO:define loss function and optimiter
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

Next, we can start to train and evaluate!

In [37]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    count = 0
    train_num = 0
    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        model.zero_grad()
        pred = model(images)
        loss = criterion(pred, labels)
        _, argmax = torch.max(pred, dim = 1)
        for i in range(pred.shape[0]):
            train_num += 1
            if labels[i] == argmax[i]:
                count += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_accuracy = count / train_num
    print('epoch:%d train accuracy:%.5f' %((epoch+1),train_accuracy))
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    count = 0
    test_num = 0
    for images, labels in tqdm(test_loader):
        with torch.no_grad():
            pred = model(images)
            _, argmax = torch.max(pred, dim = 1)
            for i in range(pred.shape[0]):
                test_num += 1
                if (labels[i] == argmax[i]):
                    count += 1
    test_accuracy = count / test_num
    print('epoch:%d   test accuracy:%.5f' %((epoch+1),test_accuracy))
    
    
    


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.42it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:01, 39.68it/s]

epoch:1 train accuracy:0.14161


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 38.15it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 24.84it/s]

epoch:1   test accuracy:0.23928


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 27.96it/s]
 10%|████████▌                                                                          | 8/78 [00:00<00:01, 39.21it/s]

epoch:2 train accuracy:0.29581


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 37.18it/s]
  1%|▋                                                                                 | 4/468 [00:00<00:14, 31.28it/s]

epoch:2   test accuracy:0.37560


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.44it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 33.81it/s]

epoch:3 train accuracy:0.49566


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 37.52it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:17, 27.07it/s]

epoch:3   test accuracy:0.66166


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.17it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 33.01it/s]

epoch:4 train accuracy:0.79567


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.97it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:17, 26.03it/s]

epoch:4   test accuracy:0.86298


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 27.87it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 33.00it/s]

epoch:5 train accuracy:0.90066


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 35.42it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 24.60it/s]

epoch:5   test accuracy:0.92538


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:17<00:00, 26.60it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 29.69it/s]

epoch:6 train accuracy:0.93670


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 34.91it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:19, 23.95it/s]

epoch:6   test accuracy:0.94892


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 27.82it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 35.98it/s]

epoch:7 train accuracy:0.95206


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 38.06it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:16, 27.46it/s]

epoch:7   test accuracy:0.95703


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.08it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 36.04it/s]

epoch:8 train accuracy:0.96022


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 37.03it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:17, 26.92it/s]

epoch:8   test accuracy:0.96274


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.64it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:02, 35.80it/s]

epoch:9 train accuracy:0.96548


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 40.38it/s]
  1%|▋                                                                                 | 4/468 [00:00<00:14, 31.83it/s]

epoch:9   test accuracy:0.96735


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:15<00:00, 29.80it/s]
  5%|████▎                                                                              | 4/78 [00:00<00:01, 39.32it/s]

epoch:10 train accuracy:0.97002


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 40.05it/s]

epoch:10   test accuracy:0.97226





#### Q5:
Please print the training and testing accuracy.