In [1]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [2]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [3]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X_test, y_test in data_iter:
            X_test = X_test.to(device)
            y_test = y_test.to(device)
            network.eval()
            y_hat = network(X_test)
            acc_num  += (y_hat.argmax(dim = 1)==y_test.to(device)).sum().cpu().item()
            network.train()
            n += y_test.shape[0]
    acc = acc_num/n
    return acc 

In [4]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X_test, y_test in data_iter:
            X_test = X_test.to(device)
            y_test = y_test.to(device)
            network.eval()
            y_hat = network(X_test)
            acc_num  += (y_hat.argmax(dim = 1)==y_test.to(device)).sum().cpu().item()
            network.train()
            n += y_test.shape[0]
    acc = acc_num/n
    return acc 

In [5]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, train_acc_num, batch_count, n, start_time = 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.cpu().item()
            train_acc_num += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, network)
        print('epoch: %d, train_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss_sum/batch_count, train_acc_num/n, test_acc, (time.time()-start_time)))

In [6]:
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1_1conv=False, stride=1):
        super().__init__()
        #justify the set of parameter rational
        if (in_channels!=out_channels) or (stride!=1): 
            assert use_1_1conv == True, 'must adjust the shape of original matrix, please let the use_1_1conv be True'
        if in_channels == out_channels:
            assert use_1_1conv == False, 'if in_channels == out_channels, use_1_1conv must be False'
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) #padding必为1保持与conv3的长宽同步
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)               #第二层必不减少长和宽
        if use_1_1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)        #stride与conv1同步，保持长宽同步
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.ac = nn.ReLU()
        
    def forward(self, X):
        Y = self.ac(self.bn1(self.conv1(X)))
        Y = self.conv2(Y)
        if self.conv3 is not None:
            X = self.conv3(X) 
        return self.ac(self.bn2(Y + X))        

In [7]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,X):
        return X.mean(dim = 2, keepdim = True).mean(dim = 3, keepdim = True)

In [8]:
def res_block(in_channels, out_channels, num_res, first_block = False):
    if first_block:
        assert in_channels == out_channels, 'the in_channel and out_channels must be same in the first block' 
    blk = []
    for i in range(num_res):
        if (i == 0) and (not first_block): 
            blk.append(Residual(in_channels, out_channels, use_1_1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)
    

In [9]:
conv_block = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                        nn.BatchNorm2d(64),
                        nn.ReLU(),
                        nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
                       )

fc_block = nn.Sequential(GlobalAvgPool2d(),
                  nn.Flatten(),
                  nn.Linear(512, 10))




In [10]:
Res_Net = nn.Sequential()
Res_Net.add_module('conv_block', conv_block)
Res_Net.add_module('resnet_block1', res_block(64, 64, 2, first_block=True))
Res_Net.add_module('resnet_block2', res_block(64, 128, 2))
Res_Net.add_module('resnet_block3', res_block(128, 256, 2))
Res_Net.add_module('resnet_block4', res_block(256, 512, 2))
Res_Net.add_module('fc_block', fc_block)

In [11]:
X = torch.rand(1,1,224,224)
for name, layer in Res_Net.named_children():
    X = layer(X)
    print(name, 'outshape:', X.shape)

conv_block outshape: torch.Size([1, 64, 56, 56])
resnet_block1 outshape: torch.Size([1, 64, 56, 56])
resnet_block2 outshape: torch.Size([1, 128, 28, 28])
resnet_block3 outshape: torch.Size([1, 256, 14, 14])
resnet_block4 outshape: torch.Size([1, 512, 7, 7])
fc_block outshape: torch.Size([1, 10])


In [12]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)

In [13]:
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(Res_Net.parameters(), lr = lr)
train_network(Res_Net, train_iter, test_iter, optimizer, num_epochs)

epoch: 1, train_loss: 0.395, train_acc: 0.855, test_acc: 0.844, time: 14.70
epoch: 2, train_loss: 0.281, train_acc: 0.897, test_acc: 0.889, time: 13.71
epoch: 3, train_loss: 0.239, train_acc: 0.912, test_acc: 0.885, time: 13.72
epoch: 4, train_loss: 0.212, train_acc: 0.921, test_acc: 0.899, time: 13.79
epoch: 5, train_loss: 0.188, train_acc: 0.930, test_acc: 0.902, time: 13.88
