In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from network import Network

In [2]:
# covertes to tensor and normalizes the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=transform
)

Files already downloaded and verified


Before starting the training procedure, it is a best practice to try and overfit a single batch of data, so to confirm that the network is implemented correctly and it has the capability to be used as the model for training.

In [3]:
# load the train_set for trying out the model
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=1)

model = Network()  # initialize the NN
criterion = nn.CrossEntropyLoss()  # loss function (categorical cross-entropy)
optimizer = optim.Adam(model.parameters(), lr=0.01)  # specify the optimizer
images, labels = next(iter(train_loader))  # load one batch of train_set

for epoch in range(50):
    correct = 0  # will be used to track the running num correct
    preds = model(images)  # forward pass
    loss = criterion(preds, labels)  # calculate loss
    optimizer.zero_grad()  # clear accumulated gradients from the previous pass
    loss.backward()  # backward pass
    optimizer.step()  # perform a single optimization step
    correct += get_num_correct(preds, labels)  # update running num correct

    # print statistics
    print(f'epoch: {epoch:2d} loss:{loss.item():2.4f} acc:{(correct/32):2.4f}')

epoch:  0 loss:2.3195 acc:0.0938
epoch:  1 loss:4.9806 acc:0.1562
epoch:  2 loss:6.4253 acc:0.1562
epoch:  3 loss:6.3442 acc:0.0938
epoch:  4 loss:4.6450 acc:0.1250
epoch:  5 loss:3.2097 acc:0.1562
epoch:  6 loss:2.4960 acc:0.2500
epoch:  7 loss:2.2880 acc:0.1250
epoch:  8 loss:2.1752 acc:0.2188
epoch:  9 loss:2.1317 acc:0.0938
epoch: 10 loss:2.0971 acc:0.2500
epoch: 11 loss:2.0512 acc:0.2812
epoch: 12 loss:1.9756 acc:0.2500
epoch: 13 loss:1.8806 acc:0.3125
epoch: 14 loss:1.7805 acc:0.3438
epoch: 15 loss:1.6941 acc:0.3750
epoch: 16 loss:1.6683 acc:0.3750
epoch: 17 loss:1.5856 acc:0.4688
epoch: 18 loss:1.4691 acc:0.4375
epoch: 19 loss:1.6006 acc:0.4375
epoch: 20 loss:1.3912 acc:0.5000
epoch: 21 loss:1.3030 acc:0.5000
epoch: 22 loss:1.5074 acc:0.4062
epoch: 23 loss:1.4688 acc:0.4062
epoch: 24 loss:1.4631 acc:0.5625
epoch: 25 loss:1.1601 acc:0.5938
epoch: 26 loss:1.1292 acc:0.5625
epoch: 27 loss:1.2183 acc:0.4688
epoch: 28 loss:0.9641 acc:0.6250
epoch: 29 loss:1.0707 acc:0.5625
epoch: 30 