In [None]:
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import logging

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from torchvision.datasets.vision import StandardTransform
from torchvision.models import resnet18

from PIL import Image
from tqdm import tqdm

from Cifar100.Cifar100 import Cifar100

**Set Arguments**

In [None]:
DEVICE = 'cuda' # 'cuda' or 'cpu'

NUM_CLASSES = 100
CLASSES_EACH_TRAIN = 10

BATCH_SIZE = 32     # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing
                    # the batch size, learning rate should change by the same factor to have comparable results

LR = 1e-3            # The initial Learning Rate
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5  # Regularization, you can keep this at the default

NUM_EPOCHS = 25  #30    # Total number of training epochs (iterations over dataset)
STEP_SIZE = 12  #30    # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.1          # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 100


**Prepare Network**

In [None]:
net = resnet18()
best_net = resnet18()

# We just changed the last layer of AlexNet with a new fully connected layer with NUM_CLASSES outputs
net.fc = nn.Linear(net.fc.in_features, NUM_CLASSES)

**Prepare Training**

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss() # for classification, we use Cross Entropy

parameters_to_optimize = net.parameters()

optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# optimizer = optim.Adam(parameters_to_optimize, lr=LR, weight_decay=WEIGHT_DECAY)

# Define scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

**Train and Test**

In [None]:
net = net.to(DEVICE)
cudnn.benchmark = True # Calling this optimizes runtime

current_step = 0
best_accuracy = 0
loss = 0
index = 0
loss_train = []
accuracy_train = []


In [None]:
#New variable from our class Cifar100
cifar100 = Cifar100(BATCH_SIZE, NUM_EPOCHS, DEVICE, LR, STEP_SIZE, GAMMA)

for index in range(0, int(NUM_CLASSES/CLASSES_EACH_TRAIN)):

  #Load data from Cifar100
  train_dataloader = cifar100.load('train')#, index)
  test_dataloader = cifar100.load('test')#, index)

  # Start iterating over the epochs
  for epoch in range(NUM_EPOCHS):

    print('Starting epoch {}/{}, LR = {}'.format(epoch+1, NUM_EPOCHS, scheduler.get_lr()))
    running_correct_train=0

    # Iterate over the dataset
    for images, labels in train_dataloader:

      # Bring data over the device of choice
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)

      net.train().to(DEVICE)
      optimizer.zero_grad()
      outputs = net(images)

      #Calc the correct for the graph
      _, preds = torch.max(outputs.data, 1)
      running_correct_train += torch.sum(preds == labels.data).data.item()

      # Compute loss based on output and ground truth
      loss = criterion(outputs, labels)

      # Log loss
      if current_step % LOG_FREQUENCY == 0:
        print('Step {}, Loss {}'.format(current_step, loss.item()))

      # Compute gradients for each layer and update weights
      loss.backward()  # backward pass: computes gradients
      optimizer.step() # update weights based on accumulated gradients

      current_step += 1

    loss_train.append(loss.item())
    accuracy_train.append(running_correct_train / float(len(train_dataloader)))

    # Step the scheduler
    scheduler.step()

  accuracy_test, loss_test = cifar100.test(net, test_dataloader, DEVICE, criterion)

  loss_test.append(loss.item())
  accuracy_test.append(running_correct_train / float(len(train_dataloader)))
  print('Test Accuracy: {}'.format(accuracy_test))


  if accuracy_test > best_accuracy:
    best_net = copy.deepcopy(net)
    best_accuracy = accuracy_test


**Plots**

In [None]:
cifar100.plot()
print('Best accuracy', best_accuracy)
