In [13]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

In [18]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

#get test data
train_x = Variable(torch.unsqueeze(train_dataset.data,dim=1)).type(torch.FloatTensor)[2000:]/255.
train_y = train_dataset.targets[2000:]
test_x = Variable(torch.unsqueeze(test_dataset.data,dim=1)).type(torch.FloatTensor)[:2000]/255.
test_y = test_dataset.targets[:2000]


In [19]:
class SimpleNet(nn.Module):
# TODO:define model
    def __init__(self):
        super(SimpleNet,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(    #(1,28,28)1是channel的维度，28*28为图片的长宽
                in_channels=1,#图片的层数，RGB=3,灰度=1
                out_channels=16,#filter的个数
                kernel_size=5,#filter的长宽
                stride=1,#每隔多少个移动
                padding=2,#图片补0.if stride = 1,padding =(kernel_size-1)/2 = (5-1)/2 = 2
            ),   #-->（16，28，28）
            nn.ReLU(),#-->（16，28，28）
            nn.MaxPool2d(kernel_size=2),#可以看成2*2的filter    #-->（16，14，14）kernel_size = 2,减小一半
        )
        self.conv2 = nn.Sequential(#-->（16，14，14）
            nn.Conv2d(16,32,5,1,2),#-->（32，14，14）
            nn.ReLU(),#-->（32，14，14）
            nn.MaxPool2d(2)#-->（32，7，7）

        )
        self.out = nn.Linear(32*7*7,10)#输出为10类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)   #-->(batch,32,7,7)
        x = x.view(x.size(0),-1) #-->(batch,32*7*7)
        output = self.out(x)
        return output


    
model = SimpleNet()
print(model)
# TODO:define loss function and optimiter
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

SimpleNet(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)


In [21]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    for images,labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        b_x = Variable(images)
        b_y = Variable(labels)

        out = model(b_x)
        loss = criterion(out,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        
    test_out = model(test_x)
    train_out = model(train_x)
    pred_train_y = torch.max(train_out,1)[1].data.squeeze()
    pred_test_y = torch.max(test_out,1)[1].data.squeeze()
    
    accuracy_train = float(sum(pred_train_y == train_y)) / float(sum(train_y == train_y))    
    accuracy_test = float(sum(pred_test_y == test_y))/ float(sum(test_y == test_y))
    
    print('Epoch:',epoch,'|train accuracy:'+str(accuracy_train),'|test accuracy:'+str(accuracy_test))

        
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset


  0%|          | 0/468 [00:00<?, ?it/s][A
  0%|          | 1/468 [00:00<01:27,  5.32it/s][A
  0%|          | 2/468 [00:00<01:16,  6.11it/s][A
  1%|          | 3/468 [00:00<01:07,  6.86it/s][A
  1%|          | 5/468 [00:00<00:59,  7.76it/s][A
  1%|▏         | 7/468 [00:00<00:53,  8.62it/s][A
  2%|▏         | 9/468 [00:00<00:48,  9.37it/s][A
  2%|▏         | 11/468 [00:01<00:46,  9.83it/s][A
  3%|▎         | 13/468 [00:01<00:45, 10.07it/s][A
  3%|▎         | 15/468 [00:01<00:43, 10.43it/s][A
  4%|▎         | 17/468 [00:01<00:42, 10.67it/s][A
  4%|▍         | 19/468 [00:01<00:41, 10.91it/s][A
  4%|▍         | 21/468 [00:01<00:39, 11.21it/s][A
  5%|▍         | 23/468 [00:02<00:39, 11.30it/s][A
  5%|▌         | 25/468 [00:02<00:38, 11.41it/s][A
  6%|▌         | 27/468 [00:02<00:38, 11.41it/s][A
  6%|▌         | 29/468 [00:02<00:38, 11.43it/s][A
  7%|▋         | 31/468 [00:02<00:38, 11.35it/s][A
  7%|▋         | 33/468 [00:03<00:38, 11.32it/s][A
  7%|▋         | 35/468 [0

Epoch: 0 |train accuracy:0.9730689655172414 |test accuracy:0.967



  0%|          | 1/468 [00:00<03:45,  2.07it/s][A
  0%|          | 2/468 [00:00<02:57,  2.63it/s][A
  1%|          | 3/468 [00:00<02:22,  3.27it/s][A
  1%|          | 4/468 [00:00<01:56,  3.99it/s][A
  1%|          | 5/468 [00:00<01:37,  4.74it/s][A
  1%|▏         | 6/468 [00:01<01:25,  5.38it/s][A
  2%|▏         | 8/468 [00:01<01:13,  6.29it/s][A
  2%|▏         | 10/468 [00:01<01:04,  7.14it/s][A
  3%|▎         | 12/468 [00:01<00:57,  7.88it/s][A
  3%|▎         | 14/468 [00:01<00:52,  8.68it/s][A
  3%|▎         | 16/468 [00:02<00:48,  9.36it/s][A
  4%|▍         | 18/468 [00:02<00:45,  9.95it/s][A
  4%|▍         | 20/468 [00:02<00:43, 10.31it/s][A
  5%|▍         | 22/468 [00:02<00:41, 10.68it/s][A
  5%|▌         | 24/468 [00:02<00:40, 11.00it/s][A
  6%|▌         | 26/468 [00:02<00:39, 11.18it/s][A
  6%|▌         | 28/468 [00:03<00:39, 11.22it/s][A
  6%|▋         | 30/468 [00:03<00:38, 11.30it/s][A
  7%|▋         | 32/468 [00:03<00:38, 11.34it/s][A
  7%|▋         | 3

Epoch: 1 |train accuracy:0.9592586206896552 |test accuracy:0.9445



  0%|          | 1/468 [00:00<02:55,  2.66it/s][A
  0%|          | 2/468 [00:00<02:22,  3.28it/s][A
  1%|          | 3/468 [00:00<01:55,  4.02it/s][A
  1%|          | 4/468 [00:00<01:37,  4.78it/s][A
  1%|          | 5/468 [00:00<01:23,  5.54it/s][A
  1%|▏         | 7/468 [00:01<01:11,  6.44it/s][A
  2%|▏         | 9/468 [00:01<01:02,  7.40it/s][A
  2%|▏         | 11/468 [00:01<00:55,  8.27it/s][A
  3%|▎         | 13/468 [00:01<00:50,  8.95it/s][A
  3%|▎         | 15/468 [00:01<00:48,  9.30it/s][A
  4%|▎         | 17/468 [00:01<00:45,  9.90it/s][A
  4%|▍         | 19/468 [00:02<00:43, 10.29it/s][A
  4%|▍         | 21/468 [00:02<00:42, 10.47it/s][A
  5%|▍         | 23/468 [00:02<00:41, 10.72it/s][A
  5%|▌         | 25/468 [00:02<00:40, 10.92it/s][A
  6%|▌         | 27/468 [00:02<00:40, 10.79it/s][A
  6%|▌         | 29/468 [00:03<00:40, 10.92it/s][A
  7%|▋         | 31/468 [00:03<00:40, 10.90it/s][A
  7%|▋         | 33/468 [00:03<00:41, 10.58it/s][A
  7%|▋         | 3

Epoch: 2 |train accuracy:0.9302758620689655 |test accuracy:0.918



  0%|          | 1/468 [00:00<02:35,  3.01it/s][A
  0%|          | 2/468 [00:00<02:12,  3.51it/s][A
  1%|          | 3/468 [00:00<01:52,  4.14it/s][A
  1%|          | 4/468 [00:00<01:35,  4.88it/s][A
  1%|          | 5/468 [00:00<01:23,  5.56it/s][A
  1%|▏         | 6/468 [00:00<01:13,  6.29it/s][A
  1%|▏         | 7/468 [00:01<01:06,  6.90it/s][A
  2%|▏         | 8/468 [00:01<01:00,  7.55it/s][A
  2%|▏         | 9/468 [00:01<00:56,  8.07it/s][A
  2%|▏         | 10/468 [00:01<00:54,  8.42it/s][A
  3%|▎         | 12/468 [00:01<00:50,  9.08it/s][A
  3%|▎         | 14/468 [00:01<00:47,  9.62it/s][A
  3%|▎         | 16/468 [00:01<00:46,  9.82it/s][A
  4%|▍         | 18/468 [00:02<00:45,  9.92it/s][A
  4%|▍         | 20/468 [00:02<00:43, 10.31it/s][A
  5%|▍         | 22/468 [00:02<00:41, 10.63it/s][A
  5%|▌         | 24/468 [00:02<00:40, 10.87it/s][A
  6%|▌         | 26/468 [00:02<00:40, 10.92it/s][A
  6%|▌         | 28/468 [00:03<00:40, 10.90it/s][A
  6%|▋         | 30/

Epoch: 3 |train accuracy:0.969396551724138 |test accuracy:0.9505



  0%|          | 1/468 [00:00<02:55,  2.65it/s][A
  0%|          | 2/468 [00:00<02:24,  3.23it/s][A
  1%|          | 3/468 [00:00<01:59,  3.88it/s][A
  1%|          | 4/468 [00:00<01:38,  4.69it/s][A
  1%|          | 5/468 [00:00<01:24,  5.48it/s][A
  1%|▏         | 6/468 [00:01<01:15,  6.15it/s][A
  1%|▏         | 7/468 [00:01<01:07,  6.86it/s][A
  2%|▏         | 9/468 [00:01<00:59,  7.67it/s][A
  2%|▏         | 11/468 [00:01<00:55,  8.27it/s][A
  3%|▎         | 13/468 [00:01<00:51,  8.89it/s][A
  3%|▎         | 15/468 [00:01<00:47,  9.52it/s][A
  4%|▎         | 17/468 [00:02<00:45,  9.99it/s][A
  4%|▍         | 19/468 [00:02<00:43, 10.43it/s][A
  4%|▍         | 21/468 [00:02<00:41, 10.71it/s][A
  5%|▍         | 23/468 [00:02<00:41, 10.84it/s][A
  5%|▌         | 25/468 [00:02<00:39, 11.10it/s][A
  6%|▌         | 27/468 [00:02<00:39, 11.10it/s][A
  6%|▌         | 29/468 [00:03<00:39, 11.01it/s][A
  7%|▋         | 31/468 [00:03<00:40, 10.85it/s][A
  7%|▋         | 33

Epoch: 4 |train accuracy:0.957655172413793 |test accuracy:0.944



  0%|          | 1/468 [00:00<02:56,  2.64it/s][A
  0%|          | 2/468 [00:00<02:28,  3.13it/s][A
  1%|          | 3/468 [00:00<02:04,  3.74it/s][A
  1%|          | 4/468 [00:00<01:42,  4.51it/s][A
  1%|          | 5/468 [00:00<01:27,  5.27it/s][A
  1%|▏         | 6/468 [00:01<01:15,  6.10it/s][A
  1%|▏         | 7/468 [00:01<01:07,  6.87it/s][A
  2%|▏         | 8/468 [00:01<01:01,  7.54it/s][A
  2%|▏         | 9/468 [00:01<00:57,  8.04it/s][A
  2%|▏         | 10/468 [00:01<00:55,  8.29it/s][A
  3%|▎         | 12/468 [00:01<00:52,  8.70it/s][A
  3%|▎         | 14/468 [00:01<00:49,  9.22it/s][A
  3%|▎         | 15/468 [00:01<00:48,  9.34it/s][A
  3%|▎         | 16/468 [00:02<00:47,  9.52it/s][A
  4%|▍         | 18/468 [00:02<00:45,  9.96it/s][A
  4%|▍         | 20/468 [00:02<00:44, 10.15it/s][A
  5%|▍         | 22/468 [00:02<00:43, 10.25it/s][A
  5%|▌         | 24/468 [00:02<00:43, 10.25it/s][A
  6%|▌         | 26/468 [00:03<00:42, 10.28it/s][A
  6%|▌         | 28/

Epoch: 5 |train accuracy:0.961103448275862 |test accuracy:0.941



  0%|          | 2/468 [00:00<01:17,  6.02it/s][A
  1%|          | 3/468 [00:00<01:09,  6.70it/s][A
  1%|          | 4/468 [00:00<01:02,  7.38it/s][A
  1%|▏         | 6/468 [00:00<00:57,  8.04it/s][A
  1%|▏         | 7/468 [00:00<00:54,  8.46it/s][A
  2%|▏         | 9/468 [00:01<00:51,  8.96it/s][A
  2%|▏         | 11/468 [00:01<00:49,  9.32it/s][A
  3%|▎         | 12/468 [00:01<00:48,  9.40it/s][A
  3%|▎         | 14/468 [00:01<00:47,  9.66it/s][A
  3%|▎         | 16/468 [00:01<00:45,  9.88it/s][A
  4%|▍         | 18/468 [00:01<00:45, 10.00it/s][A
  4%|▍         | 20/468 [00:02<00:44, 10.03it/s][A
  5%|▍         | 22/468 [00:02<00:44, 10.02it/s][A
  5%|▌         | 24/468 [00:02<00:44,  9.99it/s][A
  5%|▌         | 25/468 [00:02<00:44,  9.89it/s][A
  6%|▌         | 26/468 [00:02<00:44,  9.90it/s][A
  6%|▌         | 27/468 [00:02<00:44,  9.84it/s][A
  6%|▌         | 29/468 [00:03<00:44,  9.93it/s][A
  7%|▋         | 31/468 [00:03<00:43, 10.00it/s][A
  7%|▋         | 

Epoch: 6 |train accuracy:0.8931379310344828 |test accuracy:0.866



  1%|          | 3/468 [00:00<00:59,  7.83it/s][A
  1%|          | 4/468 [00:00<00:55,  8.31it/s][A
  1%|          | 5/468 [00:00<00:53,  8.68it/s][A
  1%|▏         | 6/468 [00:00<00:52,  8.88it/s][A
  2%|▏         | 8/468 [00:00<00:49,  9.27it/s][A
  2%|▏         | 10/468 [00:01<00:47,  9.71it/s][A
  3%|▎         | 12/468 [00:01<00:45,  9.99it/s][A
  3%|▎         | 14/468 [00:01<00:44, 10.31it/s][A
  3%|▎         | 16/468 [00:01<00:43, 10.48it/s][A
  4%|▍         | 18/468 [00:01<00:43, 10.43it/s][A
  4%|▍         | 20/468 [00:01<00:42, 10.54it/s][A
  5%|▍         | 22/468 [00:02<00:42, 10.50it/s][A
  5%|▌         | 24/468 [00:02<00:43, 10.25it/s][A
  6%|▌         | 26/468 [00:02<00:44, 10.03it/s][A
  6%|▌         | 28/468 [00:02<00:48,  9.05it/s][A
  6%|▌         | 29/468 [00:02<00:48,  9.08it/s][A
  6%|▋         | 30/468 [00:03<00:48,  9.07it/s][A
  7%|▋         | 31/468 [00:03<00:48,  8.96it/s][A
  7%|▋         | 32/468 [00:03<00:48,  9.00it/s][A
  7%|▋         |

Epoch: 7 |train accuracy:0.9287758620689656 |test accuracy:0.899



  0%|          | 1/468 [00:00<02:18,  3.37it/s][A
  0%|          | 2/468 [00:00<01:56,  4.00it/s][A
  1%|          | 3/468 [00:00<01:36,  4.82it/s][A
  1%|          | 4/468 [00:00<01:23,  5.57it/s][A
  1%|          | 5/468 [00:00<01:12,  6.39it/s][A
  1%|▏         | 6/468 [00:00<01:04,  7.16it/s][A
  1%|▏         | 7/468 [00:00<00:59,  7.77it/s][A
  2%|▏         | 8/468 [00:01<00:56,  8.17it/s][A
  2%|▏         | 10/468 [00:01<00:52,  8.75it/s][A
  3%|▎         | 12/468 [00:01<00:48,  9.43it/s][A
  3%|▎         | 14/468 [00:01<00:45,  9.92it/s][A
  3%|▎         | 16/468 [00:01<00:44, 10.23it/s][A
  4%|▍         | 18/468 [00:01<00:42, 10.48it/s][A
  4%|▍         | 20/468 [00:02<00:42, 10.54it/s][A
  5%|▍         | 22/468 [00:02<00:42, 10.53it/s][A
  5%|▌         | 24/468 [00:02<00:42, 10.51it/s][A
  6%|▌         | 26/468 [00:02<00:42, 10.50it/s][A
  6%|▌         | 28/468 [00:02<00:42, 10.34it/s][A
  6%|▋         | 30/468 [00:03<00:42, 10.26it/s][A
  7%|▋         | 32

Epoch: 8 |train accuracy:0.7829482758620689 |test accuracy:0.7315



  0%|          | 1/468 [00:00<01:35,  4.90it/s][A
  0%|          | 2/468 [00:00<01:23,  5.55it/s][A
  1%|          | 3/468 [00:00<01:13,  6.33it/s][A
  1%|          | 4/468 [00:00<01:05,  7.08it/s][A
  1%|▏         | 6/468 [00:00<00:58,  7.96it/s][A
  2%|▏         | 8/468 [00:00<00:52,  8.70it/s][A
  2%|▏         | 10/468 [00:01<00:48,  9.42it/s][A
  3%|▎         | 12/468 [00:01<00:45,  9.94it/s][A
  3%|▎         | 14/468 [00:01<00:43, 10.44it/s][A
  3%|▎         | 16/468 [00:01<00:42, 10.75it/s][A
  4%|▍         | 18/468 [00:01<00:41, 10.85it/s][A
  4%|▍         | 20/468 [00:01<00:40, 11.01it/s][A
  5%|▍         | 22/468 [00:02<00:40, 11.11it/s][A
  5%|▌         | 24/468 [00:02<00:40, 11.06it/s][A
  6%|▌         | 26/468 [00:02<00:39, 11.12it/s][A
  6%|▌         | 28/468 [00:02<00:39, 11.02it/s][A
  6%|▋         | 30/468 [00:02<00:39, 11.00it/s][A
  7%|▋         | 32/468 [00:03<00:41, 10.53it/s][A
  7%|▋         | 34/468 [00:03<00:41, 10.57it/s][A
  8%|▊         | 

Epoch: 9 |train accuracy:0.9527758620689655 |test accuracy:0.9345
