### Loading models

In [None]:
import torch
from diffusers import StableDiffusionPipeline, DiffusionPipeline
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import open_clip
device = 'cuda'
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16).cuda()
processor = AutoProcessor.from_pretrained(model_id)
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16).to(device)
processor_owl = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model_owl = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)





In [None]:
clip_model = clip_model.to(device)
qwen = qwen.to(device)

### Loading COCO

In [None]:
import json
from torchvision.datasets.folder import default_loader
from collections import defaultdict
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
class COCO(Dataset):
    def __init__(self, coco_dir, split='train', transform=None):
        self.image_dir = os.path.join(coco_dir, f"{split}2017/")
        with open(os.path.join(coco_dir, f"annotations/instances_{split}2017.json"), 'r') as file:
            coco = json.load(file)
        
        self.transform = transform
        self.annIm_dict = defaultdict(list)        
        self.cat_dict = {} 
        self.annId_dict = {}
        self.im_dict = {}

        for ann in coco['annotations']:           
            self.annIm_dict[ann['image_id']].append(ann) 
            self.annId_dict[ann['id']] = ann
        
        for img in coco['images']:
            self.im_dict[img['id']] = img
        
        for cat in coco['categories']:
            self.cat_dict[cat['id']] = cat

        
    def __len__(self):
        return len(list(self.im_dict.keys()))
    
    def __getitem__(self, idx):
        img = self.im_dict[idx]
        image = default_loader(os.path.join(self.image_dir, img['file_name']))
        #display(image)
        if self.transform is not None:
            #print(image
            image = self.transform(image)
        #print(image.shape,"::")

        targets = self.get_targets(idx)
        #print(targets)
        return image #targets
        
        
    def get_targets(self, idx):
        return [self.cat_dict[ann['category_id']]['name'] for ann in self.annIm_dict[idx]]
    
    def get_categories(self, supercategory):
        return [self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys() if self.cat_dict[cat_id]['supercategory']==supercategory]
    

    def get_all_supercategories(self):
        return {self.cat_dict[cat_id]['supercategory'] for cat_id in self.cat_dict.keys()}
    
    def get_spurious_supercategories(self):
        return ['kitchen', 'food', 'vehicle',
                'furniture', 'appliance', 'indoor',
                'outdoor', 'electronic', 'sports',
                'accessory', 'animal']
    
    def get_no_classes(self, supercategories):
        return len([self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys() if self.cat_dict[cat_id]['supercategory'] in supercategories])
    

    def get_imgIds(self):
        return list(self.im_dict.keys())
    
    def get_all_targets_names(self):
        return [self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys()]
    
    def get_imgIds_by_class(self, present_classes=[], absent_classes=[]):
        # Return images that has at least one of the present_classes, and none of the absent_classes
        ids = []
        for img_id in self.get_imgIds():
            targets = self.get_targets(img_id)
            flag = False
            for c in present_classes:
                if c in targets:
                    flag = True
                    break
            for c in absent_classes:
                if c in targets:
                    flag = False
                    break
            if flag:
                ids.append(img_id)
        return ids

In [None]:
COCO_PATH = "COCO"
dset = COCO(COCO_PATH)
supercategories = dset.get_spurious_supercategories()
no_classes = dset.get_no_classes(supercategories)
print(f"Number of classes: {no_classes}")
for supercategory in supercategories:
    classes = dset.get_categories(supercategory)
    print(f"Supercategory: {supercategory}, Classes: {classes}")
present = [x  for cat in dset.get_all_supercategories() for x in dset.get_categories(cat) ]
obj_hallucination = 'boat'
cat_spur_all = dset.get_imgIds_by_class(present_classes=present, absent_classes=[obj_hallucination])

### Sort Images based on their similarities

In [None]:
import torch.nn.functional as F
def load_and_compute_similarity(
    clip_model,
    exclude_indices,
    object_text,
    embeddings_path="clip_embeddings.pt",
    device="cuda"
):
    
    checkpoint = torch.load(embeddings_path)
    clip_embeds = checkpoint["clip_embeds"].to(device) 
    indices = checkpoint["indices"]         

    print(f"Loaded {clip_embeds.shape[0]} embeddings.")

    if exclude_indices is not None:
        mask = torch.isin(indices, torch.tensor(exclude_indices))
        clip_embeds = clip_embeds[mask]
        indices = indices[mask]
        print(f"Filtered down to {clip_embeds.shape[0]} embeddings after exclusion.")
    tokenizer = open_clip.get_tokenizer('ViT-H-14')
    tokens = tokenizer(object_text).to(device)
    with torch.no_grad():
        text_embed = clip_model.encode_text(tokens).to('cuda')
        text_embed = F.normalize(text_embed, dim=-1)[0]  

    clip_embeds = F.normalize(clip_embeds, dim=-1).to('cuda')  
    similarities = torch.matmul(clip_embeds, text_embed)  
    sorted_similarities, sorted_idx = torch.sort(similarities, descending=True)
    sorted_indices = indices[sorted_idx.cpu()]

    return sorted_indices, sorted_similarities

    

exclude_indices = []
for i in range(len(cat_spur_all)):
    exclude_indices.append(cat_spur_all[i])
filtered_indices, similarities = load_and_compute_similarity(
    clip_model,
    exclude_indices=exclude_indices,
    object_text=obj_hallucination,
    embeddings_path="clip_embeddings.pt"
)

# rank the top 5 matches
top5 = torch.topk(similarities, 5)
for score, idx in zip(top5.values, top5.indices):
    print(f"Index: {filtered_indices[idx]}, Similarity: {score:.4f}")
    image = dset[int(filtered_indices[idx])]
    display(image)


### Attack

In [None]:
def contains_obj_owlvit(image: Image.Image, score_threshold=0.1):
    texts = [[obj_hallucination]]  # supports multiple labels
    inputs = processor_owl(text=texts, images=image, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model_owl(**inputs)

    target_sizes = torch.tensor([image.size[::-1]])  # (H, W)
    results = processor_owl.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=score_threshold)[0]

    for score, label in zip(results["scores"], results["labels"]):
        if label == 0 and score > score_threshold:  
            return True
    return False

In [None]:
from PIL import Image
import random

def get_obj_owlvit(image: Image.Image, score_threshold=0.1):
    texts = [[obj_hallucination]]  # supports multiple labels
    inputs = processor_owl(text=texts, images=image, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model_owl(**inputs)

    target_sizes = torch.tensor([image.size[::-1]]).to("cuda")  
    results = processor_owl.post_process_object_detection(
        outputs=outputs,
        target_sizes=target_sizes,
        threshold=score_threshold
    )[0]

    boxes = results["boxes"]
    scores = results["scores"]
    labels = results["labels"]

    if len(boxes) == 0:
        return None

    top_idx = scores.argmax().item()
    box = boxes[top_idx].tolist() 
    obj_crop = image.crop(box)
    return obj_crop


In [None]:
templates = [
    "Do you see a {obj} in the image?",
    "Is there a {obj} here?",
    "Does the image contain a {obj}?",
    "Can you find a {obj} in this picture?",
    "Would you say there's a {obj} here?",
]

In [None]:
class ClipToQwenProjector(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.GELU(),
            nn.Linear(2048, 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU(),
            nn.Linear(4096, 3584),
            nn.GELU(),    
            nn.LayerNorm(3584)
        )

    def forward(self, x):
        x = self.mlp(x)
        return x  

In [None]:
epochs = 100
model = ClipToQwenProjector().cuda()
checkpoint = torch.load('Reverese.pt')
model.load_state_dict(checkpoint)

In [None]:
for param in qwen.parameters():
    param.requires_grad = False

for param in pipe.vae.parameters():
    param.requires_grad = False

for param in pipe.unet.parameters():
    param.requires_grad = False

for param in pipe.text_encoder.parameters():
    param.requires_grad = False

for param in model.parameters():
    param.requires_grad = False


In [None]:
def ask_qwen(prompt, img,model=None):
    with torch.no_grad():
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": img
                    },
                    {"type": "text", "text": prompt},
                ],
            }  
    
        ]
    
        # Preparation for inference
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")
        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=128)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        
    return output_text


In [None]:
import torch
import random
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast


def get_clip_embedding(image):
    inputs = torch.stack([clip_preprocess(image)]).to('cuda')
    with torch.no_grad():
        embedding = clip_model.encode_image(inputs)[0]
    return embedding

def get_qwen_inputs(prompt, image, clip_embed):
    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt},
        ]}
    ]
    mean = model(clip_embed)
    qwen_tokens = mean.repeat(64, 1)

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, _ = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt"
    ).to('cuda')

    inputs_embeds = qwen.get_input_embeddings()(inputs['input_ids'])

    n_image_tokens = (inputs['input_ids'] == qwen.config.image_token_id).sum().item()
    n_image_features = qwen_tokens.shape[0]

    if n_image_tokens != n_image_features:
        raise ValueError(
            f"Image features and image tokens do not match: tokens={n_image_tokens}, features={n_image_features}"
        )

    image_mask = (inputs['input_ids'] == qwen.config.image_token_id)\
        .unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)

    image_embeds = qwen_tokens.to(inputs_embeds.device, inputs_embeds.dtype)
    inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
    inputs['inputs_embeds'] = inputs_embeds

    return inputs


def get_qwen_probabilities(inputs):
    logits = qwen(**inputs).logits.float()
    logits_step = logits[:, -1, :]
    probs = torch.softmax(logits_step, dim=-1)

    return probs


def compute_loss(log_prob_yes, clip_embed, text_embedding, clip_embed_orig, clip_embed_gen=None):
    sim1 = F.cosine_similarity(clip_embed, text_embedding)
    sim2 = F.mse_loss(clip_embed, clip_embed_orig[0])
    sim3 = 0
    if clip_embed_gen is not None:
        sim3 = F.cosine_similarity(clip_embed, clip_embed_gen.unsqueeze(0))
        return log_prob_yes + 5 * sim1 + 5 * sim2 + 5 * sim3, sim1, sim2, sim3
    else:
        return log_prob_yes + 5 * sim1 + 5 * sim2, sim1, sim2, sim3


def generate_and_validate_image(pipe, prompt, clip_embed, qwen, processor, yes_id, no_id):
    result = pipe(
        negative_prompt="low quality, ugly, unrealistic",
        image_embeds=clip_embed.unsqueeze(0),
        guidance_scale=10
    )
    generated = result.images[0]
    torch.cuda.empty_cache()

    output = ask_qwen(prompt, generated, qwen)

    if output[0].lower().startswith("yes"):
        plt.imshow(generated)
        plt.show()
        return generated, output
    return None, output

def get_qwen_loss(probs,yes_id, no_id):
    
    prob_yes = probs[0, yes_id]
    prob_no = probs[0, no_id]
    log_prob_yes = -torch.log(prob_yes + 1e-8)

    return log_prob_yes




def main_loop(
    dset, object_hull
):
    tokenizer = open_clip.get_tokenizer('ViT-H-14')

    for i in range(0, len(filtered_indices)):
        prompt = random.choice(templates).format(obj=object_hull)
        image = dset[int(filtered_indices[i])].resize((224, 224))
        display(image)

        clip_embed = get_clip_embedding(image)
        output = ask_qwen(prompt, image, qwen)
        if output[0].lower().startswith("yes"):
            continue

        clip_embed = nn.Parameter(clip_embed)
        optimizer = torch.optim.SGD([clip_embed], lr=1)
        clip_embed_orig = clip_embed.clone().detach().unsqueeze(0)
        clip_embed_gen = None

        for epoch in tqdm(range(epochs)):
            prompt = random.choice(templates).format(obj=obj_hallucination)
            inputs = get_qwen_inputs(prompt, image, clip_embed)
            probs = get_qwen_probabilities(inputs)
            yes_id = processor.tokenizer("Yes", add_special_tokens=False)["input_ids"][0]
            no_id = processor.tokenizer("No", add_special_tokens=False)["input_ids"][0]
            log_prob_yes = get_qwen_loss(probs,yes_id,no_id)
            tokens = tokenizer(obj_hallucination).to('cuda')
            text_embedding = clip_model.encode_text(tokens).detach()
            loss, sim1, sim2, sim3 = compute_loss(
                log_prob_yes, clip_embed, text_embedding, clip_embed_orig, clip_embed_gen
            )
            print(sim1, sim2, sim3)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            generated_ids = qwen.generate(
                **inputs,
                max_new_tokens=1,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True
            )
            gen_probs = torch.softmax(generated_ids.scores[0], dim=-1)

            print("Generated probs:")
            print("  Yes:", gen_probs[0, yes_id].item())
            print("  No :", gen_probs[0, no_id].item())

            if gen_probs[0, yes_id] > gen_probs[0, no_id]:
                generated, output = generate_and_validate_image(
                    pipe, prompt, clip_embed.half(), qwen, processor, yes_id, no_id
                )
                if generated is not None:
                    if contains_obj_owlvit(generated, 0.3):
                        crop_gen = get_obj_owlvit(generated)
                        clip_inputs_gen = torch.stack([clip_preprocess(crop_gen)]).to('cuda')
                        with torch.no_grad():
                            clip_embed_gen = clip_model.encode_image(clip_inputs_gen).float()[0].detach()
                        print(f"It contains a  {obj_hallucination}")
                    else:
                        print(f"It does not contain a {obj_hallucination}")
                   



main_loop(dset, obj_hallucination)
