<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/diffusion/gligenzenn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Installation for gligen

In [None]:
!pip install transformers accelerate scipy safetensors
!git clone https://github.com/gligen/diffusers.git
!pip install git+https://github.com/gligen/diffusers.git
# Installation for GroundingDINO
%cd /content
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd /content/GroundingDINO
!pip install -q -e .
!pip install -q roboflow

Import Libraries

In [2]:
import argparse
from functools import partial
import cv2
import requests

from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import random


import warnings
warnings.filterwarnings("ignore")


import torch
from torchvision.ops import box_convert

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T

from huggingface_hub import hf_hub_download

Utils function

In [4]:
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model  

Load detection model and gligen model

In [None]:
# load detection model
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
dino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

In [None]:
# load gligen pipeline
from diffusers import StableDiffusionGLIGENPipeline

pipe = StableDiffusionGLIGENPipeline.from_pretrained("gligen/diffusers-inpainting-text-box", revision="fp16", torch_dtype=torch.float16)
pipe.to("cuda")

Inference


In [None]:
# download data for inference
%cd /content
!wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/art_dog_birthdaycake.png

In [8]:
# grounding dino detection
import os
import supervision as sv

local_image_path = "art_dog_birthdaycake.png"

TEXT_PROMPT = "dog. cake."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

image_source, image = load_image(local_image_path)

boxes, logits, phrases = predict(
    model=dino_model, 
    image=image, 
    caption=TEXT_PROMPT, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD
)

annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
annotated_frame = annotated_frame[...,::-1] # BGR to RGB

In [9]:
# display mask
def generate_masks_with_grounding(image_source, boxes):
    h, w, _ = image_source.shape
    boxes_unnorm = boxes * torch.Tensor([w, h, w, h])
    boxes_xyxy = box_convert(boxes=boxes_unnorm, in_fmt="cxcywh", out_fmt="xyxy").numpy()
    mask = np.zeros_like(image_source)
    for box in boxes_xyxy:
        x0, y0, x1, y1 = box
        mask[int(y0):int(y1), int(x0):int(x1), :] = 255
    return mask

In [None]:
image_mask = generate_masks_with_grounding(image_source, boxes)
Image.fromarray(annotated_frame)

In [None]:
#@title gligen image inpainting
image_source = Image.fromarray(image_source)
annotated_frame = Image.fromarray(annotated_frame)
image_mask = Image.fromarray(image_mask)
# Resize
image_source_for_inpaint = image_source.resize((512, 512))
image_mask_for_inpaint = image_mask.resize((512, 512))
# get bbox
xyxy_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").tolist()

# define prompts for each box
gligen_phrases = ['a cat', 'a rose']
prompt = "'a cat', 'a rose'"

num_box = len(boxes)

image_inpainting = pipe(
    prompt,
    num_images_per_prompt = 2,
    gligen_phrases = gligen_phrases,
    gligen_inpaint_image = image_source_for_inpaint,
    gligen_boxes = xyxy_boxes,
    gligen_scheduled_sampling_beta=1,
    output_type="numpy",
    num_inference_steps=50
).images

In [None]:
# display image
image_inpainting = (image_inpainting * 255).astype(np.uint8)
image_inpainting = np.concatenate(image_inpainting, axis=1)
Image.fromarray(image_inpainting).resize((image_source.size[0]*2, image_source.size[1]))