In [6]:
import numpy as np
import onnxruntime as ort
from PIL import Image

class PotatoPredictor:
    def __init__(self, model_path):
        self.session = ort.InferenceSession(model_path)
    
    def preprocess(self, image_path):
        img = Image.open(image_path).convert("RGB")
        img = img.resize((384, 384))
        arr = np.asarray(img).astype(np.float32) / 255.0
        arr = np.transpose(arr, (2, 0, 1))  # (C, H, W)
        arr = np.expand_dims(arr, axis=0)  # (1, C, H, W)
        return arr
    
    def predict(self, image_path, threshold=0.5):
        input_array = self.preprocess(image_path)
        input_name = self.session.get_inputs()[0].name
        output = self.session.run(None, {input_name: input_array})
        logit = output[0][0][0]
        prob = 1 / (1 + np.exp(-logit))
        is_potato = prob > threshold
        return is_potato, prob

In [17]:
#example usage
if __name__ == "__main__":
    model_path = "./quantized_model.onnx"
    image_path = "./potato_dataset/not_potato/000016.jpg"
    
    predictor = PotatoPredictor(model_path)
    is_potato, prob = predictor.predict(image_path)
    
    if is_potato:
        print(f"The image is a potato with probability {prob:.2f}")
    else:
        print(f"The image is not a potato with probability {1- prob:.2f}")

The image is not a potato with probability 0.76
