In [29]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [14]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [15]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X, y in data_iter:
            network.eval()
            y_hat = network(X.to(device))
            acc_num  += (y_hat.argmax(dim = 1)==y.to(device)).float().sum().cpu().item()
            network.train()
            n += y.shape[0]
    acc = acc_num/n
    return acc 

In [16]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, train_acc_num, batch_count, n, start_time = 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.cpu().item()
            train_acc_num += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, network)
        print('epoch: %d, train_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss_sum/batch_count, train_acc_num/n, test_acc, (time.time()-start_time)))

In [18]:
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                       nn.ReLU(),
                       nn.Conv2d(out_channels, out_channels, kernel_size = 1),
                       nn.ReLU(),
                       nn.Conv2d(out_channels, out_channels, kernel_size = 1),
                       nn.ReLU())
    return blk   


In [19]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size = x.size()[2:])

In [21]:
Nin_net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),
                       nn.MaxPool2d(kernel_size=3, stride=2),
                       nin_block(96, 256, kernel_size=5, stride=1, padding=2),
                       nn.MaxPool2d(kernel_size=3, stride=2),
                       nin_block(256,384, kernel_size=3, stride=1, padding=1),
                       nn.MaxPool2d(kernel_size=3, stride=2),
                       nn.Dropout(0.5),
                       #相当于全连接层
                        nin_block(384,10, kernel_size=3, stride=1, padding=1),
                       GlobalAvgPool2d(),
                       nn.Flatten())


In [22]:
print(Nin_net)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): Sequential(
    (0): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (5): MaxPool2d(kernel_size=3, stri

In [25]:
train_iter, test_iter = load_data_fashion_mnist(batch_size = 128, resize = 224)
lr, num_epochs = 0.002, 5
optimizer = optim.Adam(Nin_net.parameters(), lr=lr)
train_network(Nin_net, train_iter, test_iter, optimizer, num_epochs)

epoch: 1, train_loss: 1.270, train_acc: 0.509, test_acc: 0.701, time: 41.81
epoch: 2, train_loss: 0.583, train_acc: 0.787, test_acc: 0.817, time: 40.59
epoch: 3, train_loss: 0.479, train_acc: 0.824, test_acc: 0.835, time: 40.58
epoch: 4, train_loss: 0.426, train_acc: 0.843, test_acc: 0.849, time: 40.54
epoch: 5, train_loss: 0.396, train_acc: 0.853, test_acc: 0.846, time: 40.71
