In [1]:
import os
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as transforms

# Define paths
base_dir = '/mnt/c/Users/rober/Desktop/GazeDetectionStudy'
real_eyes_dir = os.path.join(base_dir, 'dataset', 'real_eyes_for_testing')
model_save_dir = os.path.join(base_dir, 'models', 'alexnet')
onnx_model_path = os.path.join(model_save_dir, 'alexnet_gaze.onnx')
output_image_path = os.path.join(model_save_dir, 'inference_results.png')

# Ensure the output directory exists
os.makedirs(model_save_dir, exist_ok=True)

# Debug: Print paths to verify correctness
print(f"Base directory: {base_dir}")
print(f"Model save directory: {model_save_dir}")
print(f"ONNX model path: {onnx_model_path}")
print(f"Real eyes directory: {real_eyes_dir}")

# Verify that the ONNX model file exists
if not os.path.isfile(onnx_model_path):
    print(f"ONNX model file not found at {onnx_model_path}")
    exit(1)

# Classes
class_names = ['Up', 'Down', 'Center']

# Data transformation (should match the one used during training/testing)
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalization
                         std=[0.229, 0.224, 0.225])
])

# Load the ONNX model
session = ort.InferenceSession(onnx_model_path)

# Collect images and labels
images = []
labels = []
predictions = []
probabilities = []

# Font for labeling images
try:
    font = ImageFont.truetype("arial.ttf", size=16)
except IOError:
    font = ImageFont.load_default()

# Process each class
for cls in class_names:
    class_dir = os.path.join(real_eyes_dir, cls)
    image_files = os.listdir(class_dir)
    image_files.sort()
    # Take the first 10 images
    image_files = image_files[:10]
    for image_file in image_files:
        image_path = os.path.join(class_dir, image_file)
        # Load image
        image = Image.open(image_path).convert('RGB')
        # Ensure image is 224x224
        image = image.resize((224, 224))
        # Convert to RGBA for transparency
        display_image = image.copy().convert('RGBA')
        # Preprocess
        input_image = data_transforms(image)
        input_image = input_image.unsqueeze(0).numpy()
        # Run inference
        outputs = session.run(None, {'input': input_image})
        output = outputs[0][0]
        # Get prediction
        predicted_idx = np.argmax(output)
        predicted_class = class_names[predicted_idx]
        # Get probability using softmax
        prob = np.exp(output[predicted_idx]) / np.sum(np.exp(output))
        prob_percent = prob * 100
        # Annotate image
        draw = ImageDraw.Draw(display_image)
        text = f'Pred: {predicted_class} ({prob_percent:.1f}%)'
        # Positioning text at bottom-left corner
        # Calculate text size using getbbox
        text_bbox = font.getbbox(text)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        text_position = (5, 224 - text_height - 5)
        # Draw a semi-transparent rectangle behind the text for better visibility
        rectangle_position = (text_position[0] - 2, text_position[1] - 2,
                              text_position[0] + text_width + 2, text_position[1] + text_height + 2)
        draw.rectangle(rectangle_position, fill=(255, 255, 255, 128))
        # Draw text
        draw.text(text_position, text, fill='red', font=font)
        # Convert back to RGB
        display_image = display_image.convert('RGB')
        # Store
        images.append(display_image)
        labels.append(cls)
        predictions.append(predicted_class)
        probabilities.append(prob_percent)

# Compute accuracy per class
correct_counts = {cls: 0 for cls in class_names}
total_counts = {cls: 0 for cls in class_names}

for true_label, predicted_label in zip(labels, predictions):
    total_counts[true_label] += 1
    if true_label == predicted_label:
        correct_counts[true_label] += 1

# Create a grid of images: 3 rows (classes), 10 columns (images) + 1 column for labels
grid_rows = len(class_names)
grid_cols = 10
image_width, image_height = 224, 224
label_width = 200  # Width for the label area at the beginning of each row

# Create a new image with white background
grid_image = Image.new('RGB', (label_width + grid_cols * image_width, grid_rows * image_height), color=(255, 255, 255))

for row in range(grid_rows):
    cls = class_names[row]
    accuracy = 100 * correct_counts[cls] / total_counts[cls] if total_counts[cls] > 0 else 0
    # Create label image for the row
    label_text = f'Class: {cls}\nAccuracy: {accuracy:.1f}%'
    label_image = Image.new('RGB', (label_width, image_height), color=(255, 255, 255))
    draw = ImageDraw.Draw(label_image)
    # Calculate text size and position for centering
    text_lines = label_text.split('\n')
    # Calculate total text height using font metrics
    ascent, descent = font.getmetrics()
    line_height = ascent + descent
    total_text_height = len(text_lines) * line_height
    y_text = (image_height - total_text_height) // 2
    for line in text_lines:
        # Calculate text width using getbbox
        bbox = font.getbbox(line)
        text_width = bbox[2] - bbox[0]
        x_text = (label_width - text_width) // 2
        draw.text((x_text, y_text), line, fill='black', font=font)
        y_text += line_height
    # Paste the label image
    grid_image.paste(label_image, (0, row * image_height))
    # Paste the 10 images
    for col in range(grid_cols):
        idx = row * grid_cols + col
        if idx < len(images):
            grid_image.paste(images[idx], (label_width + col * image_width, row * image_height))

# Save the grid image
grid_image.save(output_image_path)
print(f"Inference results saved to {output_image_path}")

# Print Accuracy per class
print("Accuracy per class:")
for cls in class_names:
    accuracy = 100 * correct_counts[cls] / total_counts[cls] if total_counts[cls] > 0 else 0
    print(f"{cls}: {accuracy:.1f}%")

overall_accuracy = 100 * sum(correct_counts.values()) / sum(total_counts.values())
print(f"Overall Accuracy: {overall_accuracy:.1f}%")

# Save the accuracy results to a text file
accuracy_results_path = os.path.join(model_save_dir, 'inference_accuracy.txt')
with open(accuracy_results_path, 'w') as f:
    f.write("Accuracy per class:\n")
    for cls in class_names:
        accuracy = 100 * correct_counts[cls] / total_counts[cls] if total_counts[cls] > 0 else 0
        f.write(f"{cls}: {accuracy:.1f}%\n")
    f.write(f"Overall Accuracy: {overall_accuracy:.1f}%\n")
print(f"Accuracy results saved to {accuracy_results_path}")


Base directory: /mnt/c/Users/rober/Desktop/GazeDetectionStudy
Model save directory: /mnt/c/Users/rober/Desktop/GazeDetectionStudy/models/alexnet
ONNX model path: /mnt/c/Users/rober/Desktop/GazeDetectionStudy/models/alexnet/alexnet_gaze.onnx
Real eyes directory: /mnt/c/Users/rober/Desktop/GazeDetectionStudy/dataset/real_eyes_for_testing
Inference results saved to /mnt/c/Users/rober/Desktop/GazeDetectionStudy/models/alexnet/inference_results.png
Accuracy per class:
Up: 0.0%
Down: 100.0%
Center: 0.0%
Overall Accuracy: 33.3%
Accuracy results saved to /mnt/c/Users/rober/Desktop/GazeDetectionStudy/models/alexnet/inference_accuracy.txt
