In [1]:
import open_clip
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

# Loading mobile clip
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    "MobileCLIP-B", pretrained="datacompdr", precision="fp16", device="cuda"
)
clip_model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
clip_tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

def classification(model, preprocess, tokenizer, img, classes, gt):
    image = preprocess(img).unsqueeze(0).to(device, dtype)
    text = tokenizer(classes).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        text_probs = 100.0 * image_features @ text_features.T

    classification_scores = text_probs.detach().cpu()
    classification_idx = torch.argmax(classification_scores, dim=None, keepdim=False)
    if classes[classification_idx] == gt:
        return 1
    else:
        return 0

In [2]:
all_classes = [
    'cat',
    'yoga',
    'car',
    'man',
    'balloon',
    'butterfly',
    'tower',
    'laptop',
    'street',
    'moon',
    'statue',
    'firework',
    'forest',
    'seagull',
    'jellyfish',
    'waterfall',
    'boat',
    'cow',
    'road',
    'dome',
    'window',
    'motorcycle',
    'leaf',
    'candle',
    'alleyway',
    'bird',
    'sunflower',
    'drone',
    'beach',
    'dj',
    'cave',
    'silhouette',
    'guitar',
    'graffiti',
    'lighthouse',
    'smartphone',
    'tram',
    'egret',
    'girl',
    'sneaker',
    'camera',
    'deer',
    'winter',
    'archway',
    'disco',
    'wheel',
    'graduation',
    'bathtub',
    'windmill',
    'saree',
    'smoke',
    'basketball',
    'afro',
    'plant',
    'temple',
    'kite',
    'wristwatch',
    'bridal',
    'skateboard',
    'pathway',
    'conversation',
    'cocktail',
    'insect',
    'boxing',
    'cathedral',
    'surfer',
    'soap',
    'mushroom',
    'staircase',
    'headphone',
    'microscope',
    'book',
    'flame',
    'lightning',
    'frog',
    'fish',
    'robin',
    'rooftop',
    'bridge',
    'yarn',
    'henna',
    'pasta',
    'dining',
    'swan',
    'countryside',
    'sparkler',
    'clock',
    'tulip',
    'peacock',
    'castle',
    'train',
    'umbrella',
    'portrait',
    'sewing',
    'handstand',
    'dumpling',
    'dancer',
    'carousel',
    'television',
    'canyon'
]

In [3]:
# the prompt concept for the generated image should match its classification, a poisoned model should not generate images that are classified correctly
prompt_concept = "cat"
generated_img = Image.open("../generated.png")
generated_img_cls = classification(clip_model, clip_preprocess, clip_tokenizer, generated_img, all_classes, prompt_concept)

if generated_img_cls == prompt_concept:
    print("Generated image classified correctly (model does not appear poisoned)")
else:
    print("Generated image classified incorrectly (model poisoned)")

Generated image classified incorrectly (model poisoned)
