In [10]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [11]:
def VGG_Block(num_convs, in_channels, out_channels):
    Conv_list = []
    for i in range(num_convs):
        if i == 0:
            Conv_list.append(nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        else:
            Conv_list.append(nn.Conv2d(out_channels, out_channels, kernel_size= 3, stride = 1, padding = 1))
        Conv_list.append(nn.ReLU())  
    Conv_list.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*Conv_list)    

In [12]:
conv_arch = [(1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512)]
fc_inputs = 512*7*7
fc_hiddens = 4096

In [13]:
def vgg(conv_arch, fc_inputs, fc_hiddens):
    network = nn.Sequential()
    for i, param_tuple in enumerate(conv_arch):
        network.add_module('vgg_block_'+str(i+1),VGG_Block(*param_tuple))
    network.add_module('fc', nn.Sequential(nn.Flatten(),
                                    nn.Linear(fc_inputs, fc_hiddens),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hiddens, fc_hiddens),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hiddens,10)))
    return network

In [14]:
ratio = 8
conv_arch = [(1,1,64//ratio),(1,64//ratio,128//ratio),(2,128//ratio,256//ratio),(2,256//ratio,512//ratio),(2,512//ratio,512//ratio)]
fc_inputs = (512//8)*7*7
fc_hiddens = 4096//8
Vgg_net = vgg(conv_arch, fc_inputs, fc_hiddens )

In [15]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [21]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, train_acc_num, batch_count, n, start_time = 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.cpu().item()
            train_acc_num += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, network)
        print('epoch: %d, train_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss_sum/batch_count, train_acc_num/n, test_acc, (time.time()-start_time)))

In [28]:
train_iter, test_iter = load_data_fashion_mnist(batch_size = 256, resize = 224)
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(Vgg_net.parameters(), lr=lr)
train_network(Vgg_net, train_iter, test_iter, optimizer, num_epochs)

epoch: 1, train_loss: 0.200, train_acc: 0.928, test_acc: 0.924, time: 31.10
epoch: 2, train_loss: 0.176, train_acc: 0.937, test_acc: 0.928, time: 30.89
epoch: 3, train_loss: 0.161, train_acc: 0.942, test_acc: 0.927, time: 30.97
epoch: 4, train_loss: 0.151, train_acc: 0.945, test_acc: 0.930, time: 31.13
epoch: 5, train_loss: 0.137, train_acc: 0.950, test_acc: 0.928, time: 31.30


In [29]:
!nvidia-smi

Mon Apr 26 16:46:29 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.25       Driver Version: 470.25       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| 22%   41C    P8     9W / 215W |   6396MiB /  8192MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces