## 1. GPU 확인

In [None]:
!nvidia-smi

## 2. Lib Import

In [None]:
# arrange an instance segmentation model for test
from sahi.utils.yolov5 import (
    download_yolov5s6_model,
)

# import required functions, classes
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict
from sahi.scripts.coco_error_analysis import analyse
from sahi.scripts.coco_evaluation import evaluate
from IPython.display import Image
from pathlib import Path
import json
import os
from pathlib import Path
from PIL import Image

## 3. Data & Model Path 선언

In [None]:
source_image_dir = "../resources/FLL_VAL/images/"
source_label_dir = "../resources/FLL_VAL/labels/"

# 960 imgsz로 훈련된 FLL Target
fll_model_221024_960_path = '../resources/models/221024_960/best.pt'

## 4. Model Load

In [None]:
fll_model_221024_960 = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=fll_model_221024_960_path,
    confidence_threshold=0.25,
    device="cuda:0"
)

In [None]:
model = fll_model_221024_960
model_path = fll_model_221024_960_path

## 5. Gt Json 생성

In [None]:
def initial_extract(img_dir, label_dir, out_dir):
    if os.path.exists(os.path.join(out_dir, 'val.json')):
        os.remove(os.path.join(out_dir, 'val.json'))
    
    licenses = [
        {
            "name": "",
            "id": 0,
            "url": ""
        }
    ]

    info_ = [
        {
            "contributor": "",
            "date_created": "",
            "description": "",
            "url": "",
            "version": "",
            "year": ""
        }
    ]

    categories = [
        {
            "id": 0,
            "name": "Buoy",
            "supercategory": ""
        },
        {
            "id": 1,
            "name": "Boat",
            "supercategory": ""
        },
        {
            "id": 2,
            "name": "Channel Marker",
            "supercategory": ""
        },
        {
            "id": 3,
            "name": "Speed Warning Sign",
            "supercategory": ""
        }
    ]

    img_idx = 0
    annot_idx = 0

    imgs_list = []
    annots_list = []

    for label_file in sorted(os.listdir(label_dir)):
        label_file_ = os.path.join(label_dir, label_file)
        img_file_ = os.path.join(img_dir, f'{os.path.splitext(label_file)[0]}.jpg')
        img = Image.open(img_file_)
        image_w, image_h = img.size

        imgs_list.append({
            'id': img_idx,
            'width': image_w,
            'height': image_h,
            'file_name': f'{os.path.splitext(label_file)[0]}.jpg',
            "license": 0,
            "flickr_url": "",
            "coco_url": "",
            "date_captured": 0
        })

        with open(label_file_, 'r') as label_f:
            labels = label_f.readlines()

            for label in labels:
                cat, xc, yc, label_normalized_w, label_normalized_h = list(map(lambda x: int(x) if len(x) == 1 else float(x), label.split()))
                label_w, label_h = image_w * label_normalized_w, image_h * label_normalized_h
                xmin, ymin = (image_w * xc) - (label_w / 2), (image_h * yc) - (label_h / 2)
                
                xmin = 0 if xmin < 0 else xmin
                ymin = 0 if ymin < 0 else ymin

                annots_list.append({
                    'id': annot_idx,
                    'image_id': img_idx,
                    'category_id': cat,
                    'area': int(label_h * label_w),
                    'bbox': [
                        xmin,
                        ymin,
                        label_w,
                        label_h
                    ],
                    'iscrowd': 0,
                    'attributes': {
                        'type': '',
                        'occluded': False
                    },
                    'segmentation': []
                })

                annot_idx += 1

        img_idx += 1

    out_dict = {
        'licenses': licenses,
        'info': info_,
        'categories': categories,
        'images': imgs_list,
        'annotations': annots_list
    }
    
    if not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)
    
    with open(os.path.join(out_dir, 'val.json'), 'w') as out_f:
        print(os.path.join(out_dir, 'val.json'))
        json.dump(out_dict, out_f)
        
    return os.path.join(out_dir, 'val.json')

In [None]:
# initial_extract(img_dir, label_dir, out_dir)
gt_json_path = initial_extract(source_image_dir, source_label_dir, str(Path(source_image_dir).parent))

## 6. Eval hyper-param Setting

In [None]:
INFERENCE_SETTING_TO_PARAMS = {
    "AVIKUS_FL": {
        "model_confidence_threshold": 0.25,
        "model_device": "cuda:0",
        "image_size": 960,
        "postprocess_type": "GREEDYNMM",
        "postprocess_match_metric": "IOS",
        "no_standard_prediction": False, # no FULL inference? 
        "no_sliced_prediction": True, # no Tiling? 
        "slice_size": 512, # Slice size when activate tiling
        "overlap_ratio": 0.15, # Overlap ratio when activate tiling
        "match_threshold": 0.5, # Merge match thresh when activate tiling
        "postprocess_class_agnostic": False,  # class agnostic when activate tiling
        "custom_slice_y_start": 200,  # Y start point when activate tiling
    },
}

MODEL_TYPE = "yolov5" # model type
MODEL_PATH = model_path # model path
MODEL_CONFIG_PATH = ""
EVAL_IMAGES_FOLDER_DIR = source_image_dir # source dir 
EVAL_DATASET_JSON_PATH = gt_json_path # gt json path
INFERENCE_SETTING = "AVIKUS_FL"
EXPORT_VISUAL = False
MAX_DETECTIONS = 300

setting_params = INFERENCE_SETTING_TO_PARAMS[INFERENCE_SETTING]

## 참고. Interact로 인퍼런스 확인

In [None]:
from sahi.utils.cv import Colors
import numpy as np
import copy

def visualize_object_predictions(
    image: np.array,
    object_prediction_list,
    rect_th: int = None,
    text_size: float = None,
    text_th: float = None,
    color: tuple = None,
):
    """
    Visualizes prediction category names, bounding boxes over the source image
    and exports it to output folder.
    Arguments:
        object_prediction_list: a list of prediction.ObjectPrediction
        rect_th: rectangle thickness
        text_size: size of the category name over box
        text_th: text thickness
        color: annotation color in the form: (0, 255, 0)
        output_dir: directory for resulting visualization to be exported
        file_name: exported file will be saved as: output_dir+file_name+".png"
        export_format: can be specified as 'jpg' or 'png'
    """
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.001), 1)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3
    # add bbox and mask to image if present
    for object_prediction in object_prediction_list:
        # deepcopy object_prediction_list so that original is not altered
        object_prediction = object_prediction.deepcopy()

        bbox = object_prediction.bbox.to_voc_bbox()
        category_name = object_prediction.category.name
        score = object_prediction.score.value

        # set color
        if colors is not None:
            color = colors(object_prediction.category.id)
        # visualize masks if present
        if object_prediction.mask is not None:
            # deepcopy mask so that original is not altered
            mask = object_prediction.mask.bool_mask
            # draw mask
            rgb_mask = apply_color_mask(mask, color)
            image = cv2.addWeighted(image, 1, rgb_mask, 0.4, 0)
        # set bbox points
        p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
        # visualize boxes
        cv2.rectangle(
            image,
            p1,
            p2,
            color=color,
            thickness=rect_th
        )
        # arange bounding box text location
        label = f"{category_name} {score:.2f}"
        w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]  # label width, height
        outside = p1[1] - h - 3 >= 0  # label fits outside box
        p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
        # add bounding box text
        cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(
            image,
            label,
            (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
            0,
            text_size,
            (255, 255, 255),
            thickness=text_th,
        )
        
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
    return image

In [None]:
from ipywidgets import interact
from sahi.slicing import slice_image
import matplotlib.pyplot as plt
%matplotlib inline

image_files = sorted(os.listdir(source_image_dir))

@interact(index=(0, len(image_files)-1),
          slice_size=(0, 640),
          overlap_ratio=(0, 0.5, 0.05),
          custom_slice_x_start=(0, 640),
          custom_slice_y_start=(0, 512),
          custom_slice_mode=(0,3),
          only_full_inference=(0,1))
def show_sample(index=0, slice_size=640, overlap_ratio=0.25,
                custom_slice_x_start=640, custom_slice_y_start=360, custom_slice_mode=2,
                only_full_inference=0):
    image_file = image_files[index]
    image_path = os.path.join(source_image_dir, image_file)
    image = cv2.imread(image_path)
    
    if not only_full_inference:
        slice_result = slice_image(image_path, 
                                  slice_width=slice_size,
                                  slice_height=slice_size,
                                  overlap_height_ratio=overlap_ratio,
                                  overlap_width_ratio=overlap_ratio,
                                  custom_slice_x_start=custom_slice_x_start,
                                  custom_slice_y_start=custom_slice_y_start,
                                  custom_slice_mode=custom_slice_mode,
                                  verbose=1)

        for start_pixel in slice_result.starting_pixels:
            cv2.rectangle(image,
                          start_pixel,
                          [s1+s2 for s1, s2 in zip(start_pixel,[slice_size,slice_size])],
                          color=(255, 255, 0),
                          thickness=2)
        
        result = get_sliced_prediction(image_path,
                                       model,
                                       slice_height=slice_size,
                                       slice_width=slice_size,
                                       postprocess_match_threshold=0.5,
                                       overlap_height_ratio=overlap_ratio,
                                       overlap_width_ratio=overlap_ratio,
                                       custom_slice_x_start=custom_slice_x_start,
                                       custom_slice_y_start=custom_slice_y_start,
                                       custom_slice_mode=custom_slice_mode
                                      )
    else:
        result = get_prediction(image_path, model)
    
    canvas = visualize_object_predictions(image, result.object_prediction_list)
    plt.figure(figsize=(16,16))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

## 7. Execute Full Inference

In [None]:
result = predict(
    model_type=MODEL_TYPE,
    model_path=MODEL_PATH,
    model_config_path=MODEL_CONFIG_PATH,
    model_confidence_threshold=setting_params["model_confidence_threshold"],
    model_device=setting_params["model_device"],
    model_category_mapping=None,
    model_category_remapping=None,
    source=EVAL_IMAGES_FOLDER_DIR,
    no_standard_prediction=setting_params["no_standard_prediction"],
    no_sliced_prediction=setting_params["no_sliced_prediction"],
    slice_height=setting_params["slice_size"],
    slice_width=setting_params["slice_size"],
    overlap_height_ratio=setting_params["overlap_ratio"],
    overlap_width_ratio=setting_params["overlap_ratio"],
    image_size=setting_params["image_size"],
    postprocess_type=setting_params["postprocess_type"],
    postprocess_match_metric=setting_params["postprocess_match_metric"],
    postprocess_match_threshold=setting_params["match_threshold"],
    postprocess_class_agnostic=setting_params["postprocess_class_agnostic"],
    novisual=not EXPORT_VISUAL,
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    project="runs/FLL_FULL_INFERENCE",
    name=INFERENCE_SETTING,
    visual_bbox_thickness=None,
    visual_text_size=None,
    visual_text_thickness=None,
    visual_export_format="png",
    verbose=2,
    return_dict=True,
    force_postprocess_type=True,
    custom_slice_mode=0,
    custom_slice_x_start=0,
    custom_slice_y_start=0,
)

result_json_path = str(Path(result["export_dir"]) / "result.json")

## 8. mAP evaluation 

In [None]:
evaluate_dict = evaluate(
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    result_json_path=result_json_path,
    classwise=True,
    max_detections=MAX_DETECTIONS,
    return_dict=True,
)

## 9. Export DSL Standard result.json

In [None]:
dsl_standard_result_json = ""

In [None]:
evaluate_dict = evaluate(
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    result_json_path=dsl_standard_result_json,
    classwise=True,
    max_detections=MAX_DETECTIONS,
    return_dict=True,
)

## 10. Visualize json file

In [None]:
with open(dsl_standard_result_json, 'r') as f:
    datas = json.load(f)

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline
import cv2

@interact(index=(0, len(image_files)-1))
def show_sample(index=0):
    infos = [ data for data in datas if data['image_id'] == index ]
    image_path = os.path.join(source_image_dir, image_files[index])
    image = cv2.imread(image_path)    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    for info in infos:
        bbox = list(map(int, info['bbox']))
        category_name =  info['category_name']
        conf = info['score']
        label = f"{category_name} {conf:.2f}"
        
        cv2.rectangle(image,
                      bbox[:2],
                      [ bbox1 + bbox2 for bbox1, bbox2 in zip(bbox[:2], bbox[2:])],
                      color=(255, 0, 0),
                      thickness=2)
        cv2.putText(
            image,
            label,  
            [bbox[0], bbox[1]-10],
            0,
            1,
            (255, 255, 255),
            thickness=3,
        )
        
    plt.figure(figsize=(16,16))
    plt.imshow(image)
    plt.axis('off')
    plt.show()


## 11. Export DSL Preproc Full inference result.json

In [None]:
dsl_preproc_full_inference_json = ""

In [None]:
evaluate_dict = evaluate(
    dataset_json_path=EVAL_DATASET_JSON_PATH,
    result_json_path=dsl_preproc_full_inference_json,
    classwise=True,
    max_detections=MAX_DETECTIONS,
    return_dict=True,
)

## 12. Visualize json file

In [None]:
with open(dsl_preproc_full_inference_json, 'r') as f:
    datas = json.load(f)

In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt
%matplotlib inline
import cv2

@interact(index=(0, len(image_files)-1))
def show_sample(index=0):
    infos = [ data for data in datas if data['image_id'] == index ]
    image_path = os.path.join(source_image_dir, image_files[index])
    image = cv2.imread(image_path)    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    for info in infos:
        bbox = list(map(int, info['bbox']))
        category_name =  info['category_name']
        conf = info['score']
        label = f"{category_name} {conf:.2f}"
        
        cv2.rectangle(image,
                      bbox[:2],
                      [ bbox1 + bbox2 for bbox1, bbox2 in zip(bbox[:2], bbox[2:])],
                      color=(255, 0, 0),
                      thickness=2)
        cv2.putText(
            image,
            label,  
            [bbox[0], bbox[1]-10],
            0,
            1,
            (255, 255, 255),
            thickness=3,
        )
        
    plt.figure(figsize=(16,16))
    plt.imshow(image)
    plt.axis('off')
    plt.show()