In [28]:
from pathlib import Path
from PIL import Image
import torch, numpy as np
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def zero_shot_label(img, labels):
    with torch.no_grad():
        i = preprocess(img).unsqueeze(0).to(device)
        t = clip.tokenize([f"a photo of {lbl}" for lbl in labels]).to(device)
        i_feat = model.encode_image(i); t_feat = model.encode_text(t)
        probs = (i_feat @ t_feat.T).softmax(dim=-1).squeeze(0)
        idx = int(probs.argmax().item())
        return labels[idx], probs.tolist()

def evaluate(path_clean, path_poisoned, concept_C, concept_A):
    pc, pp = Path(path_clean), Path(path_poisoned)

    if pc.is_file() and pp.is_file():
        img_c = Image.open(pc).convert("RGB")
        img_p = Image.open(pp).convert("RGB")

        pred_c, probs_c = zero_shot_label(img_c, [concept_C, concept_A])
        pred_p, probs_p = zero_shot_label(img_p, [concept_C, concept_A])

        return {
            "clean_prediction": concept_A if pred_c != concept_C else concept_C,
            "poisoned_prediction": concept_A if pred_p != concept_C else concept_C,
            "C_A_prob_ratio_clean": probs_c,
            "C_A_prob_ratio_poison": probs_p,
            "odds of " + concept_C + " to " + concept_A + " (clean)": str(int(probs_c[0] / probs_c[1])) + " to 1",
            "odds of " + concept_C + " to " + concept_A + " (poisoned)": str(int(probs_p[0] / probs_p[1])) + " to 1"
        }

In [29]:
from pprint import pprint

# * concept c is the "correct" label
# * concept a is the poisoned label
# * ideally we want the image from the clean model to be classified as concept c, and the image from the poisoned model to be classified as concept a

results = evaluate(
    "clean.png",
    "poisoned.png",
    concept_C="dog",
    concept_A="cat",
)
pprint(results)

{'C_A_prob_ratio_clean': [1.0, 8.094310760498047e-05],
 'C_A_prob_ratio_poison': [0.98388671875, 0.015899658203125],
 'clean_prediction': 'dog',
 'odds of dog to cat (clean)': '12354 to 1',
 'odds of dog to cat (poisoned)': '61 to 1',
 'poisoned_prediction': 'dog'}
