In [1]:
from PIL import Image
import numpy as np
import os
import torch
import re
import torch

In [2]:
sdxl_dataset_path = '/mnt/data/workspace/misc/sdxl_outputs'
sd_2_dataset_path = '/mnt/data/workspace/misc/sd_2_outputs'
sdxl_mask_path = os.path.join(sdxl_dataset_path, 'masks')
sd_2_mask_path = os.path.join(sd_2_dataset_path, 'masks')

In [None]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
blip_processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
blip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", torch_dtype=torch.half).cuda()


2024-09-20 11:03:31.050598: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-20 11:03:31.050651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-20 11:03:31.051765: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-20 11:03:31.057915: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
meta_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-7B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
meta_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")

In [8]:
def get_centroid(mask):
    mask_array = np.array(mask)
    mask_coordinates = np.column_stack(np.where(mask_array > 0))
    if len(mask_coordinates) == 0:
        centroid = (None, None)
    else:
        # Calculate the median of the coordinates
        centroid = np.mean(mask_coordinates, axis=0)
        centroid = centroid[1], centroid[0]
    return tuple(centroid)

In [9]:
def get_dense_caption(mask, img):
    # prompt = "Describe the image with a focus on the intricate details of each object, including their colors, shapes, and numbers. Include any physical aspects that appear unusual or incorrect according to general knowledge."
    mask = np.array(mask)
    mask = mask[:, :, np.newaxis]
    mask = np.concatenate((mask, mask, mask), axis=2)
    img = np.array(img)
    img = mask * img
    img = Image.fromarray(img)
    prompt = 'Describe the image with a focus on the intricate details of the object, including their color, shape, and number. Include any physical aspects that appear unusual or incorrect according to general knowledge. You can ignore the pitch black background.'
    inputs = blip_processor(img, 
                            prompt, 
                            return_tensors="pt"
                            ).to("cuda", torch.float16)
    out = blip_model.generate(**inputs, max_length=200, do_sample=False)
    return blip_processor.decode(out[0], skip_special_tokens=True)

In [10]:
def get_meta_caption(dense_captions):
    pattern = r'<caption>(.*?)</caption>'
    prompt = """I am providing you with captions for sub-regions of an image. These captions will be provided by the corresponding centroids
for the objects in the sub-regions. I want you to stitch all the dense captions into one unified caption for the entire image.
You have to use the centroid information to deduce the relative positions of each of the objects. Do not add any new information
to the captions. Make the caption as short as possible without losing too many details. Any mention of a black background should be ignored. Do not hallucinate any details. Generate the final caption within the <caption></caption> tags.

"""
    for i, this_caption in enumerate(dense_captions):
        string = f"{i+1}. {this_caption['centroid']} {this_caption['caption']}\n"
        prompt = prompt + string
    messages = [
    {"role": "system", "content": "You are a meta image captioning model. You look at various sub-captions and create a meaningful grounded caption using those. You can  use additional provided information to facilitate spatial reasoning."},
    {"role": "user", "content": prompt}
]
    text = meta_tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
    model_inputs = meta_tokenizer([text], return_tensors="pt").to('cuda')
    generated_ids = meta_model.generate(
    model_inputs.input_ids,
    max_new_tokens=1024
)
    generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
    response = meta_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    try:
        response = re.findall(pattern, response, re.DOTALL)[0]
        return response
    except:
        return response

In [12]:
from PIL import Image, ImageDraw
import pandas as pd
from tqdm import tqdm
df = pd.read_csv('/mnt/data/workspace/misc/DrawBenchPrompts.csv')
for i, image_name in tqdm(enumerate(os.listdir(sdxl_dataset_path))):
    print(image_name)
    if not image_name.endswith('jpg'):
        continue
    idx = int(image_name.split('.')[0])
    image_path = os.path.join(sdxl_dataset_path, image_name)
    img = Image.open(image_path)
    # img.show()
    mask_dir = [i for i in os.listdir(sdxl_mask_path) if idx==int(i.split('-')[0])][0]
    mask_dir = os.path.join(sdxl_mask_path, mask_dir)
    dense_captions = []
    for root, nouns, masks in os.walk(mask_dir):
        for mask in masks:
            mask_path = os.path.join(root, mask)
            mask = Image.open(mask_path)
            centroid = get_centroid(mask)
            caption = get_dense_caption(mask, img)
            description = {"centroid": centroid, "caption": caption}
            dense_captions.append(description)
    meta_caption = get_meta_caption(dense_captions)
    print(meta_caption)
    print(df.loc[idx, 'Prompts'])
    df.loc[idx, 'Meta Caption'] = meta_caption

201it [00:00, 221798.24it/s]

98.jpg
106.jpg
70.jpg
151.jpg
54.jpg
190.jpg
9.jpg
18.jpg
39.jpg
6.jpg
22.jpg
129.jpg
40.jpg
198.jpg
138.jpg
51.jpg
90.jpg
91.jpg
125.jpg
110.jpg
104.jpg
157.jpg
0.jpg
181.jpg
186.jpg
123.jpg
169.jpg
65.jpg
62.jpg
46.jpg
99.jpg
48.jpg
115.jpg
141.jpg
160.jpg
85.jpg
148.jpg
126.jpg
178.jpg
77.jpg
101.jpg
162.jpg
124.jpg
100.jpg
92.jpg
23.jpg
97.jpg
103.jpg
75.jpg
93.jpg
43.jpg
130.jpg
191.jpg
149.jpg
155.jpg
179.jpg
113.jpg
44.jpg
187.jpg
21.jpg
2.jpg
27.jpg
172.jpg
56.jpg
147.jpg
84.jpg
16.jpg
25.jpg
29.jpg
83.jpg
180.jpg
133.jpg
122.jpg
143.jpg
59.jpg
114.jpg
34.jpg
71.jpg
139.jpg
61.jpg
12.jpg
60.jpg
7.jpg
195.jpg
63.jpg
127.jpg
52.jpg
88.jpg
94.jpg
95.jpg
102.jpg
33.jpg
masks
170.jpg
194.jpg
121.jpg
5.jpg
13.jpg
72.jpg
45.jpg
42.jpg
193.jpg
82.jpg
111.jpg
108.jpg
168.jpg
4.jpg
176.jpg
78.jpg
35.jpg
87.jpg
118.jpg
158.jpg
161.jpg
24.jpg
182.jpg
47.jpg
152.jpg
79.jpg
188.jpg
163.jpg
117.jpg
20.jpg
86.jpg
189.jpg
185.jpg
89.jpg
68.jpg
164.jpg
128.jpg
159.jpg
55.jpg
74.jpg
14.jpg
96.jpg




In [None]:
df.to_csv('meta_captions_sdxl.csv', index=False)

In [77]:
idx

98

In [78]:
df.loc[98, :]

Prompts         An elephant is behind a tree. You can see the ...
Category                                      Gary Marcus et al. 
Meta Caption    \nIn the scene, a sizable elephant, distinguis...
Name: 98, dtype: object

In [None]:
img.show()

In [14]:
mask_array = np.array(mask)

In [18]:
np.where(mask_array>0)

(array([  0,   0,   0, ..., 890, 890, 890]),
 array([  0,   1,   2, ..., 833, 834, 835]))

In [32]:
draw.ellipse(((0,0), (1000, 500)), fill='yellow')

In [None]:
img.show()