In [1]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [11]:
def VGG_Block(num_convs, in_channels, out_channels):
    Conv_list = []
    for i in range(num_convs):
        if i == 0:
            Conv_list.append(nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        else:
            Conv_list.append(nn.Conv2d(out_channels, out_channels, kernel_size= 3, stride = 1, padding = 1))
        Conv_list.append(nn.ReLU())  
    Conv_list.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*Conv_list)    

In [12]:
conv_arch = [(1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512)]
fc_inputs = 512*7*7
fc_hiddens = 4096

In [15]:
def vgg(conv_arch, fc_inputs, fc_hiddens):
    network = nn.Sequential()
    for i, param_tuple in enumerate(conv_arch):
        network.add_module('vgg_block_'+str(i+1),VGG_Block(*param_tuple))
    network.add_module('fc', nn.Sequential(nn.Flatten(),
                                    nn.Linear(fc_inputs, fc_hiddens),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hiddens, fc_hiddens),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hiddens,10)))
    return network

In [21]:
ratio = 8
conv_arch = [(1,1,64//ratio),(1,64//ratio,128//ratio),(2,128//ratio,256//ratio),(2,256//ratio,512//ratio),(2,512//ratio,512//ratio)]
fc_inputs = (512//8)*7*7
fc_hiddens = 4096//8
Vgg_net = vgg(conv_arch, fc_inputs, fc_hiddens )

In [17]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [18]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X, y in data_iter:
            network.eval()
            y_hat = network(X.to(device))
            acc_num  += (y_hat.argmax(dim = 1)==y.to(device)).float().sum().cpu().item()
            network.train()
            n += y.shape[0]
    acc = acc_num/n
    return acc 

In [20]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = 'cuda:0' if torch.cuda.is_available else 'cpu'
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, test_loss_sum, train_acc_num, batch_count, batch_count_test, n, start_time = 0.0, 0.0, 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                train_acc_num += (y_hat.argmax(dim=1) == y).float().sum().cpu().item()
                train_loss_sum += loss.cpu().item()
                n += y.shape[0]
                batch_count += 1
        train_acc = train_acc_num/n
        train_loss = train_loss_sum/batch_count
        test_acc = evaluate_accuracy(test_iter, network)
        for X_test, y_test in test_iter:
            y_test_hat = network(X_test.to(device))
            loss_test = loss_fn(y_test_hat, y_test.to(device))
            test_loss_sum += loss_test.cpu().item()
            batch_count_test += 1
        test_loss = test_loss_sum/batch_count_test
        print('epoch: %d, train_loss: %.3f, test_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss, test_loss, train_acc, test_acc, (time.time()-start_time)))

In [22]:
train_iter, test_iter = load_data_fashion_mnist(batch_size = 256, resize = 224)
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(Vgg_net.parameters(), lr=lr)
train_network(Vgg_net, train_iter, test_iter, optimizer, num_epochs)

epoch: 1, train_loss: 0.770, test_loss: 0.428, train_acc: 0.703, test_acc: 0.857, time: 35.42
epoch: 2, train_loss: 0.367, test_loss: 0.363, train_acc: 0.866, test_acc: 0.880, time: 33.46
epoch: 3, train_loss: 0.308, test_loss: 0.317, train_acc: 0.887, test_acc: 0.894, time: 34.32
epoch: 4, train_loss: 0.274, test_loss: 0.301, train_acc: 0.899, test_acc: 0.904, time: 34.75
epoch: 5, train_loss: 0.252, test_loss: 0.316, train_acc: 0.908, test_acc: 0.908, time: 34.33
