# PyTorch Implementation for CIFAR-10 Model Deployment - By Rasool Vahdati

## Import Libraries

In [18]:
!pip install tf2onnx
!pip install onnx
!pip install onnxruntime



In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F
import onnx
import onnxruntime
import numpy as np
from pathlib import Path
import os

## Dataset Loading and Preprocessing

In [20]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


## Model Definition

In [22]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleCNN()

## Training the Model

In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")

Epoch 1/5, Loss: 1.4274
Epoch 2/5, Loss: 1.0540
Epoch 3/5, Loss: 0.9104
Epoch 4/5, Loss: 0.7968
Epoch 5/5, Loss: 0.7107


## Saving the Model

In [24]:
# Save state dictionary (.pt)
torch.save(model.state_dict(), 'cifar10_pt_model.pt')

# Save entire model (.pth)
torch.save(model, 'cifar10_pth_model.pth')

# Save model in ONNX format
dummy_input = torch.randn(1, 3, 32, 32)
torch.onnx.export(model, dummy_input, 'cifar10_onnx_model.onnx', input_names=['input'], output_names=['output'], opset_version=11)

## Inference on a Single Image

In [26]:
# Load models
model_state_dict = SimpleCNN()
model_state_dict.load_state_dict(torch.load('cifar10_pt_model.pt'))
model_state_dict.eval()

model_pth = torch.load('cifar10_pth_model.pth')
model_pth.eval()

ort_session = onnxruntime.InferenceSession('cifar10_onnx_model.onnx')

# Perform inference on a single image
sample_input, _ = test_dataset[0]
sample_input = sample_input.unsqueeze(0)

# PyTorch inference
output_pt = model_state_dict(sample_input).argmax(dim=1).item()
output_pth = model_pth(sample_input).argmax(dim=1).item()

# ONNX inference
ort_inputs = {'input': sample_input.numpy()}
ort_output = ort_session.run(None, ort_inputs)[0].argmax(axis=1)[0]

print(f"Prediction from .pt: {output_pt}")
print(f"Prediction from .pth: {output_pth}")
print(f"Prediction from ONNX: {ort_output}")

Prediction from .pt: 3
Prediction from .pth: 3
Prediction from ONNX: 3


  model_state_dict.load_state_dict(torch.load('cifar10_pt_model.pt'))
  model_pth = torch.load('cifar10_pth_model.pth')


## Model Accuracy Evaluation

In [30]:
import time

def evaluate_model(model, data_loader):
    correct = 0
    total = 0
    start_time = time.time()  # Start timing
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    end_time = time.time()  # End timing
    inference_time = end_time - start_time
    accuracy = 100 * correct / total
    return accuracy, inference_time

# Evaluate PyTorch models
accuracy_pt, time_pt = evaluate_model(model_state_dict, test_loader)
accuracy_pth, time_pth = evaluate_model(model_pth, test_loader)

# Evaluate ONNX model
correct_onnx = 0
total_onnx = 0
start_time_onnx = time.time()  # Start timing ONNX inference

for inputs, labels in test_loader:
    # Adjust the batch size for ONNX inference
    for i in range(inputs.size(0)):
        single_input = inputs[i].unsqueeze(0).numpy()  # Add a batch dimension
        ort_inputs = {'input': single_input}
        ort_outputs = ort_session.run(None, ort_inputs)
        prediction = np.argmax(ort_outputs[0], axis=1)[0]
        total_onnx += 1
        correct_onnx += (prediction == labels[i].item())

end_time_onnx = time.time()  # End timing ONNX inference
time_onnx = end_time_onnx - start_time_onnx
accuracy_onnx = 100 * correct_onnx / total_onnx

# Print Results
print(f"Accuracy (.pt): {accuracy_pt:.2f}%, Inference Time: {time_pt:.4f} seconds")
print(f"Accuracy (.pth): {accuracy_pth:.2f}%, Inference Time: {time_pth:.4f} seconds")
print(f"Accuracy ONNX: {accuracy_onnx:.2f}%, Inference Time: {time_onnx:.4f} seconds")

Accuracy (.pt): 77.20%, Inference Time: 29.8904 seconds
Accuracy (.pth): 77.20%, Inference Time: 29.0775 seconds
Accuracy ONNX: 77.20%, Inference Time: 24.9061 seconds


## Format Size Comparison

In [29]:
file_sizes = {
    '.pt': os.path.getsize('cifar10_pt_model.pt') / 1024,
    '.pth': os.path.getsize('cifar10_pth_model.pth') / 1024,
    'ONNX': os.path.getsize('cifar10_onnx_model.onnx') / 1024
}

print("Model File Sizes (in KB):")
for fmt, size in file_sizes.items():
    print(f"{fmt}: {size:.2f} KB")

Model File Sizes (in KB):
.pt: 9677.33 KB
.pth: 9680.67 KB
ONNX: 9664.56 KB
