In [53]:
from PIL import Image
import torch
from torchvision import transforms
import onnxruntime
import numpy as np
import json


In [47]:
# Load image
img = Image.open('/workspace/quant/sample_imgs/n02391049_zebra.JPEG')

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)  # Add batch dimension

# Move tensor to the appropriate device
device = torch.device("cpu")
input_batch = input_batch.to(device)

# Convert to numpy array for ONNX Runtime
input_batch = input_batch.numpy().astype(np.float32)  # Ensure float32 type


In [48]:

# Specify device for ONNX Runtime
device_name = "cpu"  # or 'cuda:0'

if device_name == "cpu":
    providers = ["CPUExecutionProvider"]
elif device_name == "cuda:0":
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]

# Create inference session
onnx_model = onnxruntime.InferenceSession(
    "/workspace/quant/models/quantized_googlenet.onnx", providers=providers
)


In [50]:
# Get the name and shape of the input layer
input_name = onnx_model.get_inputs()[0].name
input_shape = onnx_model.get_inputs()[0].shape

print(f"Input name: {input_name}")
print(f"Input shape: {input_shape}")


# Run inference
outputs = onnx_model.run(None, {input_name: input_batch})

# Process and print the output
output_array = outputs[0]
output_tensor = torch.tensor(output_array)
predicted_class = torch.argmax(output_tensor, dim=1).item()

print(f'Predicted class: {predicted_class}')


Input name: input
Input shape: ['batch_size', 3, 224, 224]
Predicted class: 340


In [54]:
# Load ImageNet labels
with open('imagenet-simple-labels.json') as f:
    labels = json.load(f)

# Print predicted class label
predicted_label = labels[predicted_class]
print(f'Predicted class: {predicted_class}, Label: {predicted_label}')

Predicted class: 340, Label: zebra
