In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import os

from model import CNN
from torchsummary import summary
import copy

In [None]:
batch_size = 32
total_epoch = 50

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

train_indices, val_indices = train_test_split(np.arange(len(trainset)), test_size=0.2)

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

trainloader = DataLoader(
    trainset,
    batch_size=batch_size,
    num_workers=4,
    sampler=train_sampler
)

val_loader = DataLoader(
    trainset,
    batch_size=batch_size,
    num_workers=4,
    sampler=valid_sampler
)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
train_size = len(train_sampler)
val_size = len(valid_sampler)

In [None]:
print('length trainset : {}, testset : {}'.format(len(trainset), len(testset)))

#### Display Random Batch of 4 Training Images




In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(testloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images[:3]))
# print labels
print(' '.join('%10s' % classes[labels[j]] for j in range(3)))

In [None]:
# Send the model to GPU
model = CNN()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
summary(model, input_size = (3,32,32), device = 'cpu')

In [None]:
model = model.to(device)

In [None]:
model_folder = os.path.abspath('./checkpoints')
if not os.path.exists(model_folder):
    os.mkdir(model_folder)
model_path = os.path.join(model_folder, 'cifar10.pth')

### training

In [None]:
model.train()  # Set model to training mode

train_loss = []
train_acc = []
val_loss = []
val_acc = []

best_acc = 0.0

for epoch in range(total_epoch):

    model.train()  # Set model to training mode
    
    running_loss = 0.0
    running_corrects = 0.0
    epoch_size = 0.0
    
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        _, preds = torch.max(outputs, 1)

        optimizer.step()

        # print statistics
        # statistics
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)
        
        epoch_size += inputs.size(0)
        
    # Normalizing the loss by the total number of train batches
    
    running_loss /= epoch_size
    running_corrects =  running_corrects.double() / epoch_size
    
    train_loss.append(running_loss)
    train_acc.append(running_corrects)
    
    print('train Loss: {:.4f} Acc: {:.4f}'.format(running_loss, running_corrects))
    
    # evalute
    print('Finished epoch {}, starting evaluation'.format(epoch+1))

    model.eval()   # Set model to evaluate mode
    
    running_loss = 0.0
    running_corrects = 0.0
    epoch_size = 0.0
    
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader):
            
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_loss += loss.item()
            running_corrects += torch.sum(preds == labels.data)
            
            epoch_size += inputs.size(0)
    
    running_loss /= epoch_size
    running_corrects =  running_corrects.double() / epoch_size
    
    print('Validation Loss: {:.4f} Acc: {:.4f}'.format(running_loss, running_corrects))
    
    val_loss.append(running_loss)
    val_acc.append(running_corrects)
    
    # deep copy the model
    if running_corrects > best_acc:
        print("saving best model val_acc : {:.4f}".format(running_corrects))
        best_acc = running_corrects
        
        model_copy = copy.deepcopy(model)
        model_copy = model_copy.cpu()
        
        model_state_dict = model_copy.state_dict()
        torch.save(model_state_dict, model_path)

        del model_copy

print('==> Finished Training ...')

### loss

In [None]:
hist = [h for h in val_loss]
hist2 = [h for h in train_loss]

plt.title("train vs Validation loss")
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1,len(val_loss)+1),hist,label="Validation")
plt.plot(range(1,len(train_loss)+1),hist2,label="Train")
plt.ylim((0,0.08))
plt.xticks(np.arange(1, len(train_loss)+1, 2))
plt.legend()
plt.show()

### accuracy 

In [None]:
hist = [h for h in val_acc]
hist2 = [h for h in train_acc]

plt.title("train vs Validation accuracy")
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1,len(val_acc)+1),hist,label="Validation")
plt.plot(range(1,len(train_acc)+1),hist2,label="Train")
plt.ylim((0,1.0))
plt.xticks(np.arange(1, len(train_acc)+1, 2))
plt.legend()
plt.show()