### author: 김강렬

In [1]:
# [albumentations 사용]

import random

import cv2
from matplotlib import pyplot as plt

import albumentations as A

import os

In [2]:
# [이미지+라벨 확인용 함수]

BOX_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White


def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
    """Visualizes a single bounding box on the image"""
    x_min, y_min, w, h = bbox
    x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
   
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    plt.figure(figsize=(12, 12))
    plt.axis('off')
    plt.imshow(img)

In [28]:
# [경로 설정]
root_path = "C:/Users/user/fish_disease_dataset/multy_class_with_fish_bbox/set/"
img_path = root_path + "aug_without_out_0.2/images/"
img_save_path = root_path + "aug/images/"
label_path = root_path + "aug_without_out_0.2/labels/"
label_save_path = root_path + "aug/labels/"

img_list = os.listdir(img_path)
label_list = os.listdir(label_path)

print("증강 대상 이미지 개수: ", len(img_list))
print()
print("증강 대상 라벨 개수: ", len(label_list))

증강 대상 이미지 개수:  548

증강 대상 라벨 개수:  548


In [30]:
# [yolo형태 라벨링데이터를 읽고 coco형식으로 변환]

# 이미지 width, height

# 클래스 설정
category_id_to_name = {0: 'bleeding', 1: 'ulcer', 2: 'defect', 3: 'fish'}
aug_type = ''
for img in img_list:
    image = cv2.imread(img_path + img)

    img_width = len(image[0])
    img_height = len(image)

    lines = []
    with open(label_path + img[:-4] + ".txt") as f:
        line = f.readline()
        lines.append(line)
        while line != "":
            line = f.readline()
            lines.append(line)
    #         print(line)
    # lines의 맨 뒤에 있는 빈 인덱스 제거
    lines.remove('')

    
    category_ids = []
    bboxes = []
    # 한 줄씩 읽어서 bboxes형태로 전환
    for row in lines:
        # 줄마다 있는 개행 문자 제거, 공백으로 나눔
        row = row[:-1].split(" ")
        
        x_center = float(row[1]) * img_width
        y_center = float(row[2]) * img_height
        bbox_width = float(row[3]) * img_width
        bbox_height = float(row[4]) * img_height

        x_min = x_center - (bbox_width / 2)
        y_min = y_center - (bbox_height / 2)
        
        if x_min < 0 or y_min < 0: # or x_min >= 1 or y_min >= 1
            print("오류 이미지: ", img)
            print("x_min: %s, y_min: %s" % (x_min, y_min))
            print()
            continue
            
        category_ids.append(int(row[0]))
        bboxes.append([x_min, y_min, bbox_width, bbox_height])
    
        # 증강 파이프라인 설정
        aug_type = '_out_0.2'
        transform = A.Compose(
            [A.ShiftScaleRotate(shift_limit=0, rotate_limit=0, scale_limit = (-0.4, 0), p=1)],
            bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']),
        )

        # 증강 실행
        try:
            transformed = transform(image=image, bboxes=bboxes, category_ids=category_ids)
        except:
            print("오류 이미지: ", img)
            break
            
        transform_img = transformed["image"]
        transform_mask = transformed["bboxes"]
    
        # 증강된 이미지 저장
        cv2.imwrite(img_save_path + img[:-4] + aug_type + ".jpg", transform_img)

        # coco형식에서 yolo형식으로 변환하고 저장
        q = open(label_save_path + img[:-4] + aug_type + ".txt", 'w')
        q.close()
        with open(label_save_path + img[:-4] + aug_type + ".txt", 'a') as fi:
            for i in range(len(transform_mask)):
                x_center = (float(transform_mask[i][0]) + (float(transform_mask[i][2]) / 2)) / img_width
                y_center = (float(transform_mask[i][1]) + (float(transform_mask[i][3]) / 2)) / img_height
                bbox_width = float(transform_mask[i][2]) / img_width
                bbox_height = float(transform_mask[i][3]) / img_height

                fi.write(str(category_ids[i]) + " ")
                fi.write(str(x_center) + " ")
                fi.write(str(y_center) + " ")
                fi.write(str(bbox_width) + " ")
                fi.write(str(bbox_height) + "\n")