In [1]:
import torch
import clip
from PIL import Image
import os
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
my_wallet_prompts = [
    "a photo of my wallet",
    "an image of my wallet",
    "a picture of my wallet"
]
friends_wallet_prompts = [
    "a photo of my friend's wallet",
    "an image of my friend's wallet",
    "a picture of my friend's wallet"
]

In [12]:
def encode_prompts(prompts):
    text_inputs = torch.cat([clip.tokenize(prompt) for prompt in prompts]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    return text_features

my_wallet_features = encode_prompts(my_wallet_prompts)
friends_wallet_features = encode_prompts(friends_wallet_prompts)

In [13]:
def encode_image(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
    return image_features.cpu().numpy()

In [15]:
def classify(image_path):
    # Encode the image
    image_features = encode_image(image_path)
    
    # Compute similarity with my_wallet prompts
    my_wallet_similarity = cosine_similarity(image_features, my_wallet_features.cpu().numpy()).mean()
    
    # Compute similarity with friends_wallet prompts
    friends_wallet_similarity = cosine_similarity(image_features, friends_wallet_features.cpu().numpy()).mean()

    print(f"my_wallet_similarity: {my_wallet_similarity}")
    print(f"friends_wallet_similarity: {friends_wallet_similarity}")
    
    # Classify based on higher similarity
    if my_wallet_similarity > friends_wallet_similarity:
        return "my_wallet"
    else:
        return "friends_wallet"

In [22]:
test_image_path = "IMG-20241216-WA0015.jpg"

# Classify the image
prediction = classify(test_image_path)
print(f"Predicted class: {prediction}")

my_wallet_similarity: 0.29621604084968567
friends_wallet_similarity: 0.3248496651649475
Predicted class: friends_wallet
