In [8]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from collections import OrderedDict
import numpy as np
import time 

In [4]:
class AlexNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(1, 96, 11, 4),
                                 nn.ReLU(),
                                 nn.MaxPool2d(3, 2),
                                 nn.Conv2d(96, 256, 5, 1, 2),
                                 nn.ReLU(),
                                 nn.MaxPool2d(3, 2),
                                 nn.Conv2d(256, 384, 3, 1, 1),
                                 nn.ReLU(),
                                 nn.Conv2d(384, 384, 3, 1, 1),
                                 nn.ReLU(),
                                 nn.Conv2d(384, 256, 3, 1,1),
                                 nn.ReLU(),
                                 nn.MaxPool2d(3, 2))
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(nn.Linear(256*5*5, 4096),
                               nn.ReLU(),
                               nn.Dropout(0.5),
                               nn.Linear(4096, 4096),
                               nn.ReLU(),
                               nn.Dropout(0.5),
                               nn.Linear(4096, 10))
    def forward(self, X):
        conv_X = self.conv(X)
        flat_X = self.flatten(conv_X)
        output_X = self.fc(flat_X)
        return output_X
        

In [7]:
Alex_net = AlexNetwork()
print(Alex_net)

AlexNetwork(
  (conv): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc): Sequential(
    (0): Linear(in_features=6400, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): R

In [13]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [14]:
train_iter, test_iter = load_data_fashion_mnist(batch_size = 256, resize = 224)

In [17]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X, y in data_iter:
            y_hat = network(X)
            acc_num  += (y_hat.argmax(dim = 1)==y).float().sum().items()
            n += y.shape[0]
    acc = acc_num/n
    return acc 
    

In [19]:
def train_network(network, train_iter, test_iter, num_epochs, optimizer):
    device = 'cuda:0' if torch.cuda.is_available else 'cpu'
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, test_loss_sum, train_acc_num, batch_count, batch_count_test, n, start_time = 0.0, 0.0, 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                train_acc_num += (y_hat.argmax(dim=1) == y).float().sum().items()
                train_loss_sum += loss
                n += y.shape[0]
                batch_count += 1
        train_acc = train_acc_num/n
        train_loss = train_loss_sum/batch_count
        test_acc = evaluate_accuracy(test_iter, network)
        for X_test, y_test in test_iter:
            y_test_hat = network(X_test)
            loss_test = loss_fn(y_test_hat, y_test).items()
            test_loss_sum += loss_test
            batch_count_test += 1
        test_loss = test_loss_sum/batch_count_test
        print('epoch: %d, train_loss: %.3f, test_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch, train_loss, test_loss, train_acc, test_acc, (start_time-time.time())))
            


In [None]:
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(network.parameters(), lr=lr)
train_network(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)