In [5]:
# person, rider pre-image / pre-mask dict 형태로 로드
# mask 합성 : 랜덤 좌표 계산 / height 계산 / image, mask resize
# 이미지 harmonization
# Crop & Paste
# modify annotation

import os
import random
import json
import copy
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2
import torch
import onnxruntime

import sys
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)
harmonization_path = os.path.join(parent_directory, "harmonization")
sys.path.insert(0, harmonization_path)

from harmonization import Harmonization

def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)

def get_object_names(base_path):
    image_dir_path = os.path.join(base_path,'images')
    names = []
    
    for name in os.listdir(image_dir_path):
        names.append(name)
    
    return names

def mask_to_polygon(mask):
    # 윤곽선 찾기
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 윤곽선을 다각형으로 변환
    polygons = []
    for contour in contours:
        contour = contour.squeeze(axis=1)  # 차원 축소
        polygon = contour[:, [0, 1]].tolist()  # (y, x) 순서로 변환하여 리스트로 저장
        polygons.append(polygon)

    return polygons

def polygon_to_mask(mask, polygons, color=255):
    polygons = np.array(polygons, dtype=np.int32)
    state = False

    try:
        mask = cv2.fillPoly(mask.astype("uint8"), [polygons], color)
        state = True
    except:
        print("mask passed!")

    return mask, state

def check_spot(spots, min, max):
    target_spots = [spot for spot in spots if min <= spot[0] and spot[0] < max]
    
    if not len(target_spots):
        return True
    else:
        return False    
    

def get_spot(spots, min, max):
    target_spots = [spot for spot in spots if min <= spot[0] and spot[0] < max]    
    spot = random.choice(target_spots)
    return spot

def get_mask(im, matter, threshold=220):
    # Get x_scale_factor & y_scale_factor to resize image
    def get_scale_factor(im_h, im_w, ref_size=512):

        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32

        x_scale_factor = im_rw / im_w
        y_scale_factor = im_rh / im_h

        return x_scale_factor, y_scale_factor
    
    # unify image channels to 3
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]

    # normalize values to scale it between -1 to 1
    im = (im - 127.5) / 127.5

    im_h, im_w, im_c = im.shape
    x, y = get_scale_factor(im_h, im_w)

    # resize image
    im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)

    # prepare input shape
    im = np.transpose(im)
    im = np.swapaxes(im, 1, 2)
    im = np.expand_dims(im, axis=0).astype('float32')

    # Initialize session and get prediction
    input_name = matter.get_inputs()[0].name
    output_name = matter.get_outputs()[0].name
    session_result = matter.run([output_name], {input_name: im})

    # refine matte
    matte = (np.squeeze(session_result[0]) * 255).astype('uint8')
    matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)

    # obtain predicted foreground mask
    mask = np.where(matte>127, 255, 0)
    return mask.astype('uint8')

def crop_from_mask(image, mask):
    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # If no contours found, return original image
    if not len(contours):
        return None

    # Find the bounding box of the largest contour
    max_area=0
    max_x, max_y, max_w, max_h = 0, 0, 0, 0
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        
        if w*h > max_area:
            max_x = x
            max_y = y
            max_w = w
            max_h = h

    # Crop the image using the bounding box
    cropped_image = image[max_y : max_y + max_h, max_x : max_x + max_w]
    cropped_mask = mask[max_y : max_y + max_h, max_x : max_x + max_w]

    return cropped_image, cropped_mask

def add_object(image, new_image, mask, new_mask, right, bottom, iou_threshold=0.2):
    mask_cp = mask.copy()
    image_cp = image.copy()
    
    # 더할 위치 계산
    left = right - new_mask.shape[1]
    top = bottom - new_mask.shape[0]
    n_left = 0
    n_top = 0

    if left<0:
        n_left = -left
        left = 0
    
    if top<0:
        n_top = -top
        top = 0
        

    # 마스크 영역에 새로운 마스크 더하기
    mask_cp[top:bottom, left:right] += new_mask[n_top:, n_left:]
    
    # truncation에 대한 iou 계산
    mask_spots = np.argwhere(mask_cp==255).tolist()
    new_mask_spots = np.argwhere(new_mask==255).tolist()
    ratio = float(len(mask_spots)) / (len(new_mask_spots))
    if ratio < iou_threshold:
        return None

    for spot in mask_spots:
        image_cp[spot[0], spot[1], :] = new_image[n_top+spot[0]-top, n_left+spot[1]-left, :]
    
    return (image_cp, mask_cp)

def modify_annotation(annotation, gen_annotation):
    height, width = annotation['imgHeight'], annotation['imgWidth']
    generated_mask = np.zeros((height, width))
    modified_annotation = copy.deepcopy(annotation)
    modified_annotation['objects'] = []
    
    for ann in gen_annotation:
        generated_mask, state = polygon_to_mask(generated_mask, ann['polygon'], 255)
    
    for ann in annotation['objects']:
        ori_mask = np.zeros((height, width))
        ori_mask, state = polygon_to_mask(ori_mask, ann['polygon'], 255)
        
        if not state:
            continue
        
        ori_mask = np.where((ori_mask == 255) & (generated_mask == 255), 0, ori_mask)
        ori_polygons = mask_to_polygon(ori_mask)
        
        for polygon in ori_polygons:
            m_ann = copy.deepcopy(ann)
            m_ann['polygon'] = polygon
            modified_annotation['objects'].append(m_ann)
    
    modified_annotation['objects'].extend(gen_annotation)
    return modified_annotation

def make_result(image, mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    image_with_mask_contours = np.copy(image)
    cv2.drawContours(image_with_mask_contours, contours, -1, (0, 0, 255), 2)
    return image_with_mask_contours

def set_step(min_val, max_val, steps):
    infos = []
    interval = (max_val-min_val)//steps
    
    for step in range(steps):
        infos.append({'min':min_val+interval*step, 'max':min_val+interval*(step+1)})
    
    return infos

def check_occlusion(pre_mask, gen_mask, threshold=0.5):
    pre_spots = np.argwhere(pre_mask==255)
    overlapped_spots = np.argwhere(np.bitwise_and(pre_mask.astype('uint8'), gen_mask.astype('uint8'))==255)
    occlusion_ratio = float(len(overlapped_spots))/len(pre_spots)
    
    if occlusion_ratio>=threshold:
        return True
    else:
        return False

# setting
device_id = 3
cls_names = ["road", "sidewalk", "railroad", "terrain", "snow", "leaves", "lane", "crosswalk", "parking_line", "stop_line", "parking_slot"]
mu_cls_names = ["sky"]
pre_path = '/data/noah/inference/magna_object'
image_path = "/data/noah/inference/magna_inference/10_RGB"
annotation_path = "/data/noah/inference/magna_inference/11_GT"
output_path = '/data/noah/inference/magna_insertion'
output_result_path = os.path.join(output_path, 'result')
output_result_with_mask_path = os.path.join(output_path, 'result_with_mask')
output_ann_path = os.path.join(output_path, 'annotations')
output_ann_mask_path = os.path.join(output_path, 'annotations_mask')

make_dirs(
    [
        output_path,
        output_result_path,
        output_ann_path,
        output_result_with_mask_path,
        output_ann_mask_path
    ]
)

harmonizer = Harmonization("/data/noah/ckpt/pretrain_ckpt/duconet/duconet1024.pth", device='cuda:{}'.format(device_id))
matter = onnxruntime.InferenceSession('/data/noah/ckpt/pretrain_ckpt/matting/modnet.onnx', None)

height_coefficient = [7.62995538e-14, -2.57068472e-10, 3.25925629e-07, -1.90207658e-04, 5.09169229e-02, -5.35772215e+00, 2.35348832e+02]
height_poly = np.poly1d(height_coefficient)

# load pre-image and pre-mask
pre_infos = {
    'person' : [],
    'rider' : []
}

pre_infos['person'] = get_object_names(os.path.join(pre_path, 'person'))
# pre_infos['rider'] = get_object_names(os.path.join(pre_path, 'rider'))

# object insertion
for ann_name in tqdm(os.listdir(annotation_path)[:100]):
    # load annotation
    annotation_file_path = os.path.join(annotation_path, ann_name)

    with open(annotation_file_path, 'r') as f:
        annotation = json.load(f)

    # setting height information parameter
    gen_annotations = []
    height_info = []
    height_info.extend(set_step(600,650,8)) # Far
    height_info.extend(set_step(650,750,4)) # Mid
    height_info.extend(set_step(750,900,2)) # Near

    # load image and mask
    height, width = annotation['imgHeight'], annotation['imgWidth']
    image_name = ann_name.split('_gt')[0]+'_rgb.jpg'
    image = Image.open(os.path.join(image_path, image_name))
    image = np.array(image).astype("uint8")    
    mask = np.zeros((len(height_info), height, width))

    # get target class index    
    cls_mask = np.zeros((height, width))
    for idx, ann in enumerate(annotation['objects']):
        if ann['label'] in cls_names:
            cls_mask, state = polygon_to_mask(cls_mask, ann['polygon'], color=255)
        
    for idx, ann in enumerate(annotation['objects']):
        if ann['label'] not in cls_names and ann['label'] not in mu_cls_names:
            cls_mask, state = polygon_to_mask(cls_mask, ann['polygon'], color=0)

    cls_spots = np.argwhere(cls_mask==255).tolist()
    spot_check = False
    
    for hinfo in height_info:
        if check_spot(cls_spots, hinfo['min'], hinfo['max']):
            spot_check = True
            break
    
    if spot_check:
        print('{} can not generate human'.format(annotation_file_path))
        continue        
    
    pre_check=False
    for h_idx, hinfo in enumerate(height_info):
        pre_check = not len(pre_infos['person']) and not len(pre_infos['rider'])
        if pre_check:
            break
        
        while True:
            # get random (bottom, right) spot and instance height
            spot = get_spot(cls_spots, hinfo['min'], hinfo['max']) #bottom, right 순서
            ins_height = int(height_poly(spot[0]))
            
            if not len(pre_infos['person']):
                pre_cls = 'rider'
            elif not len(pre_infos['rider']):
                pre_cls = 'person'
            else:
                pre_cls = random.choice(['person', 'rider'])

            # load pre-mask and pre-image
            pre_idx = random.randint(0, len(pre_infos[pre_cls])-1)
            pre_name = pre_infos[pre_cls][pre_idx]
            pre_image = cv2.imread(os.path.join(pre_path, pre_cls, 'images', pre_name))
            pre_image = cv2.cvtColor(pre_image, cv2.COLOR_BGR2RGB)
            pre_mask = get_mask(pre_image, matter)
            
            if len(np.argwhere(pre_mask==255))==(pre_image.shape[0]*pre_image.shape[1]):
                print(pre_name)
                                    
            # mask refinement & crop from the mask
            crop_result = crop_from_mask(pre_image, pre_mask)            
            
            if crop_result is None:
                pre_infos[pre_cls].pop(pre_idx)
                continue
            else:
                pre_image, pre_mask = crop_result

            if len(np.argwhere(pre_mask==255))==(pre_image.shape[0]*pre_image.shape[1]):
                print(pre_name)
            
            ratio = float(ins_height) / pre_image.shape[0]
            pre_image = cv2.resize(pre_image, (int(pre_image.shape[1]*ratio), ins_height))
            pre_mask = cv2.resize(pre_mask, (int(pre_mask.shape[1]*ratio), ins_height))
            
            # add pre-mask and pre-image
            paste_result = add_object(image, pre_image, np.zeros((height, width)), pre_mask, spot[1], spot[0])
            
            # truncation iou threshold check
            if paste_result is None:
                continue
            
            # check if overlapped with other generated instance
            paste_image, paste_mask = paste_result
            occlusion_check = False
            
            for idx in range(h_idx):
                if check_occlusion(mask[idx], paste_mask):
                    occlusion_check=True
                    break
            
            if occlusion_check:
                continue
                                    
            # pre-image / pre-mask pop
            pre_infos[pre_cls].pop(pre_idx)
            
            # assign pasted image and pasted mask
            image = np.copy(paste_image)            
            mask[h_idx] = paste_mask
            
            # add generated annotation
            gen_annotations.append(
                {
                    "label": pre_cls,
                    "polygon": mask_to_polygon(paste_mask)[0],
                    "sub_label": None,
                    "attribute_1": None,
                    "attribute_2": None,
                }
            )
            break
    
    if pre_check:
        print('{} person and rider mask length is zero'.format(annotation_file_path))
        break
    
    mask = np.sum(mask, axis=0)
    mask = np.where(mask==0, 0, 255).astype('uint8')
    
    # harmonization
    harmo_image = harmonizer.harmonize(image, mask).astype('uint8')
    
    gen_spots = np.argwhere(mask==255).tolist()
    for spot in gen_spots:
        image[spot[0], spot[1], :] = harmo_image[spot[0], spot[1], :]
    
    
    # save result image
    Image.fromarray(image.astype('uint8')).convert('RGB').save(os.path.join(output_result_path, image_name))
    # display(Image.fromarray(image.astype('uint8')))
            
    # save result image and mask
    image_with_mask = make_result(image.astype('uint8'), mask.astype('uint8'))
    Image.fromarray(image_with_mask.astype('uint8')).convert('RGB').save(os.path.join(output_result_with_mask_path, image_name))

    # modify and save annotation
    modified_annotation = modify_annotation(
        annotation,
        gen_annotations,
    )
        
    with open(os.path.join(output_ann_path, ann_name), 'w') as f:
        json.dump(modified_annotation, f)
    
    # save modified annotation mask
    modified_mask = np.zeros((height,width,3))
    for m_ann in modified_annotation['objects']:
        modified_mask, state = polygon_to_mask(
            modified_mask,
            m_ann['polygon'],
            color=(
                random.randint(0, 255),
                random.randint(0, 255),
                random.randint(0, 255),
            )
        )
    Image.fromarray(modified_mask.astype('uint8')).convert('RGB').save(os.path.join(output_ann_mask_path, image_name))

Load checkpoint from path: /data/noah/ckpt/pretrain_ckpt/duconet/duconet1024.pth
  1%|          | 1/100 [00:01<02:03,  1.25s/it]

/data/noah/inference/magna_inference/11_GT/2022-06-01-18-19-48_002886_left_rectilinear_gt_panoptic.json can not generate human


 17%|█▋        | 17/100 [03:23<17:18, 12.52s/it]

231_a young person_1.png


 20%|██        | 20/100 [03:50<12:19,  9.25s/it]

/data/noah/inference/magna_inference/11_GT/2022-06-02-08-38-48_000202_left_rectilinear_gt_panoptic.json can not generate human


 29%|██▉       | 29/100 [05:44<14:13, 12.02s/it]

838_a old person_0.png


 31%|███       | 31/100 [05:56<09:38,  8.39s/it]

/data/noah/inference/magna_inference/11_GT/2022-06-03-14-30-29_002958_right_rectilinear_gt_panoptic.json can not generate human


 39%|███▉      | 39/100 [07:41<13:09, 12.94s/it]

231_a young person_1.png


 49%|████▉     | 49/100 [09:53<11:06, 13.06s/it]

2936_a man_1.png


 53%|█████▎    | 53/100 [10:45<10:10, 12.98s/it]

838_a old person_0.png


 61%|██████    | 61/100 [12:27<08:45, 13.48s/it]

2936_a man_1.png


 64%|██████▍   | 64/100 [12:54<05:43,  9.54s/it]

/data/noah/inference/magna_inference/11_GT/2022-06-03-10-33-01_000658_right_rectilinear_gt_panoptic.json can not generate human
231_a young person_1.png


 66%|██████▌   | 66/100 [13:19<06:17, 11.09s/it]

231_a young person_1.png


 67%|██████▋   | 67/100 [13:31<06:15, 11.38s/it]

399_a old person_1.png


 70%|███████   | 70/100 [14:15<06:40, 13.36s/it]

353_a man_0.png


 71%|███████   | 71/100 [14:27<06:17, 13.02s/it]

353_a man_0.png


 72%|███████▏  | 72/100 [14:40<06:02, 12.94s/it]

353_a man_0.png


 74%|███████▍  | 74/100 [15:08<05:50, 13.49s/it]

2936_a man_1.png


 81%|████████  | 81/100 [16:43<04:11, 13.24s/it]

399_a old person_1.png


 83%|████████▎ | 83/100 [17:08<03:37, 12.81s/it]

399_a old person_1.png


 84%|████████▍ | 84/100 [17:23<03:36, 13.56s/it]

353_a man_0.png


 85%|████████▌ | 85/100 [17:36<03:20, 13.40s/it]

399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old person_1.png
399_a old p

 85%|████████▌ | 85/100 [33:39<05:56, 23.76s/it]


KeyboardInterrupt: 