## Fashion MNIST MLP Classification using PyTorch


In [None]:
import torch
import torch.nn.functional as F

from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler  # For validation set

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# transforms.ToTensor(): Separates the image into three color channels,
# then it converts the pixels of each image to the brightness of their color between 0 and 255. 
# These values are then scaled down to a range between 0 and 1.
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,),)])

train_set = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
test_set = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)

# Prepare our validation set
indices = list(range(len(train_set)))
print('Before shuffle:', indices[:15])

np.random.shuffle(indices)
print('After shuffle:',indices[:15])

# Put away 20% for validation set 
split = int(np.floor(0.2 * len(train_set)))
train_sample = SubsetRandomSampler(indices[:split])
valid_sample = SubsetRandomSampler(indices[split:])

# batch_size parameter defines how many samples per batch to load
# shuffle parameter reshuffles data with every epoch
train_loader = torch.utils.data.DataLoader(train_set, sampler=train_sample , batch_size=64)
valid_loader = torch.utils.data.DataLoader(train_set, sampler=valid_sample, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True) # 

In [None]:
# Let's get a sense of our data
images, labels = next(iter(train_loader))
print(images.shape)
print(labels.shape)

data_dictionary = {
    0: 'T-shirt/Top',
    1: 'Trouser',
    2: 'PullOver',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot'
}

fig = plt.figure(figsize=(15, 5))
for i in range(20):
    ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[i]), cmap='gray')
    ax.set_title(data_dictionary[labels[i].item()] + ' - ' + str(labels[i].item()))

    fig.tight_layout()

In [None]:
input_size = 784 # images are 28 * 28
hidden_sizes = [512, 256, 128, 64]
output_size = 10

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Batch normalization drastically decreases the number of epochs 
        # Also helps avoid overfitting
        # https://www.youtube.com/watch?v=yXOMHOpbon8
        self.batch_norm_0 = nn.BatchNorm1d(input_size)
        self.batch_norm_1 = nn.BatchNorm1d(hidden_sizes[0])
        self.batch_norm_2 = nn.BatchNorm1d(hidden_sizes[1])
        self.batch_norm_3 = nn.BatchNorm1d(hidden_sizes[2])
        self.batch_norm_4 = nn.BatchNorm1d(hidden_sizes[3])

        self.fc0 = nn.Linear(input_size, hidden_sizes[0]) 
        self.fc1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.fc2 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.fc3 = nn.Linear(hidden_sizes[2], hidden_sizes[3])
        self.fc4 = nn.Linear(hidden_sizes[3], output_size)
        self.dropout = nn.Dropout(p=0.1) # 10% dropout


    def forward(self, input):
        input = input.view(-1, input_size)
        input = self.dropout(input)

        input = F.relu(self.batch_norm_1(self.fc0(input)))

        input = F.relu(self.batch_norm_2(self.fc1(input)))
        input = self.dropout(input)

        input = F.relu(self.batch_norm_3(self.fc2(input)))

        input = F.relu(self.batch_norm_4(self.fc3(input)))
        input = self.dropout(input)

        input = self.fc4(input)
        return F.log_softmax(input, dim=1)

In [None]:
class EarlyStopper:
    def __init__(self, patience, min_delta):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0

    def early_stop(self, valid_loss, train_loss):
        if (valid_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.patience:  
                return True
            return False

In [None]:
model = Classifier()
criterion = nn.NLLLoss()

model_filename = 'model.pt'

In [None]:
# L2 regularization comes out of the box with PyTorch we just need to use the 'weight_decay' parameter in out optimzier function.
optimizer = optim.SGD(model.parameters(), lr=0.001,  momentum=0.9)
early_stopper = EarlyStopper(5, 1) # Just figured these numbers with multiple trials and errors

epoches = 30
train_losses = []
valid_losses = []

model.train()

for epoch in range(epoches):
    train_loss = 0
    valid_loss = 0

    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

        # BEGIN the implementation of the L1 regularization

        # l1_lambda = 0.001
        # l1_norm = sum(torch.linalg.norm(p, 1) for p in model.parameters())
        # train_loss += loss.item() + l1_lambda * l1_norm.item()

        # END the implementation of the L1 regularization

    for images, labels in valid_loader:
        output = model(images)
        loss = criterion(output, labels)
        
        valid_loss += loss.item()

        # BEGIN the implementation of the L1 regularization

        # valid_loss += loss.item() + l1_lambda * l1_norm.item()

        # END the implementation of the L1 regularization

    train_loss = (train_loss / len(train_loader.sampler)) * 1000
    valid_loss = (valid_loss / len(valid_loader.sampler)) * 1000
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    print('Epoch: ' + str(epoch) + '\t Training loss: ' + str(train_loss) + '\t Validation loss: ' + str(valid_loss))
    if early_stopper.early_stop(valid_loss, train_loss):             
        break

torch.save(model.state_dict(), model_filename)

In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Valid Loss')

plt.legend()

In [None]:
model.load_state_dict(torch.load(model_filename))

In [None]:
test_loss = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

# Deactivates dropout and batchnorm layers
model.eval() 

for images, labels in test_loader:
  output = model(images)
  loss = criterion(output, labels)
  test_loss += loss.item()

  # Convert output probabilities to predicted class
  _, pred = torch.max(output, 1)

  # Compare predictions to the true labes
  correct = np.squeeze(pred.eq(labels.data.view_as(pred)))

  for i in range(len(labels)):
    label = labels.data[i]
    class_correct[label] += correct[i].item()
    class_total[label] +=1

test_loss = (test_loss/len(test_loader.sampler)) * 1000

print('Test Loss: ' + str(test_loss))
print('Test Accuracy (Overall): ' + str(100 * np.sum(class_correct) / np.sum(class_total)) + '%')

In [None]:
images, labels = next(iter(test_loader))
output = model(images)
_, preds = torch.max(output, 1)
images = images.numpy()

fig = plt.figure(figsize=(25, 4))
for i in range(20):
    ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[i]), cmap='gray')
    ax.set_title(data_dictionary[preds[i].item()] + ' - ' + data_dictionary[labels[i].item()], color="green" if preds[i]==labels[i] else "red")