In [60]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [61]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [62]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super().__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i*out_channels
            net.append(conv_layer(in_c, out_channels))
        self.net = nn.ModuleList(net) 
        self.out_channels = in_channels + num_convs*out_channels
        
    def forward(self, X):
        for layer in self.net:
            Y = layer(X)
            X = torch.cat((X,Y), dim = 1)
        return X       
    
    @staticmethod
    def conv_layer(in_channels, out_channels):
        blk = nn.Sequential(nn.BatchNorm2d(in_channels),
                           nn.ReLU(),
                           nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        return blk
    

In [63]:
class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.trans_layer = transition_layer(in_channels, out_channels) 
    def forward(self, X):
        Y = self.trans_layer(X)
        return Y
    
    @staticmethod
    def transition_layer(in_channels, out_channels):
        layer = nn.Sequential(nn.BatchNorm2d(in_channels),
                             nn.ReLU(),
                             nn.Conv2d(in_channels, out_channels, kernel_size =1),
                             nn.AvgPool2d(kernel_size=2, stride=2))
        return layer   

In [64]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,X):
        return X.mean(dim = 2, keepdim = True).mean(dim = 3, keepdim = True)

In [65]:
conv_block = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
                          nn.BatchNorm2d(64),
                           nn.ReLU(),
                           nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
Dense_Net = nn.Sequential()
Dense_Net.add_module('conv_block', conv_block)
num_channels, growth_rate = 64, 32
num_convs_in_dense_block = [4,4,4,4]
for i, num_convs in enumerate(num_convs_in_dense_block):
    DB = DenseBlock(num_convs,num_channels,growth_rate)
    Dense_Net.add_module('dense_block%d'%i, DB)
    num_channels = DB.out_channels
    if i!=(len(num_convs_in_dense_block)-1):
        Dense_Net.add_module('transition_block%d'%(i), TransitionBlock(num_channels, num_channels//2))
        num_channels = num_channels//2
Dense_Net.add_module('bn_block',nn.Sequential(nn.BatchNorm2d(num_channels),nn.ReLU()))
Dense_Net.add_module('fc_block',nn.Sequential(GlobalAvgPool2d(),nn.Flatten(),nn.Linear(num_channels, 10)))

In [66]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X_test, y_test in data_iter:
            X_test = X_test.to(device)
            y_test = y_test.to(device)
            network.eval()
            y_hat = network(X_test)
            acc_num  += (y_hat.argmax(dim = 1)==y_test.to(device)).sum().cpu().item()
            network.train()
            n += y_test.shape[0]
    acc = acc_num/n
    return acc 

In [67]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, train_acc_num, batch_count, n, start_time = 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.cpu().item()
            train_acc_num += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, network)
        print('epoch: %d, train_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss_sum/batch_count, train_acc_num/n, test_acc, (time.time()-start_time)))

In [69]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size, resize=96)
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(Dense_Net.parameters(), lr = lr)
train_network(Dense_Net, train_iter, test_iter, optimizer, num_epochs)

epoch: 1, train_loss: 0.456, train_acc: 0.840, test_acc: 0.813, time: 23.79
epoch: 2, train_loss: 0.272, train_acc: 0.899, test_acc: 0.877, time: 23.30
epoch: 3, train_loss: 0.235, train_acc: 0.914, test_acc: 0.852, time: 23.33
epoch: 4, train_loss: 0.213, train_acc: 0.921, test_acc: 0.907, time: 23.32
epoch: 5, train_loss: 0.195, train_acc: 0.927, test_acc: 0.921, time: 23.44
