In [6]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from functions import get_normalized_fashion_mnist, fmnist1, get_iterators
from model_training import train_cnn_model, evaluation

torch.manual_seed(1234)
torch.backends.cudnn.deterministic = True
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

## Specify parameters

In [7]:
#Parameters
LOSS_FN = nn.CrossEntropyLoss()
BATCH_SIZE = 16
L_RATE = 1e-3
#Layer parameters
LINEAR_SIZE = 288
L1 = [1, 16, 3, 1, 1, 2, 2, 0, 0.0] 
L2 = [16, 32, 3, 1, 1, 2, 2, 0, 0.0] 
L3 = [32, 32, 3, 1, 1, 2, 2, 0, 0.0]

## Load data and get dataloaders

In [8]:
#Load full dataset
trainxs, trainys, valxs, valys, testxs, testys = get_normalized_fashion_mnist()

#Get a subset of data
train_dat = fmnist1(trainxs, trainys)
val_dat = fmnist1(valxs, valys)
test_dat = fmnist1(testxs, testys)

#Get iterators
train_loader, val_loader, test_loader = get_iterators(train_dat, val_dat, 
                                                      test_dat, BATCH_SIZE)

## Train model

In [9]:
tl, ta, vl, va, m = train_cnn_model(train_loader, val_loader, L_RATE,
                              LINEAR_SIZE, DEVICE, LOSS_FN, L1, L2, L3)

Epoch: 1, training loss: 1.585845517852534, validation loss: 1.5385540852471002
Epoch: 2, training loss: 1.2373815961062793, validation loss: 0.6457939690067654
Epoch: 3, training loss: 0.40178086608171004, validation loss: 0.2892819206392954
Epoch: 4, training loss: 0.2704877909988852, validation loss: 0.24246694630100613
Epoch: 5, training loss: 0.23820306534337776, validation loss: 0.22401040167444283
Epoch: 6, training loss: 0.21769862887160202, validation loss: 0.19535216546602666
Epoch: 7, training loss: 0.20190542767243191, validation loss: 0.18413185988745046
Epoch: 8, training loss: 0.18949280443322689, validation loss: 0.17462615908995743
Epoch: 9, training loss: 0.17818314020810105, validation loss: 0.1720314114635426
Epoch: 10, training loss: 0.16778719945811332, validation loss: 0.15715618456846903
Epoch: 11, training loss: 0.15753821557993875, validation loss: 0.14929993993469648
Epoch: 12, training loss: 0.1493497966803028, validation loss: 0.15190528974764878
Epoch: 13,

## Evaluate test set

In [11]:
test_loss, test_acc = evaluation(m, LOSS_FN, test_loader, DEVICE)
print(f"Final test acc: {np.round(test_acc, 4)}.")

Final test acc: 0.9782.
