In [3]:
import sys
sys.path.insert(1, '..')

In [5]:
import os    
import copy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.models as models

from model import EfficientNet_B0
# from efficientnet_pytorch import EfficientNet

%load_ext autoreload
%autoreload 2

In [8]:
os.listdir('../data')

['cifar-10-batches-py', 'cifar-10-python.tar.gz']

In [21]:
batch = 256

transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


Resnet18

In [35]:
def train_cifar(classifier, optimizer, trainloader, testloader, epochs=100, print_freq=1, device=None, name):
    # save model
    best_epoch = 0
    best_loss = 999
    best_weights = copy.deepcopy(classifier.state_dict())
    
    save_dir = 'saved_models/' + name + '/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
    train_losses = []
    test_losses = []
    test_acc = []
    
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        #training
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()*len(inputs)
        running_loss = running_loss/50000
        
        #testing
        correct = 0
        total = 0
        test_loss = 0
        for data in testloader:
            with torch.no_grad():
                images, labels = data[0].to(device), data[1].to(device)
                outputs = classifier(images)
                test_loss += loss.item()*len(images)
                loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = test_loss/10000
        
        train_losses.append(running_loss)
        test_losses.append(test_loss)
        test_acc.append(correct/total)
        #test every 'print_freq' epochs
        if epoch % print_freq == 0:
            print(f'epoch: {epoch+1}/{epochs}, train loss: {train_losses[-1]:.4f}, test loss: {test_losses[-1]:.4f}, test acc: {correct/total:.4f}')
    
     
        if best_loss < test_loss:
            best_loss = test_loss
            best_weights = copy.deepcopy(classifier.state_dict())
            
    torch.save(best_weights, os.path.join(save_dir, 'best.pth'))
    print('best epoch: {}'.format(best_epoch))
    
    classifier.load_state_dict(best_weights)
    
    return train_losses, test_losses, test_acc, classifier

In [9]:
def plot_history(train_losses,test_losses,test_acc,saveto):
    # Plot the loss function and train / validation accuracies
    plt.figure(figsize=(8,10))
    plt.subplot(2, 1, 1)
    plt.plot(train_losses,label='train')
    plt.plot(test_losses,label='test')
    plt.title('Loss history')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(2, 1, 2)
    plt.plot(test_acc, label='train')
    plt.title('Classification accuracy history')
    plt.xlabel('Epoch')
    plt.ylabel('Classification accuracy')
    plt.show()
    plt.savefig(saveto)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001, momentum=0.9)

# EfficientNet-B0

In [None]:
Eff_B0 = EfficientNet_B0.to(device)

train_losses, test_losses, test_acc, classifier = train_cifar(ResNet, optimizer, trainloader, testloader,
                                                    epochs=100, print_freq=1, 'EfficientNet-B0')

plot_history(train_losses, test_losses, test_acc, saveto='../save_plot/EfficientNet-B0.png')

# ResNet-18

In [36]:
ResNet = models.resnet18(pretrained=False)
ResNet.fc = nn.Linear(512, len(classes))
ResNet.to(device)

train_losses, test_losses, test_acc, classifier = train_cifar(ResNet, optimizer, trainloader, testloader,
                                                    epochs=10, print_freq=1, 'ResNet-18')

plot_history(train_losses, test_losses, test_acc, saveto='../save_plot/ResNet-18.png')