In [4]:
import onnxruntime as rt
import argparse
from imutils import paths
import cv2
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from PIL import Image

In [5]:
runID = "66a5e35cfbde4006bda9e3325e4a4ae8"
image_shape = [224, 224]
threshold = 0.85
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [6]:
test_transforms = transforms.Compose(
	[
		transforms.Resize([224, 224]),
		transforms.ToTensor(),
		transforms.Normalize(mean, std),
	]
)

In [11]:
# onnx_model = f"mlruns/{opt.runID}/artifacts/model/data/last.onnx"
onnx_model = "mlruns/66a5e35cfbde4006bda9e3325e4a4ae8/artifacts/model/data/last.onnx"
sess = rt.InferenceSession(onnx_model)
print("====INPUT====")
for i in sess.get_inputs():
	print(f"Name: {i.name}, Shape: {i.shape}, Dtype: {i.type}")
print("====OUTPUT====")
for i in sess.get_outputs():
	print(f"Name: {i.name}, Shape: {i.shape}, Dtype: {i.type}")

====INPUT====
Name: input, Shape: ['batch', 3, 224, 224], Dtype: tensor(float)
====OUTPUT====
Name: output, Shape: ['batch', 2], Dtype: tensor(float)


In [12]:
imagePaths = sorted(list(paths.list_images("../datasets/customer_staff_20220816")))

correct = 0
staff_correct = 0
customer_correct = 0
total_sample = len(imagePaths)
total_staff = 0
total_customer = 0

for i in range(total_sample):
	true_label = imagePaths[i].split("/")[-2]
	if true_label == "staff":
		total_staff += 1
	elif true_label == "customer":
		total_customer += 1
	ori = cv2.imread(imagePaths[i])
	img = Image.fromarray(cv2.cvtColor(ori, cv2.COLOR_BGR2RGB))
	img = test_transforms(img)
	if img.ndimension() == 3:
		img = torch.unsqueeze(img, 0)

	pred = sess.run(None, {"input": img.numpy()})[0][0].tolist()
	pred_label = "customer"
	score = pred[0]
	if pred[1] > threshold:
		pred_label = "staff"
		score = pred[1]

	if true_label == pred_label:
		correct += 1

	if true_label == pred_label == "staff":
		staff_correct += 1
	if true_label == pred_label == "customer":
		customer_correct += 1

	# plt.title(f"{label}_{score}")
	# plt.imshow(ori)
	# plt.show()

accuracy = round((correct / total_sample)*100, 2)
staff_accuracy = round((staff_correct / total_staff)*100, 2)
customer_accuracy = round((customer_correct / total_customer)*100, 2)
print(f"Total Accuracy = {accuracy}%")
print(f"Staff Accuracy = {staff_accuracy}%")
print(f"Customer Accuracy = {customer_accuracy}%")

Total Accuracy = 98.54
Staff Accuracy = 98.53
Customer Accuracy = 98.56
