In [46]:
import torch
import torch.nn as nn

In [47]:
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

In [48]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))])

In [49]:
train_dataset=datasets.MNIST(root='../dataset/mnist/',
                             train=True,
                             download=True,
                             transform=transform)

In [50]:
batch_size=64

In [51]:
train_loader=DataLoader(train_dataset,
                       batch_size=batch_size,
                       shuffle=True)

In [52]:
test_dataset=datasets.MNIST(root='../dataset/mnist/',
                           train=False,
                           download=True,
                           transform=transform)

In [53]:
test_loader=DataLoader(test_dataset,
                      batch_size=batch_size,
                      shuffle=True)

In [54]:
class ResidualBlock(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        return F.relu(x+y)

In [55]:
class ResidualConstantScaling(nn.Module):
    def __init__(self,channels):
        super(ResidualConstantScaling,self).__init__()
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        return F.relu(0.5*x+0.5*y)

In [56]:
class ResidualConvShortcut(nn.Module):
    def __init__(self,channels):
        super(ResidualConvShortcut,self).__init__()
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv3=nn.Conv2d(channels,channels,kernel_size=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        x=self.conv3(x)
        return F.relu(x+y)

In [62]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,kernel_size=5)
        self.conv2=nn.Conv2d(16,32,kernel_size=5)
        self.mp=nn.MaxPool2d(2)
        self.rblock1=ResidualBlock(channels=16)
        self.rblock2=ResidualBlock(channels=32)
        self.fc=nn.Linear(512,10)
    def forward(self,x):
        in_size=x.size(0)
        x=self.mp(F.relu(self.conv1(x)))
        x=self.rblock1(x)
        x=self.mp(F.relu(self.conv2(x)))
        x=self.rblock2(x)
        x=x.view(in_size,-1)
        x=self.fc(x)
        return x
model=Net()

In [63]:
criterion=torch.nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01)

In [64]:
def train(epoch):
    runing_loss=0
    for batch_idx,data in enumerate(train_loader,0):
        inputs,targets=data
        optimizer.zero_grad()
        output=model(inputs)
        loss=criterion(output,targets)
        loss.backward()
        optimizer.step()
        runing_loss+=loss.item()
        if batch_idx%300==299:
            print(epoch+1,batch_idx+1,runing_loss/300)

In [65]:
def test():
    correct=0
    total=0
    with torch.no_grad():
        for data in test_loader:
            inputs,labels=data
            o=model(inputs)
            _,predict=torch.max(o.data,dim=1)
            total+=labels.size(0)
            correct+=(predict==labels).sum().item()
    print(correct/total*100)

In [None]:
for epoch in range(10):
    train(epoch)
    test()

1 300 0.7647064708669981
1 600 0.9867103164518873
1 900 1.1554205532682438
96.03
2 300 0.13012087369337677
2 600 0.24309731864680847
2 900 0.3425486720725894
97.74000000000001
3 300 0.08863083760254085
3 600 0.17200883506486814
3 900 0.24820561928674578
97.89
4 300 0.07005051255847017
4 600 0.14170489803111802
4 900 0.20334518157721806
98.37
5 300 0.06473811823486661
5 600 0.12221906644757836
5 900 0.1751372180801506
97.89
6 300 0.050563594916214545
6 600 0.10493266692462688
6 900 0.15601105702963347
98.72
7 300 0.050648589603370056
7 600 0.09512464746716433
7 900 0.13993277221801692
98.72
8 300 0.03882531273722028
