# Generate Source Image and Mask of Rider

In [None]:
from PIL import Image
import numpy as np
import cv2
import json
import os
from tqdm import tqdm


def mask_to_polygon(mask):
    # 윤곽선 찾기
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 윤곽선을 다각형으로 변환
    polygons = []
    for contour in contours:
        contour = contour.squeeze(axis=1)  # 차원 축소
        polygon = contour[:, [0, 1]].tolist()  # (y, x) 순서로 변환하여 리스트로 저장
        polygons.append(polygon)

    return polygons


def polygon_to_mask(mask, polygons, color=255):
    polygons = np.array(polygons, dtype=np.int32)
    state = False

    try:
        mask = cv2.fillPoly(mask.astype("uint8"), [polygons], color)
        state = True
    except:
        print("mask passed!")

    return mask, state


def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)


def crop_from_mask(image, mask, padding=50):
    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # If no contours found, return original image
    if not contours:
        return image

    # Find the bounding box of the largest contour
    x, y, w, h = cv2.boundingRect(contours[0])
    
    # padding
    x = x-padding//2
    y = y-padding//2
    w = w+padding
    h = h+padding
    
    x = 0 if x<0 else x
    y = 0 if y<0 else y
    
    if x+w>image.shape[1]:
        w = image.shape[1]-x

    if y+h>image.shape[0]:
        h = image.shape[0]-y

    # Crop the image using the bounding box
    cropped_image = image[y : y + h, x : x + w]
    cropped_mask = mask[y : y + h, x : x + w]

    return cropped_image, cropped_mask


annotation_path = "/data/noah/dataset/coco_rider/anno"
out_base_path = "/data/noah/inference/magna_rider_premask"
out_mask_path = os.path.join(out_base_path, "_masks")
out_image_path = os.path.join(out_base_path, "_images")

make_dirs([out_base_path, out_mask_path, out_image_path])
cnt = 0
padding = 200
threshold = 200 + padding

for name in tqdm(os.listdir(annotation_path)):
    ann_path = os.path.join(annotation_path, name)

    with open(ann_path, "r") as f:
        ann = json.load(f)

    height, width = ann["metadata"]["height"], ann["metadata"]["width"]
    mask = np.zeros((height, width))

    for _ann in ann["annotations"]:
        points = _ann['points']
        for point in points:
            point = np.array(point, dtype=np.int32)
            try:
                mask = cv2.fillPoly(mask, [point], color=255)
            except:
                continue

    polygons = mask_to_polygon(mask)
    annotations = []
    image = Image.open(os.path.join(ann["parent_path"], ann["filename"]))

    for polygon in polygons:
        cnt += 1
        mask = np.zeros((height, width))
        mask, state = polygon_to_mask(mask, polygon)

        if state:
            crop_image, crop_mask = crop_from_mask(np.array(image).astype("uint8"), mask.astype("uint8"),padding=padding)
            crop_image = Image.fromarray(crop_image)
            crop_mask = Image.fromarray(crop_mask).convert("L")
            if crop_image.height < threshold:
                continue

            crop_image.save(os.path.join(out_image_path, "{}.png".format(cnt)))
            crop_mask.save(os.path.join(out_mask_path, "{}.png".format(cnt)))

# Rider Setting

In [None]:
import os
from tqdm import tqdm

import cv2
from PIL import Image
import numpy as np

import torch
from controlnet_aux.processor import MidasDetector
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler

def make_dirs(paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)

def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid

def closest_multiple_of_8(number):
    closest_multiple = (number // 8) * 8  # 가장 가까운 8의 배수
    return closest_multiple

device = 'cuda:3'
instance_height = 1024
k = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
image_dir = '/data/noah/inference/magna_rider_premask/images'
mask_dir = '/data/noah/inference/magna_rider_premask/masks'
out_image_dir = '/data/noah/inference/magna_object/rider/images'
out_mask_dir = '/data/noah/inference/magna_object/rider/masks'
make_dirs([out_image_dir, out_mask_dir])

prompt_types = ["a rider", "a motorcycle rider", "a bicycle rider"]
prompts = ["{}, RAW photo, subject, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, <lora:add-detail:1>".format(ptype) for ptype in prompt_types]
negative_prompts = ["(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), blurry, text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, UnrealisticDream"]*len(prompt_types)
num_inference_steps = 25
guidance_scale = 7.5
strength = 1.0
sag_scale = 0.75
controlnet_conditioning_scale = 0.75
padding_mask_crop=0
num_images_per_prompt=1

model_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/rv_inpaint_5.1"
controlnet_id = "/data/noah/ckpt/finetuning/controlnet_inpaint_coco_rider/checkpoint-21000/controlnet"
lora_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/lora_detail"
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    model_id, controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.load_lora_weights(lora_id, weight_name="add_detail.safetensors")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_freeu(s1=1.2, s2=0.5, b1=1.2, b2=1.4)


midas = MidasDetector.from_pretrained("lllyasviel/Annotators").to(device)


# Rider Generation

In [None]:
for name in tqdm(os.listdir(image_dir)):
    image_path = os.path.join(image_dir, name)
    mask_path = os.path.join(mask_dir, name)
    
    image = Image.open(image_path)
    mask = Image.open(mask_path)
    
    # generate condition image
    con_image = midas(image, image_resolution=image.height)

    height = closest_multiple_of_8(1024)
    ratio = instance_height/image.height
    width = closest_multiple_of_8(int(ratio*image.width))
    
    image = image.resize((width, height))
    mask = mask.resize((width, height))
    con_image = con_image.resize((width, height))

    # mask boundary refinement
    mask = np.array(mask)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k, iterations=3)    
    spots = np.argwhere(mask == 255)

    image = np.array(image)
    masked_image = np.ones((image.shape))*255

    for spot in spots:
        masked_image[spot[0], spot[1], :] = image[spot[0], spot[1], :]

    mask = Image.fromarray(mask.astype('uint8')).convert('L')
    masked_image = Image.fromarray(masked_image.astype('uint8'))
    image = Image.fromarray(image.astype('uint8'))
            
    result_images = pipe(
            prompt=prompts,
            negative_prompt=negative_prompts,
            image=masked_image,
            control_image=con_image,
            mask_image=mask,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            strength=strength,
            sag_scale=sag_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            padding_mask_crop=padding_mask_crop,
            num_images_per_prompt=num_images_per_prompt
        ).images

    for idx, result_image in enumerate(result_images):
        prompt_type = prompt_types[idx]
        output_image_path = os.path.join(out_image_dir, name[:-4]+'_{}'.format(prompt_type)+name[-4:])
        output_mask_path = os.path.join(out_mask_dir, name[:-4]+'_{}'.format(prompt_type)+name[-4:])
        result_image.save(output_image_path)
        mask.save(output_mask_path)
        
    