In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))])
train_dataset=datasets.MNIST(root='../dataset/mnist/',
                             train=True,
                             download=True,
                             transform=transform)
batch_size=64
train_loader=DataLoader(train_dataset,
                       batch_size=batch_size,
                       shuffle=True)
test_dataset=datasets.MNIST(root='../dataset/mnist/',
                           train=False,
                           download=True,
                           transform=transform)
test_loader=DataLoader(test_dataset,
                      batch_size=batch_size,
                      shuffle=True)
class ResidualConvShortcut(nn.Module):
    def __init__(self,channels):
        super(ResidualConvShortcut,self).__init__()
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv3=nn.Conv2d(channels,channels,kernel_size=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        x=self.conv3(x)
        return F.relu(x+y)
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,kernel_size=5)
        self.conv2=nn.Conv2d(16,32,kernel_size=5)
        self.mp=nn.MaxPool2d(2)
        self.rblock1=ResidualConvShortcut(channels=16)
        self.rblock2=ResidualConvShortcut(channels=32)
        self.fc=nn.Linear(512,10)
    def forward(self,x):
        in_size=x.size(0)
        x=self.mp(F.relu(self.conv1(x)))
        x=self.rblock1(x)
        x=self.mp(F.relu(self.conv2(x)))
        x=self.rblock2(x)
        x=x.view(in_size,-1)
        x=self.fc(x)
        return x
model=Net()
criterion=torch.nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01)
def train(epoch):
    runing_loss=0
    for batch_idx,data in enumerate(train_loader,0):
        inputs,targets=data
        optimizer.zero_grad()
        output=model(inputs)
        loss=criterion(output,targets)
        loss.backward()
        optimizer.step()
        runing_loss+=loss.item()
        if batch_idx%300==299:
            print(epoch+1,batch_idx+1,runing_loss/300)
def test():
    correct=0
    total=0
    with torch.no_grad():
        for data in test_loader:
            inputs,labels=data
            o=model(inputs)
            _,predict=torch.max(o.data,dim=1)
            total+=labels.size(0)
            correct+=(predict==labels).sum().item()
    print(correct/total*100)
for epoch in range(6):
    train(epoch)
    test()