In [2]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import onnx
import onnxruntime
import time
import torch
import torch
import torch.nn as nn
import torch.onnx
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [3]:
device = torch.device("cpu")

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        x = self.maxpool3(self.relu3(self.conv3(x)))
        x = self.flatten(x)
        x = self.relu4(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net().to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=0.001)

num_epochs = 5
log_interval = 100
train_losses = []
accuracies_per_class = {c: [] for c in classes}
train_steps = []
current_step = 0
best_accuracy = 0.0
best_model_state = None

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % log_interval == 0:
            train_losses.append(running_loss / log_interval)
            train_steps.append(current_step + log_interval)
            running_loss = 0.0

            with torch.no_grad():
                class_correct = list(0. for i in range(10))
                class_total = list(0. for i in range(10))
                outputs = net(inputs)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == labels).squeeze()
                for k in range(labels.size(0)):
                    label_idx = labels[k]
                    class_correct[label_idx] += c[k].item()
                    class_total[label_idx] += 1
                for j in range(10):
                    if class_total[j] > 0:
                        accuracies_per_class[classes[j]].append(class_correct[j] / class_total[j])
                    else:
                        accuracies_per_class[classes[j]].append(0.0)

        current_step += 1
        if (i + 1) % log_interval == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {train_losses[-1]:.3f}')

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_accuracy = 100 * correct / total
    print(f'Epoch {epoch+1} - Accuracy on the test set: {epoch_accuracy:.2f}%')

    if epoch_accuracy > best_accuracy:
        best_accuracy = epoch_accuracy
        best_model_state = net.state_dict()
        torch.save(best_model_state, 'best_model.pth')
        print(f'Saved best model with accuracy: {best_accuracy:.2f}%')

print(f'Finished Training. Best accuracy on the test set: {best_accuracy:.2f}%')

[1,   100] loss: 1.896
[1,   200] loss: 1.583
[1,   300] loss: 1.479
Epoch 1 - Accuracy on the test set: 53.61%
Saved best model with accuracy: 53.61%
[2,   100] loss: 1.316
[2,   200] loss: 1.267
[2,   300] loss: 1.201
Epoch 2 - Accuracy on the test set: 62.14%
Saved best model with accuracy: 62.14%
[3,   100] loss: 1.091
[3,   200] loss: 1.065
[3,   300] loss: 1.032
Epoch 3 - Accuracy on the test set: 68.97%
Saved best model with accuracy: 68.97%
[4,   100] loss: 0.938
[4,   200] loss: 0.931
[4,   300] loss: 0.909
Epoch 4 - Accuracy on the test set: 70.48%
Saved best model with accuracy: 70.48%
[5,   100] loss: 0.855
[5,   200] loss: 0.842
[5,   300] loss: 0.842
Epoch 5 - Accuracy on the test set: 72.63%
Saved best model with accuracy: 72.63%
Finished Training. Best accuracy on the test set: 72.63%


In [9]:
def evaluate_model(model, testloader, device, is_onnx=False, ort_session=None):
    start_time = time.time()
    correct = 0
    total = 0
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}
    
    testset = testloader.dataset
    consistent_testloader = DataLoader(testset, batch_size=testloader.batch_size, shuffle=False, num_workers=testloader.num_workers)

    with torch.no_grad():
        for data in consistent_testloader:
            images, labels = data[0].to(device), data[1].to(device)

            if is_onnx:
                onnx_inputs = {ort_session.get_inputs()[0].name: images.cpu().numpy()}
                onnx_outs = ort_session.run(None, onnx_inputs)
                outputs = torch.from_numpy(onnx_outs[0]).to(device)
            else:
                outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for label, prediction in zip(labels, predicted):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    total_accuracy = 100 * correct / total
    class_accuracy = {
        classname: 100 * float(correct_pred[classname]) / total_pred[classname]
        for classname, correct_count in correct_pred.items()
    }
    evaluation_time = time.time() - start_time

    return total_accuracy, class_accuracy, evaluation_time

def load_torch_model(Model, path):
    print(f'\nLoading torch model')
    best_model = Model().to(device)
    best_model.load_state_dict(torch.load(path))
    best_model.eval()
    print(f'\nTorch model loaded')
    return best_model

def export_torch_model_to_onnx(Model, torch_model_path, onnx_export_path):
    print(f'\nExport torch model to ONNX')
    toch_model = load_torch_model(Model, torch_model_path)
    example_input = torch.randn(1, 3, 32, 32, device=device)
    torch.onnx.export(toch_model,
                    example_input,
                    onnx_export_path,
                    export_params=True,
                    opset_version=11,
                    do_constant_folding=True,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input': {0: 'batch_size'},
                                    'output': {0: 'batch_size'}})
    print(f'\nTorch model exported')

def load_onnx_model_and_start_ort(path):
    onnx_model = onnx.load(path)
    onnx.checker.check_model(onnx_model)
    print("Loaded ONNX model. ONNX model checked and is well formed.")
    return onnxruntime.InferenceSession("best_model.onnx", providers=['CPUExecutionProvider'])

In [10]:
torch_model_path = 'best_model.pth'
onnx_model_path = "best_model.onnx"

best_model = load_torch_model(Net, torch_model_path)
total_accuracy, class_accuracy, evaluation_time = evaluate_model(best_model, testloader, device)
print(f'\nEvaluation of the best PyTorch model:')
print(f'Accuracy of the best model on the 10000 test images: {total_accuracy:.2f}%')
print(f'Evaluation time: {evaluation_time:.2f} seconds')
print("\nAccuracy per class for the best model on the test set:")
for classname, accuracy in class_accuracy.items():
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

ort_session = load_onnx_model_and_start_ort(onnx_model_path)
total_accuracy_onnx, class_accuracy_onnx, evaluation_time_onnx = evaluate_model(None, testloader, device, is_onnx=True, ort_session=ort_session)
print(f'\nEvaluation of the ONNX model on the test images:')
print(f'Accuracy of the ONNX model on the 10000 test images: {total_accuracy_onnx:.2f}%')
print(f'Evaluation time: {evaluation_time_onnx:.2f} seconds')
print("\nAccuracy per class for the ONNX model on the test set:")
for classname, accuracy in class_accuracy_onnx.items():
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')


Loading torch model

Torch model loaded


  best_model.load_state_dict(torch.load(path))



Evaluation of the best PyTorch model:
Accuracy of the best model on the 10000 test images: 72.63%
Evaluation time: 17.58 seconds

Accuracy per class for the best model on the test set:
Accuracy for class: plane is 85.8 %
Accuracy for class: car   is 94.2 %
Accuracy for class: bird  is 60.0 %
Accuracy for class: cat   is 56.9 %
Accuracy for class: deer  is 57.4 %
Accuracy for class: dog   is 62.7 %
Accuracy for class: frog  is 84.2 %
Accuracy for class: horse is 76.3 %
Accuracy for class: ship  is 84.9 %
Accuracy for class: truck is 63.9 %
Loaded ONNX model. ONNX model checked and is well formed.

Evaluation of the ONNX model on the test images:
Accuracy of the ONNX model on the 10000 test images: 72.22%
Evaluation time: 15.59 seconds

Accuracy per class for the ONNX model on the test set:
Accuracy for class: plane is 82.0 %
Accuracy for class: car   is 84.3 %
Accuracy for class: bird  is 59.1 %
Accuracy for class: cat   is 49.1 %
Accuracy for class: deer  is 67.0 %
Accuracy for class: