In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import os

import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
from dotenv import load_dotenv
from ultralytics import YOLO

logger = logging.getLogger()
logger.setLevel(logging.INFO)

load_dotenv("../.envrc")

ROOT_DIR = os.getenv("ROOT_DIR")
PARAMS = os.getenv("PARAMS")
os.chdir(ROOT_DIR)

sns.set_style("darkgrid")

from src.libs.utils import calculate_intersection_over_union  # noqa: E402
from src.libs.visualize import perform_object_detection_and_plot_results  # noqa: E402

## Load model parameters and settings

In [None]:
def plot_image(
    image,
):
    plt.imshow(image)
    plt.axis(False)
    plt.grid(False)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # plt.savefig("snapshost.png", bbox_inches="tight", pad_inches=0)
    plt.show()


with open(PARAMS) as conf_file:
    config = yaml.safe_load(conf_file)

## Validate IOU (Intersection over Union)

In [None]:
selected_image_files = ["random-plug.png"]

image_directory = config["data_split"]["test_path"]
label_directory = config["data_split"]["test_labels"]

for i, image_file in enumerate(selected_image_files):
    logging.info(f"Processing image: {image_file}")

    current_image_path = os.path.join(image_directory, image_file)
    image = cv2.imread(current_image_path)
    if image is None:
        logging.info(f"Failed to load image: {image_file}")
        continue

    associated_label_file = os.path.splitext(image_file)[0] + ".txt"
    current_label_path = os.path.join(label_directory, associated_label_file)
    try:
        with open(current_label_path, "r") as f:
            labels = f.read().strip().split("\n")
            logging.info(f"Found {len(labels)} labels in file: {associated_label_file}")
    except FileNotFoundError:
        logging.info(f"Label file not found: {associated_label_file}")
        continue

    for label in labels:
        parts = label.split()
        if len(parts) != 5:
            logging.info(f"Invalid label format in file: {associated_label_file}")
            continue
        class_id, x_center, y_center, width, height = map(float, parts)
        x_min = int((x_center - width / 2) * image.shape[1])
        y_min = int((y_center - height / 2) * image.shape[0])
        x_max = int((x_center + width / 2) * image.shape[1])
        y_max = int((y_center + height / 2) * image.shape[0])
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (36, 255, 12), 1)
        # cv2.putText(image, 'Ground truth bbox', (x_min, y_min-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36,255,12), 2, cv2.LINE_AA) # noqa: E501

        ground_truth_bbox = (x_min, y_min, x_max, y_max)
        logging.info(f"Ground_truth_bbox: {ground_truth_bbox}")

    plot_image(image)

## Load the model

In [None]:
model_path = config["evaluate"]["best_model"]

if not os.path.isfile(model_path):
    logging.error(f"Model not found: {model_path}")

model = YOLO(model_path)


## Perform plug detection and localization

In [None]:
for image_path in enumerate(selected_image_files):
    detected_image, boxes, confidence = perform_object_detection_and_plot_results(
        model, os.path.join(image_directory, image_file)
    )
    cv2.rectangle(detected_image, ground_truth_bbox[:2], ground_truth_bbox[2:], (36, 255, 12), 1)
    plot_image(detected_image)

In [None]:
iou = calculate_intersection_over_union(tuple(boxes[0]), ground_truth_bbox)

logging.info(f"IOU: {round(iou, 2)}")