In [16]:
import pandas as pd
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from datasets import load_dataset 
from transformers.pipelines.pt_utils import KeyDataset
from tqdm import tqdm

import sys
sys.path.append('Fine-Grained-Hallucination/GroundingDINO/groundingdino')
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionXLInpaintPipeline, DDIMScheduler, AutoencoderKL
from transformers import SamModel, SamProcessor
from PIL import Image
import gradio as gr
import numpy as np

import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict
from huggingface_hub import hf_hub_download

import nltk

from utils import get_entity_identifier_model, get_nouns, prepare_dino

  from .autonotebook import tqdm as notebook_tqdm


ImportError: /root/miniconda3/envs/hallu/lib/python3.12/site-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkAddData_12_1, version libnvJitLink.so.12

In [24]:
sd_2 = "stabilityai/stable-diffusion-2"
sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
device = 'cuda'
sd_2_pipe = StableDiffusionPipeline.from_pretrained(sd_2).to(device='cuda')
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(sdxl).to(device='cuda')
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
dino_model = prepare_dino()

In [3]:
entity_pipeline = get_entity_identifier_model()

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 5336.26it/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s

In [8]:
def get_mask(image, prompt):
    transform = T.Compose(
    [
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
    )
    image_source = np.asarray(image)
    image_transformed, _ = transform(image.convert('RGB'), None)
    boxes, logits, phrases = predict(
            model=dino_model, 
            image=image_transformed, 
            caption=prompt, 
            box_threshold=0.4, 
            text_threshold=0.9,
            device='cuda'
        )
    if len(boxes)>0:
        print(f"Number of boxes predicted = {len(boxes)}")
        # boxes = boxes[0]
        input_boxes = []
        masks = {}
        for i, box in enumerate(boxes):
            x1 = ((box[0] - box[2]/2) * image_source.shape[1]).item()
            x2 = ((box[0] + box[2]/2) * image_source.shape[1]).item()
            y1 = ((box[1] - box[3]/2) * image_source.shape[0]).item()
            y2 = ((box[1] + box[3]/2) * image_source.shape[0]).item()
            inputs = sam_processor(image.convert('RGB'), input_boxes=[[[x1, y1, x2, y2]]], return_tensors='pt').to(device)
            with torch.no_grad():
                outputs = sam_model(**inputs)
            
            mask_out = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
            mask = mask_out[0].squeeze(0).permute(1,2,0).numpy()
            masks[f'{i+1}'] = Image.fromarray(mask[:,:,1])
        return masks
    else:
        return -1

In [5]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os
import re
from PIL import Image

class MaskingDataset(Dataset):
    def __init__(self, img_dir, prompt_csv):
        self.img_dir = img_dir
        self.csv = prompt_csv
        self.df = pd.read_csv(self.csv)
        self.paths = os.listdir(self.img_dir)
        self.pattern = re.compile("[0-9]+")
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = os.path.join(self.img_dir, self.paths[idx])
        img = Image.open(path)
        img_name = int(self.pattern.findall(path)[-1])
        prompt = self.df.loc[img_name, "Prompts"]
        return {"image": img, "prompt": prompt, 'img_name': img_name}

In [6]:
sdxl_dataset_path = '/mnt/data/workspace/misc/sdxl_outputs'
sd_2_dataset_path = '/mnt/data/workspace/misc/sd_2_outputs'
prompt_csv = '/mnt/data/workspace/misc/DrawBenchPrompts.csv'
sdxl_dataset = MaskingDataset(sdxl_dataset_path, prompt_csv)
sd_2_dataset = MaskingDataset(sd_2_dataset_path, prompt_csv)
sdxl_dl = DataLoader(sdxl_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
sd_2_dl = DataLoader(sd_2_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=False)

In [11]:
from PIL import ImageDraw
import torchvision.transforms as Tr
if not os.path.exists(os.path.join(sd_2_dataset_path, "masks")):
    os.makedirs(os.path.join(sd_2_dataset_path, 'masks'))
mask_dir = os.path.join(sd_2_dataset_path, 'masks')
for i, batch in enumerate(sd_2_dl):
    prompt = batch[0]['prompt']
    nouns = get_nouns(prompt, entity_pipeline).split(',')
    short_prompt = prompt.replace(' ', "_") if len(prompt.replace(' ', "_"))<200 else prompt.replace(' ', "_")[:200]
    if not os.path.exists(os.path.join(mask_dir, str(batch[0]['img_name'])+'-'+short_prompt)):
        os.makedirs(os.path.join(mask_dir, str(batch[0]['img_name'])+'-'+short_prompt))
    this_img_mask_dir = os.path.join(mask_dir, str(batch[0]['img_name'])+'-'+short_prompt)
    for idx, noun in enumerate(nouns):
        path_noun = noun.replace(' ', "_") if " " in noun else noun
        if not os.path.exists(os.path.join(this_img_mask_dir, f"{path_noun}")):
            os.makedirs(os.path.join(this_img_mask_dir, f"{path_noun}"))
        this_noun_dir = os.path.join(this_img_mask_dir, f"{path_noun}")
        this_mask = get_mask(batch[0]['image'], noun)
        if not isinstance(this_mask, int):
            for k, v in this_mask.items():
                v.save(f"{this_noun_dir}/{k}.png")

Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 2
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 2
Number of boxes predicted = 2
Number of boxes predicted = 2
Number of boxes predicted = 2
Number of boxes predicted = 2
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 3
Number of boxes predicted = 1
Number of boxes predicted = 2
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of boxes predicted = 2
Number of boxes predicted = 1
Number of boxes predicted = 1
Number of 

In [5]:
from ultralytics import YOLO
yolo_model = YOLO("yolov8x.pt")

Downloading https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x.pt to 'yolov8x.pt'...


100%|██████████| 131M/131M [00:00<00:00, 337MB/s] 


In [28]:
import os
for image_dir in os.listdir("/mnt/data/workspace/misc/sdxl_outputs/masks"):
    image_id = image_dir.split('-')[0]
    for root, dirs, files in os.walk(os.path.join("/mnt/data/workspace/misc/sdxl_outputs/masks", image_dir)):
        empty = True
        for subdir in dirs:
            subdir_path = os.path.join(root, subdir)
            if os.listdir(subdir_path):
                empty = False
        if empty:
            result = yolo_model([f"/mnt/data/workspace/misc/sdxl_outputs/{image_id}.jpg"])
            image = Image.open(f"/mnt/data/workspace/misc/sdxl_outputs/{image_id}.jpg")
            boxes = result[0].boxes.xyxy.cpu().numpy()
            for i, box in enumerate(boxes):
                inputs = sam_processor(image.convert('RGB'), input_boxes=[[[box]]], return_tensors='pt').to(device)
                with torch.no_grad():
                    outputs = sam_model(**inputs)
                
                mask_out = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
                mask = mask_out[0].squeeze(0).permute(1,2,0).numpy()
                mask = Image.fromarray(mask[:,:,1])
                mask.save(os.path.join(root, f'yolo_box_{i+1}.png'))
        break


0: 640x640 1 clock, 8.8ms
Speed: 1.9ms preprocess, 8.8ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 truck, 3 traffic lights, 8.4ms
Speed: 1.7ms preprocess, 8.4ms inference, 0.9ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 umbrella, 9.2ms
Speed: 1.9ms preprocess, 9.2ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 9.1ms
Speed: 2.2ms preprocess, 9.1ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 2 bananas, 8.7ms
Speed: 2.1ms preprocess, 8.7ms inference, 0.9ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 9.2ms
Speed: 1.9ms preprocess, 9.2ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 kite, 8.9ms
Speed: 1.7ms preprocess, 8.9ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 3 birds, 9.3ms
Speed: 1.8ms preprocess, 9.3ms inference, 1.0ms postprocess per image at s