In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from config import ModelConfig
from dataLoader import DataEngine
from model import Net
from train import train
from test import test
from results import plot_misclassified_images, plot_graph
from utilities import *


# View model config
args = ModelConfig()
args.print_config()

print()
# Set seed
init_seed(args)

data = DataEngine(args)

# get some random training images
dataiter = iter(data.train_loader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images[:4]))
# print labels
print(' '.join('%5s' % data.classes[labels[j]] for j in range(4)))

device = which_device()
model = Net(args).to(device)
show_model_summary(model, device, (3,32,32))


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

def run():
  test_losses = []
  test_accs = []
  misclassified_imgs = []
  for epoch in range(args.epochs):
      print("EPOCH:", epoch+1)
      train(model, device, data.train_loader, criterion, optimizer, epoch)
      test(model, device, data.test_loader, criterion, data.classes, test_losses, test_accs, misclassified_imgs, epoch==args.epochs-1)
  return test_losses, test_accs, misclassified_imgs

# train and test the model
test_losses, test_accs, misclassified_imgs = run()

plot_graph(test_losses, "Loss")

plot_graph(test_accs, "Accuracy")

plot_misclassified_images(misclassified_imgs, data.classes, "misclassified_imgs.png")