In [1]:
%matplotlib inline

In [2]:
import random
import cv2
from matplotlib import pyplot as plt

import albumentations as A

import sys
import os
import numpy as np

In [21]:
"""파일 경로 설정"""
root_dir = "C:/Users/user/fish_disease_dataset/multy_class_with_fish_bbox/set/" # 데이터셋 경로
img_dir = root_dir + "origin/valid/images/" 
label_dir = root_dir + "origin/valid/labels/"

In [None]:
# 좌우반전(rot_hf)
aug_type = "_rot_hf" 
transform = A.Compose(
    [A.HorizontalFlip(p=1)],
    bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']),
)

In [None]:
# 15도 회전(rot_15) 
aug_type = "_rot_15"
transform = A.Compose(
    [A.Rotate(limit=(15, 15), p=1, border_mode=cv2.BORDER_REPLICATE)], # 빈곳을 인접 이미지의 복사로 채움
    bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids'])
)

In [None]:
# 명암 조절(어둡게, 약간만)(bright_down)
aug_type = "_bright_down"
transform = A.Compose(
    [A.RandomBrightnessContrast(brightness_limit = (-0.15, -0.15), p=1)],
    bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']),
)

In [8]:
# 축소-0.2(out_0.2)
aug_type = "_out_0.2"
transform = A.Compose(
    [A.ShiftScaleRotate(shift_limit=0, rotate_limit=0, scale_limit = (-0.2, -0.2), p=1, 
     border_mode=cv2.BORDER_CONSTANT)], # 빈곳을 검은색으로 채움
    bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']),
)

In [18]:
# crop 영역 계산
def get_crop_region(bbox, img_w, img_h):
    x, y, w, h = bbox
    w_half, h_half = w / 2, h / 2
    
    x_min, x_max = x - w_half, x + w_half
    y_min, y_max = y - h_half, y + h_half
    
    x_min = int(x_min * img_w)
    y_min = int(y_min * img_h)
    x_max = int(x_max * img_w)
    y_max = int(y_max * img_h)
    
    return x_min, y_min, x_max, y_max

In [22]:
"""증강 실행 코드"""

# 라벨 폴더 읽어오기
label_file_list = os.listdir(label_dir)

# 이미지 폴더 읽어오기
img_file_list = os.listdir(img_dir)

category_id_to_name = {0: 'bleeding', 1: 'defect', 2: 'necrosis', 3: 'fish'}

for img in img_file_list:
    image = cv2.imread(img_dir + img)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_width = image.shape[1]
    img_height = image.shape[0]
    
    bboxes = []
    
    with open(label_dir + img[:-4] + ".txt", 'r') as f: # 라벨 txt 파일 내용 "2차원 리스트로" 읽어오기
        bbox = []
        category_ids = []
        
        while True:
            line = f.readline()
            
            if line == "":
                break
            elif '\n' in line:
                line.replace('\n', '')
                
            data = line.split(" ")
            category_ids.append(int(data[0]))
            del data[0]
            
            for d in data:
                 bbox.append(float(d))
                
            bboxes.append(bbox) 
            bbox = []
            
#     print(img[:6], " -> ", bboxes, " ", category_ids)
    
    # 증강 파이프라인 설정
    
    # 물고기 bbox 크롭(crop)
    aug_type = "_crop"
    x_min, y_min, x_max, y_max = get_crop_region(bboxes[1], img_width, img_height)
    transform = A.Compose(
        [A.Crop(x_min, y_min, x_max, y_max, p=1)],
        bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']),
    )
    
    # 증강
    transformed = transform(image=image, bboxes=bboxes, category_ids=category_ids)
#     try:
#         transformed = transform(image=image, bboxes=bboxes, category_ids=category_ids)
#     except:
#         print("오류 이미지: ", img)
#         break
    
    # 결과 저장 경로 설정
#     img_save_path = root_dir + "aug/" + aug_type + "/images/"
#     label_save_path = root_dir + "aug/" + aug_type + "/labels/"
    img_save_path = root_dir + "valid/images/"
    label_save_path = root_dir + "valid/labels/"

    # 증강된 이미지 저장
    cv2.imwrite(img_save_path + img[:-4] + aug_type + ".jpg", transformed["image"])

    # 증강된 바운딩 박스 라벨로 저장
    q = open(label_save_path + img[:-4] + aug_type + ".txt", 'w')
    q.close()
    with open(label_save_path + img[:-4] + aug_type + ".txt", 'a') as fi:
        for i in range(len(transformed["bboxes"])):
            fi.write(str(category_ids[i]) + " ")
            fi.write(str(transformed["bboxes"][i][0]) + " ")
            fi.write(str(transformed["bboxes"][i][1]) + " ")
            fi.write(str(transformed["bboxes"][i][2]) + " ")
            fi.write(str(transformed["bboxes"][i][3]) + "\n")
    

In [23]:
# crop된 이미지의 라벨들에서 물고기 bbox를 제거
label_dir = root_dir + "valid/labels/"
label_file_list = os.listdir(label_dir)
for label in label_file_list:
    with open(label_dir + label, "r") as f:
        lines = f.readlines()
    with open(label_dir + label, "w") as f:
        f.write(lines[0])

In [15]:
# set1 준비: crop 데이터셋에서 증강이 적용되지 않은 데이터를 복사
import shutil
data_type = "labels" # images, labels
source_dir = root_dir + "aug/_crop/" + data_type + "/"
target_dir = root_dir + "set1/train/" + data_type + "/"
file_list = os.listdir(source_dir)
for file in file_list:
    if "_bright_down" in file or "rot_15" in file or "rot_hf" in file:
        continue
    else:
        shutil.copy(source_dir + file, target_dir + file)