# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time
from torchvision import datasets, transforms
BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [2]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

test_dataset = datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)
# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

Using downloaded and verified file: ./data/MNIST\raw\train-images-idx3-ubyte.gz
Extracting ./data/MNIST\raw\train-images-idx3-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./data/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw
Processing...
Done!


Then, we define the model, object function and optimizer that we use to classify.

In [10]:
import torch.nn.functional as F
class SimpleNet(nn.Module):
# TODO:define model

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        # 第一个线性层输入维度：28*28-图片长*宽；输出维度500-第一层神经元个数
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)
        # 预测层输出为10个神经元，代表10个数字
    def forward(self, x):
        x = x.view(-1, 28*28) #把torch tensor先展开成一行，再按照指定的size进行resize
        #此处就是每一列28*28（一个图片）；-1表示的是行数自动算出，行数也就是图片数目
        x = F.relu(self.fc1(x))
        # 激活
        x = F.relu(self.fc2(x))
        # 激活
        x = self.fc3(x)
        return x


# 实例化    
model = SimpleNet()

# TODO:define loss function and optimiter
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
print(model)

SimpleNet(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=10, bias=True)
)


Next, we can start to train and evaluate!

In [18]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    ave_loss = 0
    for X_train, y_train in tqdm(train_loader):
        # TODO:forward + backward + optimize
       optimizer.zero_grad()                       #梯度清零
       out = model(X_train)                        #使用model预测
       loss = criterion(out, y_train)              #计算损失

       ave_loss = ave_loss * 0.9 + loss.item() * 0.1
       loss.backward()                             #backward传递损失
       optimizer.step()                            # 梯度更新
    print("epoch_number:",epoch,"------------loss:",ave_loss)
    correct_cnt= 0
    total_cnt = 0
    for X_test, y_test in tqdm(test_loader):
        out = model(X_test)                     #预测一batch_size 的测试数据
        loss = criterion(out, y_test)       
        _, pred_label = torch.max(out.data, 1)   
            #预测数据结果：维度1（列）上求最大
            #比如 out[1][0] out[1][1] out[1][2]... 求最大；返回最大的值和所在列index
        total_cnt += X_test.data.size()[0]
            #test_data的图片数目
        correct_cnt += (pred_label == y_test).sum()
            #测试正确的图片数目
        # smooth average
    accuracy = float(correct_cnt)/total_cnt
    print("epoch_number:",epoch,"------------test accuracy:",accuracy)
   
        
        
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    
    
    
    


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:19<00:00, 23.45it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 29.20it/s]

epoch_number: 0 ------------loss: 0.034567791192461415


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 35.97it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 25.71it/s]

epoch_number: 0 ------------test accuracy: 0.9798677884615384


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.15it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 27.85it/s]

epoch_number: 1 ------------loss: 0.025283340231180732


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 35.68it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 24.66it/s]

epoch_number: 1 ------------test accuracy: 0.9771634615384616


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 22.61it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 27.85it/s]

epoch_number: 2 ------------loss: 0.026071097033364216


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 34.99it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 25.28it/s]

epoch_number: 2 ------------test accuracy: 0.9808693910256411


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 22.99it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 28.92it/s]

epoch_number: 3 ------------loss: 0.025030184080170836


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.96it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:19, 24.26it/s]

epoch_number: 3 ------------test accuracy: 0.9756610576923077


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.15it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 29.20it/s]

epoch_number: 4 ------------loss: 0.02708594882833014


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.77it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:17, 26.62it/s]

epoch_number: 4 ------------test accuracy: 0.9746594551282052


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.18it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 28.92it/s]

epoch_number: 5 ------------loss: 0.03302292137476507


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 35.26it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 24.66it/s]

epoch_number: 5 ------------test accuracy: 0.9798677884615384


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.11it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 28.65it/s]

epoch_number: 6 ------------loss: 0.019336187061834625


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.79it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:17, 25.93it/s]

epoch_number: 6 ------------test accuracy: 0.9783653846153846


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.18it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 28.92it/s]

epoch_number: 7 ------------loss: 0.033459041519671356


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.73it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:18, 24.86it/s]

epoch_number: 7 ------------test accuracy: 0.9764623397435898


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.25it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 27.60it/s]

epoch_number: 8 ------------loss: 0.021637861194549756


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 35.39it/s]
  1%|▌                                                                                 | 3/468 [00:00<00:19, 23.87it/s]

epoch_number: 8 ------------test accuracy: 0.9809695512820513


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:20<00:00, 23.23it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 27.60it/s]

epoch_number: 9 ------------loss: 0.02880777074876969


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 36.46it/s]

epoch_number: 9 ------------test accuracy: 0.9799679487179487





#### Q5:
Please print the training and testing accuracy.

In [26]:
correct_cnt= 0
total_cnt = 0
for X_test, y_test in tqdm(test_loader):
        out = model(X_test)
        loss = criterion(out, y_test)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += X_test.data.size()[0]
        correct_cnt += (pred_label == y_test).sum()
accuracy = float(correct_cnt)/total_cnt
print(accuracy)


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 39.80it/s]

0.9799679487179487





# 
X_test: torch.Size([128, 1, 28, 28])
pred_label: torch.Size([128])
y_test: torch.Size([128])

In [23]:
correct_cnt= 0
total_cnt = 0
for X_train, y_train in tqdm(train_loader):
        out = model(X_train)
        loss = criterion(out, y_train)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += X_train.data.size()[0]
        correct_cnt += (pred_label == y_train).sum()
accuracy = float(correct_cnt)/total_cnt
print(accuracy)


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:12<00:00, 38.59it/s]

0.9939736912393162



