In [1]:
import argparse
import cv2
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
from IPython.display import Image

import torch
from ultralytics.utils.files import increment_path
from ultralytics import YOLO

from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov8 import download_yolov8s_model

In [None]:
image_source = "/mnt/ai-storage/jira/DIP555/raw_panos/AIC_2181/pano_sub/frames"
model_path = "../sahi/models/20231231_yolov8x-albumentations.pt"

In [None]:
download_yolov8s_model(model_path)

In [None]:
# Check source path
if not Path(image_source).is_dir():
    raise NotADirectoryError(f"Source path '{image_source}' is not a directory.")

In [None]:
model = YOLO(model_path)

In [None]:
# Detect objects from classes 0 and 32 only
# classes = [0, 32]
# model.overrides["classes"] = classes

In [None]:
detection_model = AutoDetectionModel.from_pretrained(
    model_type="yolov8",
    model=model,
    confidence_threshold=0.5,
    device="cuda:0" if torch.cuda.is_available() else "cpu",
)

In [None]:
# Output setup
save_dir = increment_path(Path(image_source).parent / "results_sahi" / "exp", True)
save_dir

In [None]:
save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
image_files = list(Path(image_source).rglob("*.[jp][pn]g"))

if not image_files:
    raise FileNotFoundError(f"No image files found in: {image_source}")

In [None]:
player_detections = {}
ball_detections = {}

In [None]:
img_path = image_files[0]
img_path

In [None]:
results = get_sliced_prediction(
    str(img_path),
    detection_model,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
)

In [None]:
results.export_visuals(export_dir="sahi_sample/",
                    text_size=5,
                    rect_th=None,
                    hide_labels=True,
                    hide_conf=True,
                    file_name="custom_yolov8_prediction_visual",)
# Image("sahi_sample/prediction_visual.png")

In [None]:
visual = cv2.imread("sahi_sample/custom_yolov8_prediction_visual.png")
visual = cv2.cvtColor(visual, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,10))
plt.imshow(visual)

In [None]:
frame_number = img_path.stem
frame_number

In [None]:
# !pip install -U imantics

In [None]:
player_frame_detections = {}
ball_frame_detections = {}

for bboxid, detection in enumerate(results.object_prediction_list):
    if detection.category.name == 'player':
        player_frame_detections[bboxid] = {
                                            'bbox': [int(v) for v in detection.bbox.to_xyxy()],
                                            'score': round(detection.score.value, 5)
                                          }
    else:
        ball_frame_detections[bboxid] = {
                                            'bbox': [int(v) for v in detection.bbox.to_xyxy()],
                                            'score': round(detection.score.value, 5)
                                        }

In [None]:
player_detections[frame_number] = player_frame_detections
ball_detections[frame_number] = ball_frame_detections

In [None]:
debug = True

In [None]:
if debug:
    object_prediction_list = results.object_prediction_list
    boxes_list = []
    clss_list = []
    for ind, _ in enumerate(object_prediction_list):
        clss = object_prediction_list[ind].category.name
        boxes = (
            object_prediction_list[ind].bbox.minx,
            object_prediction_list[ind].bbox.miny,
            object_prediction_list[ind].bbox.maxx,
            object_prediction_list[ind].bbox.maxy,
        )

        boxes_list.append(boxes)
        clss_list.append(clss)

    frame = cv2.imread(str(img_path))

    # Create a copy of the original image to draw on
    frame_copy = frame.copy()

    for box, cls in zip(boxes_list, clss_list):
        x1, y1, x2, y2 = box
        cv2.rectangle(
            frame_copy, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2
        )
        label = str(cls)
        t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
        cv2.rectangle(
            frame_copy,
            (int(x1), int(y1) - t_size[1] - 3),
            (int(x1) + t_size[0], int(y1) + 3),
            (56, 56, 255) if label == "person" else (56, 255, 56),
            -1,
        )
        cv2.putText(
            frame_copy,
            label,
            (int(x1), int(y1) - 2),
            0,
            0.6,
            [255, 255, 255] if label == "person" else [0, 0, 0],
            thickness=1,
            lineType=cv2.LINE_AA,
        )

    frame_name = f"{frame_number}_dets.jpg"
    frame_path = save_dir / frame_name
    cv2.imwrite(
        str(frame_path),
        frame_copy,
    )

In [None]:
frame_path

In [None]:
output_dict = {
    "players": {},
    "ball": {},
    "debug": {
        "fps": fps,
        "image_h": height,
        "image_w": width,
        "model_name": args.model
    }
}


In [None]:
detections = {}

for img_path in image_files:
    results = get_sliced_prediction(
        str(img_path),
        detection_model,
        slice_height=512,
        slice_width=512,
        overlap_height_ratio=0.2,
        overlap_width_ratio=0.2,
    )

    frame_number = int(img_path.stem)

    object_prediction_list = [
        res.to_coco_annotation().json for res in results.object_prediction_list
    ]

    detections[frame_number] = object_prediction_list

    if debug:
        object_prediction_list = results.object_prediction_list
        boxes_list = []
        clss_list = []
        for ind, _ in enumerate(object_prediction_list):
            clss = object_prediction_list[ind].category.name
            boxes = (
                object_prediction_list[ind].bbox.minx,
                object_prediction_list[ind].bbox.miny,
                object_prediction_list[ind].bbox.maxx,
                object_prediction_list[ind].bbox.maxy,
            )

            boxes_list.append(boxes)
            clss_list.append(clss)

        frame = cv2.imread(str(img_path))

        # Create a copy of the original image to draw on
        frame_copy = frame.copy()

        for box, cls in zip(boxes_list, clss_list):
            x1, y1, x2, y2 = box
            cv2.rectangle(
                frame_copy, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2
            )
            label = str(cls)
            t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
            cv2.rectangle(
                frame_copy,
                (int(x1), int(y1) - t_size[1] - 3),
                (int(x1) + t_size[0], int(y1) + 3),
                (56, 56, 255) if label == "person" else (56, 255, 56),
                -1,
            )
            cv2.putText(
                frame_copy,
                label,
                (int(x1), int(y1) - 2),
                0,
                0.6,
                [255, 255, 255] if label == "person" else [0, 0, 0],
                thickness=1,
                lineType=cv2.LINE_AA,
            )

        frame_name = f"{frame_number:05d}_dets.jpg"
        frame_path = save_dir / frame_name
        cv2.imwrite(
            str(frame_path),
            frame_copy,
        )

#     if cv2.waitKey(1) & 0xFF == ord("q"):
#         break

# cv2.destroyAllWindows()

In [None]:
try:
    coco_out_path = f"{save_dir}/coco_results.pkl"
    print(f"Saving {coco_out_path}...")
    with open(coco_out_path, "wb") as f:
        pickle.dump(detections, f)
except Exception as e:
    print(e)
    print(f"Could not save {coco_out_path}")

print("Inference with SAHI is done.")
