[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_yolov5.ipynb)

## 0. Preperation

- Install latest version of SAHI and YOLOv5:
- pip가 아닌 현재 sahi, yolov5 소스 기반에서 install을 해야만 한다 (커스텀 코드 변경)

In [1]:
#!pip install -U sahi yolov5 scikit-image imagecodecs pycocotools

In [2]:
import os
os.getcwd()

'/home/yyj/sahi/demo'

## 1. Export frame

- 데이터를 영상으로 받았을 경우, 매 every 변수마다 프레임을 추출해서 저장

In [7]:
import cv2
import os

def extract_frame(video_path, frame_dir, overwrite=False, start=-1, end=-1, every=1):
    """
    Extract frames from a video using OpenCVs VideoCapture
    :param video_path: path of the video
    :param frames_dir: the directory to save the frames
    :param overwrite: to overwrite frames that already exist?
    :param start: start frame
    :param end: end frame
    :param every: frame spacing
    :return: count of images saved
    """
    
    video_path = os.path.normpath(video_path)
    frame_dir = os.path.normpath(frame_dir)
    
    video_dir, video_filename = os.path.split(video_path)
    video_filename = os.path.splitext(video_filename)[0]
    assert os.path.exists(video_path)
    
    capture = cv2.VideoCapture(video_path)
    
    if start < 0:
        start = 0
    if end < 0:
        end = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    
    capture.set(1, start)
    frame = start
    while_safety = 0
    saved_count = 0 
    
    while frame < end:
        _, image = capture.read()
        
        if while_safety > 10:  # break the while if our safety maxs out at 10
            break
            
        if image is None:
            while_safety += 1
            continue
            
        if frame % every == 0:
            while_safety = 0
            
            save_path = os.path.join(frame_dir, video_filename, "{:010d}.jpg".format(frame))
            
            if not os.path.exists(os.path.join(frame_dir, video_filename)):
                os.makedirs(os.path.join(frame_dir, video_filename))
            if not os.path.exists(save_path) or overwrite:
                cv2.imwrite(save_path, image)
                saved_count += 1
                
        frame += 1
        
    capture.release()
        
    return saved_count;

In [8]:
!ls ../resources/

DSL		 fll_221016.mp4       log_221017_2
export_frame	 FLL_VAL	      log_221017_2_narrow
FL_221017_2.mp4  FLL_VAL_OLD	      models
FL_221017.mp4	 FLL_VAL.zip	      new_train_from_val.zip
FL_221021.mp4	 hf_spaces_badge.svg  save_test
FL_221022_2.mp4  lidar01.mp4	      save_val
FL_221022.mp4	 lidar02.mp4	      sliced_inference.gif
fll_221014.mp4	 log
fll_221015.mp4	 log_221017


In [10]:
video_path = "../resources/FL_221022_2.mp4"
frame_dirs = "../resources/export_frame"

extract_frame(video_path, frame_dirs, every=30)

1025

## 3. Import required modules

In [None]:
# arrange an instance segmentation model for test
from sahi.utils.yolov5 import (
    download_yolov5s6_model,
)

# import required functions, classes
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict
from IPython.display import Image
from pathlib import Path

In [None]:
coco_m_path = '../resources/models/yolov5m.pt'
fll_model_221007_path = '../resources/models/221007/best.pt'
fll_model_221012_path = '../resources/models/221012/best.pt'
fll_model_221014_path = '../resources/models/221014/best.pt'

## 4. Standard Inference with a YOLOv5 Model

- Instantiate a detection model by defining model weight path and other parameters:

In [None]:
coco_base_model = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=coco_m_path,
    confidence_threshold=0.25,
    device="cuda:0"
)

fll_model_221007 = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=fll_model_221007_path,
    confidence_threshold=0.25,
    device="cuda:0"
)

fll_model_221012 = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=fll_model_221012_path,
    confidence_threshold=0.25,
    device="cuda:0"
)

fll_model_221014 = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=fll_model_221014_path,
    confidence_threshold=0.25,
    device="cuda:0"
)

In [None]:
model = fll_model_221014
model_path = fll_model_221014_path

- Perform prediction by feeding the get_prediction function with an image path and a DetectionModel instance:

In [None]:
# WANGSAN test image path
test_image_path = "../resources/export_frame/lidar01/0000018960.jpg"

In [None]:
result = get_prediction(test_image_path, model)

- Or perform prediction by feeding the get_prediction function with a numpy image and a DetectionModel instance:

In [None]:
# result = get_prediction(read_image(test_image_path), detection_model)

- Visualize predicted bounding boxes and masks over the original image:

In [None]:
result = get_prediction(read_image(test_image_path), fll_model_221007)
result.export_visuals(export_dir="demo_data/", file_name="prediction_visual1")

Image("demo_data/prediction_visual1.png")

## 4-1. Total Image

In [None]:
model_type = "yolov5"
model_device = "cuda:0" # or 'cpu'
model_confidence_threshold = 0.4

slice_height = 512
slice_width = 512
overlap_height_ratio = 0.2
overlap_width_ratio = 0.2

source_image_dir = "../resources/FLL_VAL/images/"
source_label_dir = "../resources/FLL_VAL/labels/"

no_sliced_prediction = False
no_standard_prediction = False
custom_slice_mode=1
custom_slice_x_start=200
custom_slice_y_start=200

- 함수 CALL 테스트

In [None]:
"""
dataset_json_path: str
    If coco file path is provided, detection results will be exported in coco json format.
"""

result = predict(
    model_type=model_type,
    model_path=model_path,
    model_device=model_device,
    model_confidence_threshold=model_confidence_threshold,
    no_sliced_prediction=no_sliced_prediction,
    no_standard_prediction=no_standard_prediction,
    source=source_image_dir,
    slice_height=slice_height,
    slice_width=slice_width,
    overlap_height_ratio=overlap_height_ratio,
    overlap_width_ratio=overlap_width_ratio,
    custom_slice_mode=custom_slice_mode,
    custom_slice_x_start=custom_slice_x_start,
    custom_slice_y_start=custom_slice_y_start,
    return_dict=True
)

## 4-2. Interact show_samples

In [None]:
image_files = sorted([fn for fn in os.listdir(source_image_dir) if fn.endswith("jpg")])
label_files = sorted([fn for fn in os.listdir(source_label_dir) if fn.endswith("txt")])

In [None]:
len(image_files), len(label_files)

### draw bbox util

In [None]:
from sahi.utils.cv import Colors
import numpy as np
import copy

def visualize_object_predictions(
    image: np.array,
    object_prediction_list,
    rect_th: int = None,
    text_size: float = None,
    text_th: float = None,
    color: tuple = None,
):
    """
    Visualizes prediction category names, bounding boxes over the source image
    and exports it to output folder.
    Arguments:
        object_prediction_list: a list of prediction.ObjectPrediction
        rect_th: rectangle thickness
        text_size: size of the category name over box
        text_th: text thickness
        color: annotation color in the form: (0, 255, 0)
        output_dir: directory for resulting visualization to be exported
        file_name: exported file will be saved as: output_dir+file_name+".png"
        export_format: can be specified as 'jpg' or 'png'
    """
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.001), 1)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3
    # add bbox and mask to image if present
    for object_prediction in object_prediction_list:
        # deepcopy object_prediction_list so that original is not altered
        object_prediction = object_prediction.deepcopy()

        bbox = object_prediction.bbox.to_voc_bbox()
        category_name = object_prediction.category.name
        score = object_prediction.score.value

        # set color
        if colors is not None:
            color = colors(object_prediction.category.id)
        # visualize masks if present
        if object_prediction.mask is not None:
            # deepcopy mask so that original is not altered
            mask = object_prediction.mask.bool_mask
            # draw mask
            rgb_mask = apply_color_mask(mask, color)
            image = cv2.addWeighted(image, 1, rgb_mask, 0.4, 0)
        # set bbox points
        p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
        # visualize boxes
        cv2.rectangle(
            image,
            p1,
            p2,
            color=color,
            thickness=rect_th
        )
        # arange bounding box text location
        label = f"{category_name} {score:.2f}"
        w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]  # label width, height
        outside = p1[1] - h - 3 >= 0  # label fits outside box
        p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
        # add bounding box text
        cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(
            image,
            label,
            (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
            0,
            text_size,
            (255, 255, 255),
            thickness=text_th,
        )
        
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
    return image

### FULL INFERENCE

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1))
def show_sample(index=0):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    result = get_prediction(image_path, model)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

### FULL INFERENCE: Add confidence_threshold with interact

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1), confidence_score=(0, 1, 0.05))
def show_sample(index=0, confidence_score=0.25):
    sample_model = AutoDetectionModel.from_pretrained(
        model_type='yolov5',
        model_path=model_path,
        confidence_threshold=confidence_score,
        device="cuda:0"
    )

    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    result = get_prediction(image_path, sample_model)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()
    

## 4-3. Interact show_samples with no nms AutoShape 

In [None]:
from yolov5.models.common import AutoShape, DetectMultiBackend
from yolov5.utils.torch_utils import select_device

device="cuda:0"
autoshape=True

device = select_device(device)  # detection model
AutoShapeModel = AutoShape(DetectMultiBackend(model_path, device=device, fuse=autoshape))  # for file/URI/PIL/cv2/np inputs and NMS

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1))
def show_sample(index=0):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    
    result = AutoShapeModel(image_path, nms=False)
    canvas = result.render()[0]
    
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

### NMS와 NON_NMS 비교

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1))
def show_sample(index=0):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    print("Image Path:", image_path)
    result_with_nms = get_prediction(image_path, model)

    canvas1 = visualize_object_predictions(image, result_with_nms.object_prediction_list)
    
    result_with_no_nms = AutoShapeModel(image_path, nms=False)
    
    canvas2 = result_with_no_nms.render()[0]
    
    fig, axes = plt.subplots(1, 2, figsize=(32,32)) 
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    
    axes[0].axis('off')
    axes[1].axis('off')
    
    axes[0].imshow(canvas1)
    axes[1].imshow(canvas2)
    
    plt.axis('off')
    plt.show()

## 4-4. Interact show_samples with slice prediction

### Single Row가 잘 잘라지는지에 대한 검사, Shape도 출력

In [None]:
custom_slice_mode=2
custom_slice_x_start=640
custom_slice_y_start=360
slice_size=512

In [None]:
from typing import List
import copy
import matplotlib.pyplot as plt
import cv2
%matplotlib inline 
from sahi import slicing
from sahi.slicing import slice_image

slicing.logger.setLevel(slicing.logging.INFO)

# single_row_y_start: int = 200,
@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), single_row_y_start=(0, 512))
def visualize_slice_rect(index=0, slice_size=512, overlap_ratio=0.2, single_row_y_start=200):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    
    res = slice_image(image_path, 
                      slice_width=slice_size,
                      slice_height=slice_size,
                      overlap_height_ratio=overlap_ratio,
                      overlap_width_ratio=overlap_ratio,
                      custom_slice_mode=custom_slice_mode,
                      custom_slice_x_start=custom_slice_x_start,
                      custom_slice_y_start=custom_slice_y_start,
                      verbose=1)

    image = cv2.imread(image_path)
    image = copy.deepcopy(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    for start_pixel in res.starting_pixels:
        cv2.rectangle(image,
                      start_pixel,
                      [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
                      color=(255, 0, 0),
                      thickness=2)
    
    for s_im in res.images:
        print(s_im.shape)
    
    plt.figure(figsize=(16,16))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

### Draw slice image

In [None]:
from typing import List
import copy
import matplotlib.pyplot as plt
import cv2
%matplotlib inline 
from sahi import slicing
from sahi.slicing import slice_image

slicing.logger.setLevel(slicing.logging.INFO)

# single_row_y_start: int = 200,
@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), single_row_y_start=(0, 512))
def visualize_slice_rect(index=0, slice_size=512, overlap_ratio=0.2, single_row_y_start=200):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    
    res = slice_image(image_path, 
                      slice_width=slice_size,
                      slice_height=slice_size,
                      overlap_height_ratio=overlap_ratio,
                      overlap_width_ratio=overlap_ratio,
                      custom_slice_mode=custom_slice_mode,
                      custom_slice_x_start=custom_slice_x_start,
                      custom_slice_y_start=custom_slice_y_start,
                      verbose=1)
    
    print(len(res.images))
    fig, axes = plt.subplots(nrows=len(res.images)//2+1,ncols=2,figsize=(12,16))
    
#     plt.subplots_adjust(left=0.05, bottom=0.01, right=0.99, 
#                     top=0.99, wspace=None, hspace=0.2)
    
    ax = axes.flatten()
    
    for img_idx, s_im in enumerate(res.images, 0):
#         s_im = cv2.cvtColor(s_im, cv2.COLOR_BGR2RGB)
        ax[img_idx].imshow(s_im)
        ax[img_idx].axis('off')
        
    plt.show()

### Interact slice inference with slice_size, overlap_ratio, single_row_y_start

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), only_full_inference=(0,1))
def show_sample(index=0, slice_size=512, overlap_ratio=0.2, only_full_inference=0):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    if not only_full_inference:
        slice_result = slice_image(image_path, 
                                  slice_width=slice_size,
                                  slice_height=slice_size,
                                  overlap_height_ratio=overlap_ratio,
                                  overlap_width_ratio=overlap_ratio,
                                  custom_slice_mode=custom_slice_mode,
                                  custom_slice_x_start=custom_slice_x_start,
                                  custom_slice_y_start=custom_slice_y_start,
                                  verbose=1)

        for start_pixel in slice_result.starting_pixels:
            cv2.rectangle(image,
                          start_pixel,
                          [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
                          color=(255, 255, 0),
                          thickness=2)
        
        result = get_sliced_prediction(image_path,
                                       model,
                                       slice_height=slice_size,
                                       slice_width=slice_size,
                                       postprocess_match_threshold=0.5,
                                       overlap_height_ratio=overlap_ratio,
                                       overlap_width_ratio=overlap_ratio,
                                       custom_slice_mode=custom_slice_mode,
                                       postprocess_type="GREEDYNMM",
                                       custom_slice_x_start=custom_slice_x_start,
                                       custom_slice_y_start=custom_slice_y_start
                                      )
    else:
        result = get_prediction(image_path, model)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

## Show Yolo Input Image (LETTERBOX)

In [None]:
# Pre-process
from pathlib import Path
from PIL import Image
from yolov5.utils.dataloaders import exif_transpose, letterbox
from yolov5.utils.general import make_divisible
import torch
import random
from glob import glob

ims =  random.choice(glob(source_image_dir+"/*.jpg"))
size = (640, 640)
device = "cuda:0"

n, ims = (1, [ims])  # number, list of images

shape1 = [] # image and inference shapes, filenames
for i, im in enumerate(ims):
    f = f'image{i}'  # filename
    if isinstance(im, (str, Path)):  # filename or uri
        im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
        im = np.asarray(exif_transpose(im))
    elif isinstance(im, Image.Image):  # PIL Image
        im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
    if im.shape[0] < 5:  # image in CHW
        im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
    im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)  # enforce 3ch input
    
    s = im.shape[:2]  # HWC
    g = max(size) / max(s)  # gain
    shape1.append([y * g for y in s])
    ims[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update

shape1 = [make_divisible(x, 32) for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims]  # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2)))  # stack and BHWC to BCHW
x = torch.from_numpy(x).to(torch.device(device)).type(torch.float32) / 255  # uint8 to fp16/32

plt.imshow(x[0].permute(1,2,0).cpu())

## 5. SAHI Benchmark

### 5-1. Create Val.json (GT), Start Category id "0"

In [None]:
import json

In [None]:
def initial_extract(img_dir, label_dir, out_dir):
    if os.path.exists(os.path.join(out_dir, 'val.json')):
        os.remove(os.path.join(out_dir, 'val.json'))
    
    licenses = [
        {
            "name": "",
            "id": 0,
            "url": ""
        }
    ]

    info_ = [
        {
            "contributor": "",
            "date_created": "",
            "description": "",
            "url": "",
            "version": "",
            "year": ""
        }
    ]

    categories = [
        {
            "id": 0,
            "name": "Buoy",
            "supercategory": ""
        },
        {
            "id": 1,
            "name": "Boat",
            "supercategory": ""
        },
        {
            "id": 2,
            "name": "Channel Marker",
            "supercategory": ""
        },
        {
            "id": 3,
            "name": "Speed Warning Sign",
            "supercategory": ""
        }
    ]

    img_idx = 0
    annot_idx = 0

    imgs_list = []
    annots_list = []

    for label_file in sorted(os.listdir(label_dir)):
        label_file_ = os.path.join(label_dir, label_file)
        img_file_ = os.path.join(img_dir, f'{os.path.splitext(label_file)[0]}.jpg')
        img = Image.open(img_file_)
        image_w, image_h = img.size

        imgs_list.append({
            'id': img_idx,
            'width': image_w,
            'height': image_h,
            'file_name': f'{os.path.splitext(label_file)[0]}.jpg',
            "license": 0,
            "flickr_url": "",
            "coco_url": "",
            "date_captured": 0
        })

        with open(label_file_, 'r') as label_f:
            labels = label_f.readlines()

            for label in labels:
                cat, xc, yc, label_normalized_w, label_normalized_h = list(map(lambda x: int(x) if len(x) == 1 else float(x), label.split()))
                label_w, label_h = image_w * label_normalized_w, image_h * label_normalized_h
                xmin, ymin = (image_w * xc) - (label_w / 2), (image_h * yc) - (label_h / 2)
                
                xmin = 0 if xmin < 0 else xmin
                ymin = 0 if ymin < 0 else ymin

                annots_list.append({
                    'id': annot_idx,
                    'image_id': img_idx,
                    'category_id': cat,
                    'area': int(label_h * label_w),
                    'bbox': [
                        xmin,
                        ymin,
                        label_w,
                        label_h
                    ],
                    'iscrowd': 0,
                    'attributes': {
                        'type': '',
                        'occluded': False
                    },
                    'segmentation': []
                })

                annot_idx += 1

        img_idx += 1

    out_dict = {
        'licenses': licenses,
        'info': info_,
        'categories': categories,
        'images': imgs_list,
        'annotations': annots_list
    }
    
    if not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)
    
    with open(os.path.join(out_dir, 'val.json'), 'w') as out_f:
        print(os.path.join(out_dir, 'val.json'))
        json.dump(out_dict, out_f)
        
    return os.path.join(out_dir, 'val.json')

In [None]:
# initial_extract(img_dir, label_dir, out_dir)
gt_json_path = initial_extract(source_image_dir, source_label_dir, str(Path(source_image_dir).parent))

### GT val json으로 만든 것으로 BBOX 그리기

In [None]:
from pycocotools.coco import COCO
plt.rcParams['figure.figsize'] = (16, 8)
coco=COCO(gt_json_path)
y_offset=-20

@interact(index=(0, len(image_files)-1))
def draw_gt_bbox(index=0):
    fig, ax = plt.subplots()
    img = coco.loadImgs(ids=[index])[-1]
    I = cv2.imread(os.path.join(source_image_dir, img['file_name']))
    I = cv2.cvtColor(I, cv2.COLOR_BGR2RGB)
    ax.imshow(I); plt.axis('off')
    annIds = coco.getAnnIds(imgIds=img['id'], iscrowd=None)
    anns = coco.loadAnns(annIds)
    coco.showAnns(anns, draw_bbox=True)
    for i, ann in enumerate(anns):
        ax.text(anns[i]['bbox'][0], anns[i]['bbox'][1]+y_offset, anns[i]['category_id'], style='italic', 
                bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})

### 5-2. Predict: Parameter Sweep 

In [None]:
from pathlib import Path

from sahi.predict import predict
from sahi.scripts.coco_error_analysis import analyse
from sahi.scripts.coco_evaluation import evaluate

MODEL_TYPE = "yolov5"
MODEL_PATH = model_path
MODEL_CONFIG_PATH = ""
EVAL_IMAGES_FOLDER_DIR = source_image_dir
EVAL_DATASET_JSON_PATH = gt_json_path
INFERENCE_SETTING = "AVIKUS_FL"
EXPORT_VISUAL = False
MAX_DETECTIONS = 500

In [None]:
INFERENCE_SETTING_TO_PARAMS = {
    "AVIKUS_FL": {
        "no_standard_prediction": False,
        "no_sliced_prediction": False,
        "slice_size": 512,
        "overlap_ratio": 0.15,
        "match_threshold": 0.5,
        "postprocess_class_agnostic": False,
        "single_row_y_start": 200,
    },
}

setting_params = INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]

In [None]:
result = predict(
    model_type=MODEL_TYPE,
    model_path=MODEL_PATH,
    model_config_path=MODEL_CONFIG_PATH,
#     model_confidence_threshold=0.01,
    model_device="cuda:0",
    source=EVAL_IMAGES_FOLDER_DIR,
    no_standard_prediction=setting_params["no_standard_prediction"],
    no_sliced_prediction=setting_params["no_sliced_prediction"],
#     image_size=None,
    slice_height=setting_params["slice_size"],
    slice_width=setting_params["slice_size"],
    overlap_height_ratio=setting_params["overlap_ratio"],
    overlap_width_ratio=setting_params["overlap_ratio"],
    postprocess_type="GREEDYNMM",
    postprocess_match_metric="IOS",
#     postprocess_type="NMS",
#     postprocess_match_metric="IOU",
    postprocess_match_threshold=setting_params["match_threshold"],
    postprocess_class_agnostic=setting_params["postprocess_class_agnostic"],
    novisual=not EXPORT_VISUAL,
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    project="runs/mAP_TEST",
    name=INFERENCE_SETTING,
    visual_bbox_thickness=None,
    visual_text_size=None,
    visual_text_thickness=None,
    visual_export_format="png",
    verbose=1,
    return_dict=True,
    force_postprocess_type=True,
    single_row_predict=True,
    single_row_y_start=setting_params["single_row_y_start"]
)

result_json_path = str(Path(result["export_dir"]) / "result.json")

In [None]:
evaluate_dict = evaluate(
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    result_json_path=result_json_path,
    classwise=True,
    max_detections=MAX_DETECTIONS,
    return_dict=True,
)

analyse_dict = analyse(
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    result_json_path=result_json_path,
    max_detections=MAX_DETECTIONS,
    return_dict=True,
)

### 5-3. Draw TP, FP, FN Bbox

In [None]:
from yolov5.utils.metrics import box_iou

In [None]:
def process_one_image(preds, targets, match_metric):
    tp_boxes = None
    fp_boxes = None
    fn_boxes = None
    iou = box_iou(targets[:, 1:], preds[:, :4])
    correct_class = targets[:, 0:1] == preds[:, 5]
    x = torch.where((iou >= match_metric) & (correct_class))
    if x[0].numel():
        matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
        if x[0].shape[0] > 1:
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
        tp_boxes = preds[matches[:, 1]]
        fp_idxs = [idx for idx in range(preds.shape[0]) if idx not in matches[:, 1]]
        fp_boxes = preds[fp_idxs]
        fn_idxs = [idx for idx in range(targets.shape[0]) if idx not in matches[:, 0]]
        fn_boxes = targets[fn_idxs]
    else:
        fn_boxes = targets
    return tp_boxes if tp_boxes is not None else torch.tensor([]), fp_boxes if fp_boxes is not None else torch.tensor(
        []), fn_boxes if fn_boxes is not None else torch.tensor([])


def eval_one_image(img, preds, targets, conf_metric=0.25, match_metrics=0.5):
    """
    img: torch.Tensor
    preds: torch.Tensor[x1, y1, x2, y2, conf, class_id]
    targets: torch.Tensor[class_id, x1, y1, x2, y2]
    """
    
#     (save_dir / 'tps').mkdir(parents=True, exist_ok=True)
#     (save_dir / 'fps').mkdir(parents=True, exist_ok=True)
#     (save_dir / 'fns').mkdir(parents=True, exist_ok=True)
#     (save_dir / 'combined').mkdir(parents=True, exist_ok=True)
#     preds = preds[preds[:, 4] > conf_metric]
    tps, fps, fns = process_one_image(preds, targets, match_metrics)
    
    rect_th = max(round(sum(img.shape) / 2 * 0.001), 1)

    # visualize boxes    
    if tps.numel():
        for tp in tps:
            x1, y1, x2, y2, conf, class_id = tp
            p1, p2 = (int(x1), int(y1)), (int(x2), int(y2))
            cv2.rectangle(
                img,
                p1,
                p2,
                color=(255,0,0),
                thickness=rect_th
            )
            label = f"{int(class_id)} {conf:.2f}"
            text_th = max(rect_th - 1, 3)
            text_size = rect_th / 3
            w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
            # label fits outside box
            outside = p1[1] - h - 3 >= 0
            
            cv2.putText(
                img,
                label,  
                (p1[0], p1[1] - 12 if outside else p1[1] + h + 2),
                0,
                text_size,
                (255, 255, 255),
                thickness=text_th,
            )
            
    if fps.numel():
        for fp in fps:
            x1, y1, x2, y2, conf, class_id = fp
            p1, p2 = (int(x1), int(y1)), (int(x2), int(y2))
            cv2.rectangle(
                img,
                p1,
                p2,
                color=(0,0,255),
                thickness=rect_th
            )
            label = f"{int(class_id)} {conf:.2f}"
            text_th = max(rect_th - 1, 3)
            text_size = rect_th / 3
            w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
            # label fits outside box
            outside = p1[1] - h - 3 >= 0
            
            cv2.putText(
                img,
                label,  
                (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                0,
                text_size,
                (255, 255, 255),
                thickness=text_th,
            )
            
    if fns.numel():
        for fn in fns:
            class_id, x1, y1, x2, y2 = fn
            p1, p2 = (int(x1), int(y1)), (int(x2), int(y2))
            cv2.rectangle(
                img,
                p1,
                p2,
                color=(0,255,0),
                thickness=rect_th
            )

    return img

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), single_row_y_start=(0, 512), only_full_inference=(0,1))
def show_sample(index=1780, slice_size=512, overlap_ratio=0.2, single_row_y_start=200, only_full_inference=0):
    image_file = image_files[index]
    label_path = os.path.join(source_label_dir, label_files[index])
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    image_h, image_w, _ = image.shape
    
    
    if not only_full_inference:
        slice_result = slice_image(image_path, 
                                  slice_width=slice_size,
                                  slice_height=slice_size,
                                  overlap_height_ratio=overlap_ratio,
                                  overlap_width_ratio=overlap_ratio,
                                  single_row_y_start=single_row_y_start,
                                  single_row_predict=True,
                                  verbose=2)


        for start_pixel in slice_result.starting_pixels:
            cv2.rectangle(image,
                          start_pixel,
                          [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
                          color=(255, 255, 0),
                          thickness=2)

        result = get_sliced_prediction(image_path,
                                       model,
                                       slice_height=slice_size,
                                       slice_width=slice_size,
                                       postprocess_match_threshold=0.5,
                                       overlap_height_ratio=overlap_ratio,
                                       overlap_width_ratio=overlap_ratio,
                                       single_row_y_start=single_row_y_start,
                                       single_row_predict=True,
                                       verbose=2
                                      )
    else:
        result = get_prediction(image_path, model)
    
#     print(result.to_coco_annotations())
    preds = [[ann['bbox'][0], 
              ann['bbox'][1], 
              ann['bbox'][0]+ann['bbox'][2], 
              ann['bbox'][1]+ann['bbox'][3],
              ann['score'],
              ann['category_id']]
              for ann in result.to_coco_annotations()]
    
    targets = []
    
    with open(label_path, 'r') as label_f:
        labels = label_f.readlines()

        for label in labels:
            cat, xc, yc, label_normalized_w, label_normalized_h = list(map(lambda x: int(x) if len(x) == 1 else float(x), label.split()))
            label_w, label_h = image_w * label_normalized_w, image_h * label_normalized_h
            xmin, ymin = (image_w * xc) - (label_w / 2), (image_h * yc) - (label_h / 2)

            xmin = 0 if xmin < 0 else xmin
            ymin = 0 if ymin < 0 else ymin
            
            targets.append([cat, xmin, ymin, xmin+label_w, ymin+label_h])
    
    
#     canvas = visualize_object_predictions(image, result.object_prediction_list)
    canvas = eval_one_image(image, torch.as_tensor(preds), torch.as_tensor(targets))
    canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), single_row_y_start=(0, 512))
def show_sample(index=1780, slice_size=512, overlap_ratio=0.2, single_row_y_start=200):
    fig, axes = plt.subplots(nrows=2,ncols=3,figsize=(32,16))
    plt.subplots_adjust(left=0.05, bottom=0.01, right=0.99, 
                    top=0.99, wspace=None, hspace=0.2)
    ax = axes.flatten()
    
    image_file = image_files[index]
    label_path = os.path.join(source_label_dir, label_files[index])
    image_path = os.path.join(source_image_dir, image_file)
    print(image_path)
    image = cv2.imread(image_path)
    image_h, image_w, _ = image.shape
    
    slice_result = slice_image(image_path, 
                              slice_width=slice_size,
                              slice_height=slice_size,
                              overlap_height_ratio=overlap_ratio,
                              overlap_width_ratio=overlap_ratio,
                              single_row_y_start=single_row_y_start,
                              single_row_predict=True,
                              verbose=2)

#     for start_pixel in slice_result.starting_pixels:
#         cv2.rectangle(image,
#                       start_pixel,
#                       [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
#                       color=(255, 255, 0),
#                       thickness=2)
        
#     result = get_sliced_prediction(image_path,
#                                    model,
#                                    slice_height=slice_size,
#                                    slice_width=slice_size,
#                                    postprocess_match_threshold=0.5,
# #                                    postprocess_type="NMM",
#                                    overlap_height_ratio=overlap_ratio,
#                                    overlap_width_ratio=overlap_ratio,
#                                    single_row_y_start=single_row_y_start,
#                                    single_row_predict=True,
#                                    verbose=2
#                                   )

    result = get_prediction(image_path, model)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    ax[5].imshow(canvas)
    ax[5].axis('off')
        
#     print(slice_result.images[0].shape)
    img_idx = 0 
    for s_im in slice_result.images:
        
        result = get_prediction(s_im, model)
    
# #     print(result.to_coco_annotations())
#     preds = [[ann['bbox'][0], 
#               ann['bbox'][1], 
#               ann['bbox'][0]+ann['bbox'][2], 
#               ann['bbox'][1]+ann['bbox'][3],
#               ann['score'],
#               ann['category_id']]
#               for ann in result.to_coco_annotations()]
    
#     targets = []
    
#     with open(label_path, 'r') as label_f:
#         labels = label_f.readlines()

#         for label in labels:
#             cat, xc, yc, label_normalized_w, label_normalized_h = list(map(lambda x: int(x) if len(x) == 1 else float(x), label.split()))
#             label_w, label_h = image_w * label_normalized_w, image_h * label_normalized_h
#             xmin, ymin = (image_w * xc) - (label_w / 2), (image_h * yc) - (label_h / 2)

#             xmin = 0 if xmin < 0 else xmin
#             ymin = 0 if ymin < 0 else ymin
            
#             targets.append([cat, xmin, ymin, xmin+label_w, ymin+label_h])
    
    
        canvas = visualize_object_predictions(s_im, result.object_prediction_list)
#     canvas = eval_one_image(image, torch.as_tensor(preds), torch.as_tensor(targets))
        canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
    
        ax[img_idx].imshow(canvas)
        ax[img_idx].axis('off')
        
        img_idx += 1
    
    plt.show()

In [None]:
from pycocotools.coco import COCO
plt.rcParams['figure.figsize'] = (16, 8)
coco=COCO(gt_json_path)
y_offset=-20

@interact(index=(0, len(image_files)-1))
def draw_gt_bbox(index=0):
    fig, ax = plt.subplots()
    img = coco.loadImgs(ids=[index])[-1]
    I = cv2.imread(os.path.join(source_image_dir, img['file_name']))
    I = cv2.cvtColor(I, cv2.COLOR_BGR2RGB)
    ax.imshow(I); plt.axis('off')
    annIds = coco.getAnnIds(imgIds=img['id'], iscrowd=None)
    anns = coco.loadAnns(annIds)
    coco.showAnns(anns, draw_bbox=True)
    for i, ann in enumerate(anns):
        ax.text(anns[i]['bbox'][0], anns[i]['bbox'][1]+y_offset, anns[i]['category_id'], style='italic', 
                bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})

In [None]:
evaluate_dict['eval_results']['bbox_mAP50']

In [None]:
### slice_size
for p_slice in [256, 384, 512, 640]:
    INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["slice_size"] = p_slice
    ### overlap_ratio
    for p_overlap_ratio in [0, 0.1, 0.2, 0.25]:
        INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["overlap_ratio"] = p_overlap_ratio
        ### match_threshold
        for p_match_threshold in [0.5, 0.7, 0.9]:
            INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["match_threshold"] = p_match_threshold
            ### postprocess_class_agnostic
            for p_postprocess_class_agnostic in [True, False]:
                INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["postprocess_class_agnostic"] = p_postprocess_class_agnostic
                ### no_sliced_prediction
                for p_no_sliced_prediction in [True, False]:
                    INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["no_sliced_prediction"] = p_no_sliced_prediction
                    ### single_row_y_start
                    for p_single_row_y_start in [0, 100, 200]:
                        INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]["single_row_y_start"] = p_single_row_y_start
                        
                        setting_info = f"slice_{p_slice}_overlap_{p_overlap_ratio}_match_{p_match_threshold}_agnostic_{}_no_slice_{}_ystart_{}"
                        
                        result = predict(
                            model_type=MODEL_TYPE,
                            model_path=MODEL_PATH,
                            model_config_path=MODEL_CONFIG_PATH,
                            model_confidence_threshold=0.01,
                            model_device="cuda:0",
                            model_category_mapping=None,
                            model_category_remapping=None,
                            source=EVAL_IMAGES_FOLDER_DIR,
                            no_standard_prediction=setting_params["no_standard_prediction"],
                            no_sliced_prediction=setting_params["no_sliced_prediction"],
                            image_size=None,
                            slice_height=setting_params["slice_size"],
                            slice_width=setting_params["slice_size"],
                            overlap_height_ratio=setting_params["overlap_ratio"],
                            overlap_width_ratio=setting_params["overlap_ratio"],
                            postprocess_type="GREEDYNMM",
                            postprocess_match_metric="IOS",
                            postprocess_match_threshold=setting_params["match_threshold"],
                            postprocess_class_agnostic=setting_params["postprocess_class_agnostic"],
                            novisual=not EXPORT_VISUAL,
                            dataset_json_path=EVAL_DATASET_JSON_PATH,
                            project="runs/221007",
                            name=setting_info,
                            visual_bbox_thickness=None,
                            visual_text_size=None,
                            visual_text_thickness=None,
                            visual_export_format="png",
                            verbose=0,
                            return_dict=True,
                            force_postprocess_type=True,
                            single_row_predict=True,
                            single_row_y_start=setting_params["single_row_y_start"]
                        )
                        
                        print("settings", setting_info)
                        
                        result_json_path = str(Path(result["export_dir"]) / "result.json")
                        
                        evaluate_dict = evaluate(
                            dataset_json_path=EVAL_DATASET_JSON_PATH,
                            result_json_path=result_json_path,
                            classwise=True,
                            max_detections=MAX_DETECTIONS,
                            return_dict=True,
                        )
                        
                    ### model_confidence_threshold
                    

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline

@interact(index=(0, len(image_files)-1), slice_size=(0, 512), overlap_ratio=(0, 0.5, 0.05), single_row_y_start=(0, 512))
def show_sample(index=0, slice_size=256, overlap_ratio=0.2, single_row_y_start=200):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    slice_result = slice_image(image_path, 
                              slice_width=slice_size,
                              slice_height=slice_size,
                              overlap_height_ratio=overlap_ratio,
                              overlap_width_ratio=overlap_ratio,
                              single_row_y_start=single_row_y_start,
                              single_row_predict=True,
                              verbose=1)

    for start_pixel in slice_result.starting_pixels:
        cv2.rectangle(image,
                      start_pixel,
                      [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
                      color=(255, 255, 0),
                      thickness=2)
        
    result = get_sliced_prediction(image_path,
                                   model,
                                   slice_height=slice_size,
                                   slice_width=slice_size,
                                   postprocess_match_threshold=0.5,
                                   overlap_height_ratio=overlap_ratio,
                                   overlap_width_ratio=overlap_ratio,
                                   single_row_y_start=single_row_y_start,
                                   single_row_predict=True)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

## Appendix

- Predictions are returned as [sahi.prediction.PredictionResult](sahi/prediction.py), you can access the object prediction list as:

In [None]:
object_prediction_list = result.object_prediction_list

In [None]:
object_prediction_list[0]

- ObjectPrediction's can be converted to [COCO annotation](https://cocodataset.org/#format-data) format:

In [None]:
result.to_coco_annotations()[:3]

- ObjectPrediction's can be converted to [COCO prediction](https://github.com/i008/COCO-dataset-explorer) format:

In [None]:
result.to_coco_predictions(image_id=1)[:3]

- ObjectPrediction's can be converted to [imantics](https://github.com/jsbroks/imantics) annotation format:

In [None]:
result.to_imantics_annotations()[:3]

- ObjectPrediction's can be converted to [fiftyone](https://github.com/voxel51/fiftyone) detection format:

In [None]:
result.to_fiftyone_detections()[:3]