In [None]:
from PIL import Image
import numpy as np
import cv2
import json
import os
from tqdm import tqdm


def mask_to_polygon(mask):
    # 윤곽선 찾기
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 윤곽선을 다각형으로 변환
    polygons = []
    for contour in contours:
        contour = contour.squeeze(axis=1)  # 차원 축소
        polygon = contour[:, [0, 1]].tolist()  # (y, x) 순서로 변환하여 리스트로 저장
        polygons.append(polygon)

    return polygons


def polygon_to_mask(mask, polygons, color=255):
    polygons = np.array(polygons, dtype=np.int32)
    state = False

    try:
        mask = cv2.fillPoly(mask.astype("uint8"), [polygons], color)
        state = True
    except:
        print("mask passed!")

    return mask, state


def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)


def crop_from_mask(image, mask):
    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # If no contours found, return original image
    if not contours:
        return image

    # Find the bounding box of the largest contour
    x, y, w, h = cv2.boundingRect(contours[0])

    # Crop the image using the bounding box
    cropped_image = image[y : y + h, x : x + w]
    cropped_mask = mask[y : y + h, x : x + w]

    return cropped_image, cropped_mask


annotation_path = "/data/noah/dataset/coco_rider/anno"
out_base_path = "/data/noah/dataset/coco_rider/magna_rider_premask"
out_mask_path = os.path.join(out_base_path, "masks")
out_image_path = os.path.join(out_base_path, "images")

make_dirs([out_base_path, out_mask_path, out_image_path])
cnt = 0

for name in tqdm(os.listdir(annotation_path)):
    ann_path = os.path.join(annotation_path, name)

    with open(ann_path, "r") as f:
        ann = json.load(f)

    height, width = ann["metadata"]["height"], ann["metadata"]["width"]
    mask = np.zeros((height, width))

    for _ann in ann["annotations"]:
        point = np.array(_ann["points"], dtype=np.int32)
        try:
            mask = cv2.fillPoly(mask, [point], color=255)
        except:
            continue

    polygons = mask_to_polygon(mask)
    annotations = []
    image = Image.open(os.path.join(ann["parent_path"], ann["filename"]))

    for polygon in polygons:
        cnt += 1
        mask = np.zeros((height, width))
        mask, state = polygon_to_mask(mask, polygon)

        if state:
            crop_image, crop_mask = crop_from_mask(np.array(image).astype("uint8"), mask.astype("uint8"))
            crop_image = Image.fromarray(crop_image)
            crop_mask = Image.fromarray(crop_mask).convert("L")

            crop_image.save(os.path.join(out_image_path, "{}.png".format(cnt)))
            crop_mask.save(os.path.join(out_mask_path, "{}.png".format(cnt)))

In [None]:
import os
import random
import json
import copy
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2
import torch

from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DDIMScheduler
from controlnet_aux.processor import MidasDetector
import sys

sys.path.insert(0, "../harmonization")
from harmonization import Harmonization
from gtgen.bpr import GtGenBPRInference


def make_inputs(image, annotation, target_indexs, harmonizer):
    height, width = annotation["metadata"]["height"], annotation["metadata"]["width"]
    mask = np.zeros((height, width))
    spot = None
    k = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

    while True:
        target_index = random.choice(target_indexs)
        rb_spot = random_coordinate(annotation["annotations"], target_index, height, width)  # height, width 순

        if rb_spot is None:
            print("{} can not generate right bottom spot".format(annotation["filename"]))
            return None

        # rb_spot x값을 기준으로 height 선정 및 target_height 산출
        target_height = random.randint(500, 500)
        mask_name = random.choice(mask_lists)

        image_pth = os.path.join(image_path, mask_name)
        mask_pth = os.path.join(mask_path, mask_name)

        paste_image = Image.open(image_pth)
        paste_mask = Image.open(mask_pth)

        ratio = float(target_height) / paste_mask.height
        paste_mask = paste_mask.resize((int(paste_mask.width * ratio), int(paste_mask.height * ratio)))
        paste_image = paste_image.resize((int(paste_image.width * ratio), int(paste_image.height * ratio)))

        paste_mask = np.array(paste_mask).astype("uint8")
        paste_image = np.array(paste_image).astype("uint8")

        if np.sum(paste_mask) == 0:
            continue

        paste_mask = cv2.morphologyEx(paste_mask, cv2.MORPH_OPEN, k, iterations=5)
        paste_mask = np.where(paste_mask > 127, 255, 0).astype("uint8")
        sum_mask = add_mask(mask, paste_mask, rb_spot[1], rb_spot[0])

        if sum_mask is None:
            continue

        spot = np.argwhere(sum_mask == 255)
        print(mask_name)
        print(image.shape)
        print(paste_image.shape)
        sum_image = add_image(image, paste_image, paste_mask, rb_spot[1], rb_spot[0])
        break

    sum_image = harmonizer.harmonize(sum_image, sum_mask)
    sum_image = Image.fromarray(sum_image.astype("uint8")).convert("RGB")
    sum_mask = Image.fromarray(sum_mask.astype("uint8")).convert("L")

    output = {"image": sum_image, "mask": sum_mask, "spot": spot}
    return output


def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid


def find_outer_contour_coordinates(mask):
    # OpenCV의 findContours 함수를 사용하여 이진 이미지의 외곽선을 찾습니다.
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 외곽선을 이루는 점들의 좌표를 반환합니다.
    outer_contour_coords = [[], []]
    for contour in contours:
        for point in contour:
            x, y = point[0]
            outer_contour_coords[0].append(y)
            outer_contour_coords[1].append(x)

    return outer_contour_coords


def euclidean_distance(point1, point2):
    return np.sqrt((point2[0] - point1[0]) ** 2 + (point2[1] - point1[1]) ** 2)


def random_coordinate(annotation, target_index, height, width):
    mask = np.zeros((height, width))
    mask, state = polygon_to_mask(mask, annotation[target_index]["points"], color=255)

    if not state:
        return None

    for idx, ann in enumerate(annotation):
        if idx == target_index:
            continue

        mask, state = polygon_to_mask(mask, ann["points"], color=0)

    target_spots = np.argwhere(mask == 255).tolist()

    if len(target_spots) == 0:
        return None

    coordinates = find_outer_contour_coordinates(mask)
    threshold = 500

    # 랜덤으로 좌표 선택
    while True:
        target_spot = random.choice(target_spots)  # height,width 순
        distances = [
            euclidean_distance((coord[0], coord[1]), target_spot) for coord in zip(coordinates[0], coordinates[1])
        ]
        min_distance = int(min(distances))

        if min_distance >= threshold:
            return target_spot
        else:
            threshold = threshold // 2


def add_mask(mask, new_mask, right, bottom):
    mask_cp = mask.copy()

    # 새로운 마스크를 더할 위치 계산
    left = right - new_mask.shape[1]
    top = bottom - new_mask.shape[0]

    # 마스크 영역에 새로운 마스크 더하기
    if left < 0 or top < 0:
        return None

    mask_cp[top:bottom, left:right] += new_mask

    return mask_cp


def add_image(image, new_image, mask, right, bottom):
    image_cp = image.copy()

    # 새로운 마스크를 더할 위치 계산
    left = right - new_image.shape[1]
    top = bottom - new_image.shape[0]

    # 마스크 영역에 새로운 마스크 더하기
    if left < 0 or top < 0:
        return None

    for h in range(top, bottom):
        for w in range(left, right):
            if mask[h - top, w - left]:
                image_cp[h, w, :] = new_image[h - top, w - left, :]

    return image_cp


def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)


def make_result(image, mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    image_with_mask_contours = np.copy(image)
    cv2.drawContours(image_with_mask_contours, contours, -1, (0, 0, 255), 2)
    return image_with_mask_contours


def mask_refinement(image, ann, bpr_inference):
    seg_result = bpr_inference.inference(
        img=image,
        seg=ann,
        img_scale=(256, 256),
        img_ratios=[1.0, 2.0],
        nms_iou_threshold=0.5,
        point_density=0.25,
        patch_size=[32, 64, 96],
        padding=0,
    )

    height, width = image.shape[0], image.shape[1]
    result_map = np.zeros((height, width))

    for sr in seg_result["annotations"]:
        result_map, state = polygon_to_mask(result_map, sr["points"], 255)

    return result_map.astype("uint8")


def mask_to_polygon(mask):
    # 윤곽선 찾기
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 윤곽선을 다각형으로 변환
    polygons = []
    for contour in contours:
        contour = contour.squeeze(axis=1)  # 차원 축소
        polygon = contour[:, [0, 1]].tolist()  # (y, x) 순서로 변환하여 리스트로 저장
        polygons.append(polygon)

    return polygons


def polygon_to_mask(mask, polygons, color=255):
    polygons = np.array(polygons, dtype=np.int32)
    state = False

    try:
        mask = cv2.fillPoly(mask.astype("uint8"), [polygons], color)
        state = True
    except:
        print("mask passed!")

    return mask, state


def modify_annotation(annotations, polygons, height, width):
    # draw generated mask
    generated_mask = np.zeros((height, width))
    generated_annotations = []
    original_annotations = []

    for polygon in polygons:
        generated_mask, state = polygon_to_mask(generated_mask, polygon, 255)

        if state:
            ann = {
                "id": "",
                "type": "poly_seg",
                "attributes": {},
                "points": polygon,
                "label": "rider",
            }
            generated_annotations.append(ann)

    for annotation in annotations:
        # draw original mask
        original_mask = np.zeros((height, width))
        original_mask, state = polygon_to_mask(original_mask, annotation["points"], 255)

        if not state:
            continue

        # modify original mask
        original_mask = np.where((original_mask == 255) & (generated_mask == 255), 0, original_mask)
        original_polygons = mask_to_polygon(original_mask)

        for polygon in original_polygons:
            ann = copy.deepcopy(annotation)
            ann["points"] = polygon
            original_annotations.append(ann)

    original_annotations.extend(generated_annotations)
    return original_annotations


device = "cuda:2"
torch.cuda.set_device(device)

# 사전 정의된 Crop 이미지와 마스크
mask_path = "/data/noah/inference/magna_rider_premask/masks"
image_path = "/data/noah/inference/magna_rider_premask/images"
mask_lists = os.listdir(mask_path)

# 생성할 Annotation 정보
base_image_path = "/data/noah/dataset/magna_traffic_light/pre_images"
target_annotation_path = "/data/noah/dataset/magna_traffic_light/pre_anno"
target_class_name = "road"
target_height = None

save_base_path = "/data/noah/inference/magna_rv_inpainting"
save_result_path = os.path.join(save_base_path, "results")
save_draw_path = os.path.join(save_base_path, "draw_results")
save_refined_draw_path = os.path.join(save_base_path, "draw_results_refined")
save_annotation_draw_path = os.path.join(save_base_path, "annotation_draw")
save_modified_annotation_draw_path = os.path.join(save_base_path, "modified_annotation_draw")
save_mask_path = os.path.join(save_base_path, "masks")
make_dirs(
    [
        save_base_path,
        save_result_path,
        save_draw_path,
        save_refined_draw_path,
        save_modified_annotation_draw_path,
        save_annotation_draw_path,
        save_mask_path,
    ]
)

bpr_inference = GtGenBPRInference(devices=[3], batch_size=48)
bpr_model = bpr_inference.load_model("/data/noah/ckpt/finetuning/bpr.pth", img_scale=(256, 256))
assert bpr_model is not None, "model not loaded"

harmonizer = Harmonization("/data/noah/ckpt/pretrain_ckpt/duconet/duconet1024.pth", device=device)

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/rv_inpaint_5.1", torch_dtype=torch.float16
).to(device)
pipe.load_lora_weights(
    "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/lora_detail", weight_name="add_detail.safetensors"
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# generator = torch.Generator(device=device).manual_seed(42)

prompt = "{}, best quality, extremely detailed, clearness, naturalness, film grain, crystal clear, photo with color, actuality"
negative_prompt = "cartoon, anime, painting, disfigured, immature, blur, picture, 3D, render, semi-realistic, drawing, poorly drawn, bad anatomy, wrong anatomy, gray scale, worst quality, low quality, sketch, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"


for ann_idx, ann_name in tqdm(enumerate(os.listdir(target_annotation_path)[:100])):
    annotation_path = os.path.join(target_annotation_path, ann_name)

    with open(annotation_path, "r") as f:
        annotation = json.load(f)

    target_indexs = []

    for idx, ann in enumerate(annotation["annotations"]):
        if ann["label"] == target_class_name:
            target_indexs.append(idx)
            break

    if not len(target_indexs):
        continue

    height, width = annotation["metadata"]["height"], annotation["metadata"]["width"]
    image = Image.open(os.path.join(base_image_path, annotation["parent_path"][1:], annotation["filename"]))
    image = np.array(image).astype("uint8")

    sum_mask_image = np.zeros((height, width))
    sum_result_image = np.copy(image)
    sum_draw_image = None

    generate_cnt = random.randint(3, 5)
    inputs = None
    generated_spots = []

    for iter_cnt in range(generate_cnt):
        inputs = make_inputs(image, annotation, target_indexs, harmonizer)

        if inputs is None:
            break

        if len(generated_spots) and np.any(np.all(np.isin(np.array(generated_spots), inputs["spot"]), axis=1)):
            continue
        else:
            generated_spots.extend(inputs["spot"].tolist())

        result_image = pipe(
            prompt=prompt.format("a rider is on the road"),
            negative_prompt=negative_prompt,
            image=inputs["image"],
            mask_image=inputs["mask"],
            height=inputs["image"].height,
            width=inputs["image"].width,
            num_inference_steps=25,
            guidance_scale=7.5,
            # generator=generator
        ).images[0]

        result_image = harmonizer.harmonize(np.array(result_image), np.array(inputs["mask"]))
        result_image = result_image.astype("uint8")

        for spot in inputs["spot"]:
            sum_result_image[spot[0], spot[1], :] = result_image[spot[0], spot[1], :]

        sum_mask_image = sum_mask_image + np.array(inputs["mask"])

    if inputs is not None:
        # mask refinement
        polygons = mask_to_polygon(sum_mask_image)
        generated_annotation = copy.deepcopy(annotation)
        anns = []

        for polygon in polygons:
            an = {
                "id": "",
                "type": "poly_seg",
                "attributes": {},
                "points": polygon,
                "label": "rider",
            }
            anns.append(an)

        generated_annotation["annotations"] = anns
        refined_mask = mask_refinement(sum_result_image, generated_annotation, bpr_inference)

        # draw result image with mask
        sum_draw_image = make_result(np.copy(sum_result_image), sum_mask_image.astype("uint8"))
        sum_draw_refined_image = make_result(np.copy(sum_result_image), refined_mask)

        # annotation 수정 작업 #
        modified_annotation = modify_annotation(
            annotation["annotations"],
            polygons,
            height,
            width,
        )

        # draw annotation
        original_mask = np.zeros((height, width, 3))
        modified_mask = np.zeros((height, width, 3))

        for ann in annotation["annotations"]:
            original_mask, state = polygon_to_mask(
                original_mask,
                ann["points"],
                color=(
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255),
                ),
            )

        for m_ann in modified_annotation:
            modified_mask, state = polygon_to_mask(
                modified_mask,
                m_ann["points"],
                color=(
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255),
                ),
            )

        Image.fromarray(sum_result_image.astype("uint8")).convert("RGB").save(
            os.path.join(save_result_path, annotation["filename"])
        )
        Image.fromarray(sum_draw_image.astype("uint8")).convert("RGB").save(
            os.path.join(save_draw_path, annotation["filename"])
        )
        Image.fromarray(sum_draw_refined_image.astype("uint8")).convert("RGB").save(
            os.path.join(save_refined_draw_path, annotation["filename"])
        )
        Image.fromarray(original_mask.astype("uint8")).convert("RGB").save(
            os.path.join(save_annotation_draw_path, annotation["filename"])
        )
        Image.fromarray(modified_mask.astype("uint8")).convert("RGB").save(
            os.path.join(save_modified_annotation_draw_path, annotation["filename"])
        )
        Image.fromarray(sum_mask_image.astype("uint8")).convert("L").save(
            os.path.join(save_mask_path, annotation["filename"])
        )

In [2]:
import os
import random
import json
import copy
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2
import torch

from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DDIMScheduler
from controlnet_aux.processor import MidasDetector
import sys

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor

sys.path.insert(0, "../harmonization")
from harmonization import Harmonization
from gtgen.bpr import GtGenBPRInference


def make_inputs(annotation, target_indexs):
    height, width = annotation["metadata"]["height"], annotation["metadata"]["width"]
    mask = np.zeros((height, width))
    spot = None

    while True:
        target_index = random.choice(target_indexs)
        rb_spot = random_coordinate(annotation["annotations"], target_index, height, width)  # height, width 순

        if rb_spot is None:
            print("{} can not generate right bottom spot".format(annotation["filename"]))
            return None

        # rb_spot x값을 기준으로 height 선정 및 target_height 산출
        target_height = random.randint(350, 450)
        paste_mask = np.ones((target_height, target_height)) * 255

        sum_mask = add_mask(mask, paste_mask, rb_spot[1], rb_spot[0])

        if sum_mask is None:
            continue

        break

    sum_mask = Image.fromarray(sum_mask.astype("uint8")).convert("L")

    output = {"mask": sum_mask}
    return output


def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid


def find_outer_contour_coordinates(mask):
    # OpenCV의 findContours 함수를 사용하여 이진 이미지의 외곽선을 찾습니다.
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 외곽선을 이루는 점들의 좌표를 반환합니다.
    outer_contour_coords = [[], []]
    for contour in contours:
        for point in contour:
            x, y = point[0]
            outer_contour_coords[0].append(y)
            outer_contour_coords[1].append(x)

    return outer_contour_coords


def euclidean_distance(point1, point2):
    return np.sqrt((point2[0] - point1[0]) ** 2 + (point2[1] - point1[1]) ** 2)


def random_coordinate(annotation, target_index, height, width):
    mask = np.zeros((height, width))
    mask, state = polygon_to_mask(mask, annotation[target_index]["points"], color=255)

    if not state:
        return None

    for idx, ann in enumerate(annotation):
        if idx == target_index:
            continue

        mask, state = polygon_to_mask(mask, ann["points"], color=0)

    target_spots = np.argwhere(mask == 255).tolist()

    if len(target_spots) == 0:
        return None

    coordinates = find_outer_contour_coordinates(mask)
    threshold = 500

    # 랜덤으로 좌표 선택
    while True:
        target_spot = random.choice(target_spots)  # height,width 순
        distances = [
            euclidean_distance((coord[0], coord[1]), target_spot) for coord in zip(coordinates[0], coordinates[1])
        ]
        min_distance = int(min(distances))

        if min_distance >= threshold:
            return target_spot
        else:
            threshold = threshold // 2


def add_mask(mask, new_mask, right, bottom):
    mask_cp = mask.copy()

    # 새로운 마스크를 더할 위치 계산
    left = right - new_mask.shape[1]
    top = bottom - new_mask.shape[0]

    # 마스크 영역에 새로운 마스크 더하기
    if left < 0 or top < 0:
        return None

    mask_cp[top:bottom, left:right] += new_mask

    return mask_cp


def add_image(image, new_image, mask, right, bottom):
    image_cp = image.copy()

    # 새로운 마스크를 더할 위치 계산
    left = right - new_image.shape[1]
    top = bottom - new_image.shape[0]

    # 마스크 영역에 새로운 마스크 더하기
    if left < 0 or top < 0:
        return None

    for h in range(top, bottom):
        for w in range(left, right):
            if mask[h - top, w - left]:
                image_cp[h, w, :] = new_image[h - top, w - left, :]

    return image_cp


def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)


def make_result(image, mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    image_with_mask_contours = np.copy(image)
    cv2.drawContours(image_with_mask_contours, contours, -1, (0, 0, 255), 2)
    return image_with_mask_contours


def mask_refinement(image, ann, bpr_inference):
    seg_result = bpr_inference.inference(
        img=image,
        seg=ann,
        img_scale=(256, 256),
        img_ratios=[1.0, 2.0],
        nms_iou_threshold=0.5,
        point_density=0.25,
        patch_size=[32, 64, 96],
        padding=0,
    )

    height, width = image.shape[0], image.shape[1]
    result_map = np.zeros((height, width))

    for sr in seg_result["annotations"]:
        result_map, state = polygon_to_mask(result_map, sr["points"], 255)

    return result_map.astype("uint8")


def mask_to_polygon(mask):
    # 윤곽선 찾기
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 윤곽선을 다각형으로 변환
    polygons = []
    for contour in contours:
        contour = contour.squeeze(axis=1)  # 차원 축소
        polygon = contour[:, [0, 1]].tolist()  # (y, x) 순서로 변환하여 리스트로 저장
        polygons.append(polygon)

    return polygons


def polygon_to_mask(mask, polygons, color=255):
    polygons = np.array(polygons, dtype=np.int32)
    state = False

    try:
        mask = cv2.fillPoly(mask.astype("uint8"), [polygons], color)
        state = True
    except:
        print("mask passed!")

    return mask, state


def modify_annotation(annotations, polygons, height, width):
    # draw generated mask
    generated_mask = np.zeros((height, width))
    generated_annotations = []
    original_annotations = []

    for polygon in polygons:
        generated_mask, state = polygon_to_mask(generated_mask, polygon, 255)

        if state:
            ann = {
                "id": "",
                "type": "poly_seg",
                "attributes": {},
                "points": polygon,
                "label": "rider",
            }
            generated_annotations.append(ann)

    for annotation in annotations:
        # draw original mask
        original_mask = np.zeros((height, width))
        original_mask, state = polygon_to_mask(original_mask, annotation["points"], 255)

        if not state:
            continue

        # modify original mask
        original_mask = np.where((original_mask == 255) & (generated_mask == 255), 0, original_mask)
        original_polygons = mask_to_polygon(original_mask)

        for polygon in original_polygons:
            ann = copy.deepcopy(annotation)
            ann["points"] = polygon
            original_annotations.append(ann)

    original_annotations.extend(generated_annotations)
    return original_annotations


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model


# detect object using grounding DINO
def detect(image, image_source, text_prompt, model, box_threshold=0.5, text_threshold=0.35):
    boxes, logits, phrases = predict(
        model=model, image=image, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
    return annotated_frame, boxes


def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device_2), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return masks.cpu()


def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))


def get_mask(image, dino, sam):
    result_mask = np.zeros((image.shape[0], image.shape[1]))

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_torch, _ = transform(Image.fromarray(image).convert("RGB"), None)
    annotated_frame, detected_boxes = detect(
        image_torch, image, text_prompt="rider . bicycle . motorcycle .", model=dino
    )
    if len(detected_boxes) == 0:
        return None

    seg_result = segment(image, sam, boxes=detected_boxes)
    for seg_map in seg_result:
        mask = seg_map[0].cpu().numpy().astype(np.uint8) * 255
        result_mask = np.where(mask == 255, 255, result_mask)

    return result_mask


def crop_from_mask(image, mask):
    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # If no contours found, return original image
    if not contours:
        return image

    # Find the bounding box of the largest contour
    x, y, w, h = cv2.boundingRect(contours[0])

    # Crop the image using the bounding box
    cropped_image = image[y : y + h, x : x + w]
    cropped_mask = mask[y : y + h, x : x + w]

    return cropped_image, cropped_mask


device = "cuda:2"
# torch.cuda.set_device(device)

# 생성할 Annotation 정보
base_image_path = "/data/noah/dataset/magna_traffic_light/pre_images"
target_annotation_path = "/data/noah/dataset/magna_traffic_light/pre_anno"
target_class_name = "road"
target_height = None

save_base_path = "/data/noah/inference/magna_rv_inpainting"
save_result_path = os.path.join(save_base_path, "results")
save_draw_path = os.path.join(save_base_path, "draw_results")
save_refined_draw_path = os.path.join(save_base_path, "draw_results_refined")
save_annotation_draw_path = os.path.join(save_base_path, "annotation_draw")
save_modified_annotation_draw_path = os.path.join(save_base_path, "modified_annotation_draw")
save_mask_path = os.path.join(save_base_path, "masks")
make_dirs(
    [
        save_base_path,
        save_result_path,
        save_draw_path,
        save_refined_draw_path,
        save_modified_annotation_draw_path,
        save_annotation_draw_path,
        save_mask_path,
    ]
)

device_2 = "cuda:3"
harmonizer = Harmonization("/data/noah/ckpt/pretrain_ckpt/duconet/duconet1024.pth", device=device_2)
grounding_dino_ckpt_path = "/data/noah/ckpt/pretrain_ckpt/Grounding_DINO/groundingdino_swinb_cogcoor.pth"
grounding_dino_config_path = (
    "/workspace/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py"
)
grounding_dino = load_model(grounding_dino_config_path, grounding_dino_ckpt_path, device=device_2)
sam_ckpt_path = "/data/noah/ckpt/pretrain_ckpt/SAM/sam_vit_h_4b8939.pth"
sam_predictor = SamPredictor(build_sam(checkpoint=sam_ckpt_path).to(device_2))

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/rv_inpaint_5.1", torch_dtype=torch.float16
).to(device)
pipe.load_lora_weights(
    "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/lora_detail", weight_name="add_detail.safetensors"
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
# generator = torch.Generator(device=device).manual_seed(42)

prompt = "a bicycle rider is on the road, best quality, extremely detailed, clearness, naturalness, film grain, crystal clear, photo with color, actuality"
negative_prompt = "cartoon, anime, painting, disfigured, immature, blur, picture, 3D, render, semi-realistic, drawing, poorly drawn, bad anatomy, wrong anatomy, gray scale, worst quality, low quality, sketch, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"

for ann_idx, ann_name in tqdm(enumerate(os.listdir(target_annotation_path)[:10])):
    annotation_path = os.path.join(target_annotation_path, ann_name)

    with open(annotation_path, "r") as f:
        annotation = json.load(f)

    target_indexs = []

    for idx, ann in enumerate(annotation["annotations"]):
        if ann["label"] == target_class_name:
            target_indexs.append(idx)
            break

    if not len(target_indexs):
        continue

    height, width = annotation["metadata"]["height"], annotation["metadata"]["width"]
    image = Image.open(os.path.join(base_image_path, annotation["parent_path"][1:], annotation["filename"]))

    sum_mask_image = np.zeros((height, width))
    sum_result_image = np.array(image).astype("uint8")
    sum_draw_image = None

    generate_cnt = random.randint(3, 5)
    inputs = None
    generated_spots = []

    for iter_cnt in range(generate_cnt):
        inputs = make_inputs(annotation, target_indexs)

        if inputs is None:
            break

        mask = np.array(inputs["mask"]).astype("uint8")
        spot = np.argwhere(mask == 255)
        right, bottom = np.max(spot[:, 1]), np.max(spot[:, 0])

        if len(generated_spots) and np.any(np.all(np.isin(np.array(generated_spots), spot), axis=1)):
            continue

        result_image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=image,
            mask_image=inputs["mask"],
            height=image.height,
            width=image.width,
            num_inference_steps=25,
            guidance_scale=9.5,
            padding_mask_crop=32,
            # generator=generator
        ).images[0]

        result_image = harmonizer.harmonize(np.array(result_image), mask)
        result_image = result_image.astype("uint8")

        # cropped_img, _ = crop_from_mask(result_image, mask)
        # cropped_mask = get_mask(cropped_img, grounding_dino, sam_predictor)
        # if cropped_mask is None:
        #     continue
        # else:
        #     cropped_mask = cropped_mask.astype("uint8")
        #     k = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
        #     cropped_mask = cv2.morphologyEx(cropped_mask, cv2.MORPH_OPEN, k, iterations=3)

        # mask = add_mask(np.zeros((image.height, image.width)), cropped_mask, right, bottom)
        # spot = np.argwhere(mask == 255).tolist()

        generated_spots.extend(spot)

        for st in spot:
            sum_result_image[st[0], st[1], :] = result_image[st[0], st[1], :]

        sum_mask_image = sum_mask_image + mask

    if inputs is not None:
        # mask refinement
        polygons = mask_to_polygon(sum_mask_image)
        generated_annotation = copy.deepcopy(annotation)
        anns = []

        for polygon in polygons:
            an = {
                "id": "",
                "type": "poly_seg",
                "attributes": {},
                "points": polygon,
                "label": "rider",
            }
            anns.append(an)

        generated_annotation["annotations"] = anns

        # draw result image with mask
        sum_draw_image = make_result(np.copy(sum_result_image), sum_mask_image.astype("uint8"))

        # annotation 수정 작업 #
        modified_annotation = modify_annotation(
            annotation["annotations"],
            polygons,
            height,
            width,
        )

        # draw annotation
        original_mask = np.zeros((height, width, 3))
        modified_mask = np.zeros((height, width, 3))

        for ann in annotation["annotations"]:
            original_mask, state = polygon_to_mask(
                original_mask,
                ann["points"],
                color=(
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255),
                ),
            )

        for m_ann in modified_annotation:
            modified_mask, state = polygon_to_mask(
                modified_mask,
                m_ann["points"],
                color=(
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255),
                ),
            )

        Image.fromarray(sum_result_image.astype("uint8")).convert("RGB").save(
            os.path.join(save_result_path, annotation["filename"])
        )
        Image.fromarray(sum_draw_image.astype("uint8")).convert("RGB").save(
            os.path.join(save_draw_path, annotation["filename"])
        )
        Image.fromarray(original_mask.astype("uint8")).convert("RGB").save(
            os.path.join(save_annotation_draw_path, annotation["filename"])
        )
        Image.fromarray(modified_mask.astype("uint8")).convert("RGB").save(
            os.path.join(save_modified_annotation_draw_path, annotation["filename"])
        )
        Image.fromarray(sum_mask_image.astype("uint8")).convert("L").save(
            os.path.join(save_mask_path, annotation["filename"])
        )

KeyboardInterrupt: 