In [1]:
import open_clip
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

# clip model
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    "MobileCLIP-B", pretrained="datacompdr", precision="fp16", device="cuda"
)
clip_model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
clip_tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

def image_text_sim(model, preprocess, tokenizer, img, text):
    image = preprocess(img).unsqueeze(0).to(device, dtype)
    text = tokenizer([text]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        text_probs = 100.0 * image_features @ text_features.T

    return text_probs.item()

## Mislabel Rate

This is a binary metric (yes mislabeled, or no not mislabeled) and should be averaged across many images to get a "rate" between 0 and 1. Here, we show the 1 image case.

In [2]:
import spacy
from nltk.stem import WordNetLemmatizer

def get_concepts(generated_caption):
    # Load nlp modules for early stopping
    lemmatizer = WordNetLemmatizer()
    nlp = spacy.load("en_core_web_sm")

    # Obtain set of candidate concepts
    doc = nlp(generated_caption)
    # Extract nouns
    identified_nouns = set(
        [token.text for token in doc if token.pos_ in {"NOUN", "PROPN"}]
    )
    return set([lemmatizer.lemmatize(x.lower()) for x in identified_nouns])

def get_top_n_concepts(
    img, concepts, clip_preprocess, clip_tokenizer, clip_model, n=3
):
    sims = [
        (
            concept,
            image_text_sim(
                clip_model, clip_preprocess, clip_tokenizer, img, concept
            ),
        )
        for concept in concepts
    ]
    top_n_concepts = sorted(sims, reverse=True, key=lambda x: x[1])[:n]

    return set([x[0] for x in top_n_concepts])

# this is the caption and concepts for the adversarially perturbed image
adv_caption = "a picturesque cobblestone street in a quaint town, with white-washed buildings showcasing traditional architecture, a couple walking hand in hand, and a signboard with text in a foreign language."
adv_concepts = get_concepts(adv_caption)

# this is the caption and concepts for the original image
original_caption = "a small bird with a yellow head and striped brown body stands on a dark surface, surrounded by green foliage."
original_concepts = get_concepts(original_caption)
original_img = Image.open("../clean.png")
top_original_concepts = get_top_n_concepts(original_img, original_concepts, clip_preprocess, clip_tokenizer, clip_model)

# this is the caption and concepts for the target image
target_concept = "street"
target_caption = "a picturesque cobblestone street in a historic town, with white-washed buildings, wooden balconies, and a backdrop of a hilly landscape."
target_concepts = get_concepts(target_caption)
target_img = Image.open("../target.png")

# clip scores, for checking semantic difference
clip_sim_original = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, original_img, adv_caption)
clip_sim_target = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, target_img, adv_caption)

# get top concepts that ONLY appear in the original image
# this prevents concepts that appear in both original and target from causing us to think image is not mislabeled
top_original_only_concepts = top_original_concepts - target_concepts

In [3]:
is_mislabeled = (
    all([original_concept not in adv_concepts for original_concept in top_original_only_concepts])
    and target_concept in adv_concepts
    and clip_sim_target > clip_sim_original
)
if is_mislabeled:
    print("Image is mislabeled")
else:
    print("Image is not mislabeled")

Image is mislabeled


## AAR and BAR

This measures how "tightly" the image is mislabeled to the target image. Again, this value should be averaged across multiple images. Here, we show the 1 image case.

In [4]:
# this is the caption for the adversarially perturbed image
adv_caption = "a picturesque cobblestone street in a quaint town, with white-washed buildings showcasing traditional architecture, a couple walking hand in hand, and a signboard with text in a foreign language."

# captions for images w/o perturbations
original_caption = "a small bird with a yellow head and striped brown body stands on a dark surface, surrounded by green foliage."
original_img = Image.open("../clean.png")

target_caption = "a picturesque cobblestone street in a historic town, with white-washed buildings, wooden balconies, and a backdrop of a hilly landscape."
target_img = Image.open("../target.png")

# clip scores for adversarial caption
clip_adv_original = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, original_img, adv_caption)
clip_adv_target = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, target_img, adv_caption)

# clip scores for original and target captions (denominator)
clip_original = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, original_img, original_caption)
clip_target = image_text_sim(clip_model, clip_preprocess, clip_tokenizer, target_img, target_caption)

In [5]:
# sometimes these values will overflow beyond 0 and 1, so we clamp
aar = torch.clamp(
    torch.tensor(clip_adv_target / clip_target), 
    0, 1
).item()
bar = torch.clamp(
    torch.tensor(clip_adv_original / clip_original), 
    0, 1
).item()

print(f"AAR: {aar:.2f}")
print(f"BAR: {bar:.2f}")

AAR: 1.00
BAR: 0.00
