## Grounded-Segment-Anything

In [None]:
import os, sys

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))


os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import argparse
import os
import copy

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt


# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline


from huggingface_hub import hf_hub_download

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model   

ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)



def do_inference(prompt="dog", box_threshold=0.3, text_threshold=0.25, local_image_path='assets/animals/강아지1.jpeg', out_path='out.png'):
    image_source, image = load_image(local_image_path)
    boxes, logits, phrases = predict(
        model=groundingdino_model, 
        image=image, 
        caption=prompt, 
        box_threshold=box_threshold, 
        text_threshold=text_threshold,
        device=DEVICE
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[...,::-1] # BGR to RGB
    # Image.fromarray(annotated_frame).save('detection.png')


    sam_predictor.set_image(image_source)


    H, W, _ = image_source.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])


    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(DEVICE)
    masks, _, _ = sam_predictor.predict_torch(
                point_coords = None,
                point_labels = None,
                boxes = transformed_boxes,
                multimask_output = False,
            )


    def show_mask(mask, image, random_color=True):
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        
        annotated_frame_pil = Image.fromarray(image).convert("RGBA")
        mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

        return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

    annotated_frame_with_mask = show_mask(masks[0][0].cpu(), annotated_frame)

    Image.fromarray(annotated_frame_with_mask).save('test.png')

    image_mask = masks[0][0].cpu().numpy()
    image_mask_pil = Image.fromarray(image_mask).convert("L")
    image_source_pil = Image.fromarray(image_source)

    result_image = Image.composite(image_source_pil, Image.new("RGB", image_source_pil.size, (0, 0, 0)), image_mask_pil)

    result_image.save(out_path)

import os
folder_path = '/purestorage/project/tyk/9_Animation/Segmentation/Grounded-Segment-Anything/assets/animals'
out_folder_path = '/purestorage/project/tyk/9_Animation/Segmentation/Grounded-Segment-Anything/output'


print(os.listdir(folder_path))
for file_name in os.listdir(folder_path):
    if "강아지" in file_name:
        prompt = 'dog full body including clothes'
    elif "고양이" in file_name:
        prompt = 'cat full body including clothes'
    else:
        continue
    file_name_prefix = file_name.rsplit(".", 1)[0]
    # print(os.path.join(out_folder_path, f"{file_name_prefix}.png"))
    do_inference(prompt=prompt, box_threshold=0.3, 
    text_threshold=0.25, local_image_path=os.path.join(folder_path, file_name), out_path=os.path.join(out_folder_path, f"{file_name_prefix}.png"))