In [None]:
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LCMScheduler
from diffusers.utils import make_image_grid, load_image
from PIL import Image

resolution = 512

unet = UNet2DConditionModel.from_pretrained(
    "/home/sckim/Dataset/lcm_sd_background/checkpoint-430000/unet",
    subfolder="unet",
    torch_dtype=torch.float16,
)

pipe = StableDiffusionPipeline.from_pretrained(
    "/home/sckim/Dataset/sd", unet=unet, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

prompt = "the old house with the stairs up to it is located next to some beautiful trees, anime, cartoon, masterpiece, natural"

generator = torch.manual_seed(11)
background_image = pipe(
    prompt,
    num_inference_steps=4,
    guidance_scale=7.0,
    generator=generator,
    height=resolution,
    width=resolution,
).images[0]
display(background_image)

In [None]:
import os, sys

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))

import argparse
import copy

import cv2
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

from diffusers import StableDiffusionInpaintPipeline, LCMScheduler, UNet2DConditionModel
from diffusers.utils import make_image_grid

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor

device = "cuda:0"
torch.cuda.set_device(device)
generator = torch.Generator(device=device).manual_seed(0)


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    _ = model.eval()
    return model


def detect(image, image_tf, text_prompt, model, box_threshold=0.3, text_threshold=0.25):
    boxes, logits, phrases = predict(
        model=model, image=image_tf, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold
    )

    annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
    return annotated_frame, boxes


def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return masks.cpu()


def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))


grounding_dino_path = "/home/sckim/Dataset/groundingdino_swinb_cogcoor.pth"
grounding_dino_config_filename = (
    "/workspace/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py"
)
grounding_dino_model = load_model(grounding_dino_config_filename, grounding_dino_path, device=device)

sam_path = "/home/sckim/Dataset/sam_vit_h_4b8939.pth"
sam_predictor = SamPredictor(build_sam(checkpoint=sam_path).to(device))

image_path = "/home/sckim/Dataset/background_inference/2.jpeg"

# load base and mask image
image, image_tf = load_image(image_path)
annotated_frame, detected_boxes = detect(image, image_tf, text_prompt="character .", model=grounding_dino_model)


segmented_frame_masks = segment(image, sam_predictor, boxes=detected_boxes)
annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)

mask = segmented_frame_masks[0][0].cpu().numpy()
# inverted_mask = ((1 - mask) * 255).astype(np.uint8)

image = Image.fromarray(image)
mask = Image.fromarray(mask)

image = image.resize((resolution, resolution))
mask = mask.resize((resolution, resolution))

########################################################################################################################
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "/home/sckim/Dataset/sd",
    torch_dtype=torch.float16,
    # in_channels=9,
    # low_cpu_mem_usage=False,
    # ignore_mismatched_sizes=True,
    safety_checker=None,
)
pipeline.load_ip_adapter(
    "/home/sckim/Dataset/ip_adapter/",
    subfolder="models",
    weight_name="ip-adapter_sd15.bin",
)
pipeline = pipeline.to(device)

prompt = "remove"
negative_prmopt = ""

result_image = pipeline(
    prompt=prompt,
    image=background_image,
    mask_image=mask,
    strength=1.0,
    guidance_scale=9.0,
    num_inference_steps=25,
    height=resolution,
    width=resolution,
    ip_adapter_image=image,
    generator=generator,
).images[0]

grid = make_image_grid(images=[image, mask, background_image, result_image], rows=1, cols=4)
display(grid)

In [None]:
import os, sys

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))

import argparse
import copy

import cv2
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

from diffusers import StableDiffusionInpaintPipeline, LCMScheduler, UNet2DConditionModel
from diffusers.utils import make_image_grid

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor

device = "cuda:0"
torch.cuda.set_device(device)
generator = torch.Generator(device=device).manual_seed(42)


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    _ = model.eval()
    return model


def detect(image, image_tf, text_prompt, model, box_threshold=0.3, text_threshold=0.25):
    boxes, logits, phrases = predict(
        model=model, image=image_tf, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold
    )

    annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
    return annotated_frame, boxes


def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return masks.cpu()


def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))


grounding_dino_path = "/home/sckim/Dataset/groundingdino_swinb_cogcoor.pth"
grounding_dino_config_filename = (
    "/workspace/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py"
)
grounding_dino_model = load_model(grounding_dino_config_filename, grounding_dino_path, device=device)

sam_path = "/home/sckim/Dataset/sam_vit_h_4b8939.pth"
sam_predictor = SamPredictor(build_sam(checkpoint=sam_path).to(device))

image_path = "/home/sckim/Dataset/background_inference/2.jpeg"

# load base and mask image
image, image_tf = load_image(image_path)
annotated_frame, detected_boxes = detect(image, image_tf, text_prompt="character .", model=grounding_dino_model)


segmented_frame_masks = segment(image, sam_predictor, boxes=detected_boxes)
annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)

mask = segmented_frame_masks[0][0].cpu().numpy()
inverted_mask = ((1 - mask) * 255).astype(np.uint8)

image = Image.fromarray(image)
mask = Image.fromarray(inverted_mask)

############################################################################################################################

unet = UNet2DConditionModel.from_pretrained(
    "/home/sckim/Dataset/lcm_sd_background/checkpoint-430000/unet",
    torch_dtype=torch.float16,
)
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "/home/sckim/Dataset/sd",
    unet=unet,
    torch_dtype=torch.float16,
    in_channels=9,
    low_cpu_mem_usage=False,
    ignore_mismatched_sizes=True,
    safety_checker=None,
)
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
# pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
pipeline = pipeline.to(device)

prompt = "an anime scene with a clock tower and grassy field"
negative_prmopt = ""
resolution = 512
image = image.resize((resolution, resolution))
mask = mask.resize((resolution, resolution))

k = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# mask = Image.fromarray(cv2.erode(np.array(mask), k))
mask = Image.fromarray(cv2.dilate(np.array(mask), k))

result_image = pipeline(
    prompt=prompt,
    image=image,
    mask_image=mask,
    strength=1.0,
    guidance_scale=7.0,
    num_inference_steps=5,
    height=resolution,
    width=resolution,
    generator=generator,
).images[0]

grid = make_image_grid(images=[image, mask, result_image], rows=1, cols=3)
display(grid)