In [67]:
import onnxruntime as ort
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image

In [68]:
def combine_images(image_path1, image_path2, output_path):
    img1 = Image.open(image_path1)
    img2 = Image.open(image_path2)

    if img1.height != img2.height:
        img2 = img2.resize((img2.width, img1.height))
    
    combined_width = img1.width + img2.width
    combined_image = Image.new("RGB", (combined_width, img1.height))
    
    combined_image.paste(img1, (0, 0))
    combined_image.paste(img2, (img1.width, 0))
    
    combined_image.save(output_path)

In [69]:
# for i in range(10):
#     for j in range(10):
#         image_path1 = f"../data/mnist_{i}.png"
#         image_path2 = f"../data/mnist_{j}.png"
#         output_path = f"../data/mnist_{i}{j}.png"
#         combine_images(image_path1, image_path2, output_path)

In [70]:
# def predict_number(onnx_model_path, image_path):
#     ort_session = ort.InferenceSession(onnx_model_path)
    
#     transform = transforms.Compose([
#         transforms.Grayscale(num_output_channels=1),
#         transforms.Resize((28, 28)),
#         transforms.ToTensor(),
#         transforms.Normalize((0.1307,), (0.3081,))
#     ])
    
#     image = Image.open(image_path).convert('L')
#     image = transform(image)
    
#     if image.mean() > 0.5:
#         image = 1.0 - image
    
#     image = image.unsqueeze(0).numpy().astype(np.float32)
    
#     ort_inputs = {ort_session.get_inputs()[0].name: image}
#     ort_outs = ort_session.run(None, ort_inputs)
#     predicted = np.argmax(ort_outs[0], axis=1)[0]
    
#     return predicted



def predict_number(ort_session, image_path):
    input_shape = ort_session.get_inputs()[0].shape
    
    _, _, target_height, target_width = input_shape
    
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((target_height, target_width)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    image = Image.open(image_path).convert('L')
    image = transform(image)
    
    if image.mean() > 0.5:
        image = 1.0 - image
    
    image = image.unsqueeze(0).numpy().astype(np.float32)
    
    ort_inputs = {ort_session.get_inputs()[0].name: image}
    ort_outs = ort_session.run(None, ort_inputs)
    
    predicted = np.argmax(ort_outs[0], axis=1)[0]
    
    return predicted

In [72]:
predicted_digit_onnx = predict_number(ort.InferenceSession('model.onnx'), '../data/images.png')
print(f'{predicted_digit_onnx}')

1
