In [2]:
import torch
import clip
from PIL import Image
import os

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 338M/338M [02:53<00:00, 2.04MiB/s]


In [54]:
def get_features(image_folder):
    features = []
    labels = []
    for label, class_name in enumerate(os.listdir(image_folder)):
        class_path = os.path.join(image_folder, class_name)
        print(label)
        for img_file in os.listdir(class_path):
            image = preprocess(Image.open(os.path.join(class_path, img_file))).unsqueeze(0).to(device)
            with torch.no_grad():
                image_features = model.encode_image(image)
            features.append(image_features.cpu().numpy())
            labels.append(label)
    return torch.tensor(features).squeeze(), torch.tensor(labels)

features, labels = get_features("dataset")

0
1


In [25]:
from sklearn.metrics.pairwise import cosine_similarity

# Compute mean class embeddings
mean_my_wallet = features[labels == 0].mean(axis=0)
mean_friends_wallet = features[labels == 1].mean(axis=0)

# Predict new image
def predict(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        test_feature = model.encode_image(image).cpu().numpy()
    sim_my = cosine_similarity(test_feature, mean_my_wallet.reshape(1, -1))
    sim_friend = cosine_similarity(test_feature, mean_friends_wallet.reshape(1, -1))
    print(f"Similarity to my wallet: {sim_my[0][0]:.2f}")
    print(f"Similarity to friends wallet: {sim_friend[0][0]:.2f}")
    return "my_wallet" if sim_my > sim_friend else "friends_wallet"

In [36]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# Split data
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)

# Train
clf = LogisticRegression(random_state=0).fit(X_train, y_train)

# Evaluate
accuracy = clf.score(X_test, y_test)
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 100.00%


In [52]:
def predict(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        test_feature = model.encode_image(image).cpu().numpy()
    print(test_feature[0])
    return "my_wallet" if clf.predict(test_feature)[0] == 0 else "friends_wallet"

In [53]:
prediction = predict("IMG-20241216-WA0024.jpg")
print(f"Predicted class: {prediction}")

[-5.68098873e-02  2.97054760e-02 -1.52585968e-01  3.23875964e-01
  1.13797523e-01 -2.87126929e-01 -6.00193888e-02  6.39230490e-01
  6.25264049e-02  5.31900525e-01  3.13962400e-01 -2.76612073e-01
 -7.18602240e-02 -4.27433372e-01  3.47097814e-02 -1.76097900e-01
  2.09276095e-01 -4.39335912e-01  3.37403029e-01 -1.14696458e-01
  3.07476133e-01  4.65982676e-01  7.56941319e-01 -4.00471687e-02
  3.26511800e-01  8.15967377e-03 -5.56526065e-01 -2.09344506e-01
 -2.43107364e-01  4.48005944e-01  2.64304936e-01  5.28648913e-01
 -2.05887452e-01 -6.69756889e-01  6.48288965e-01 -2.06333607e-01
  9.58932936e-03 -1.45329982e-01 -3.03426385e-03 -4.25520599e-01
 -3.70879531e-01 -4.78750885e-01  7.72121251e-02 -2.33635247e-01
  5.84554672e-01  1.13166416e+00  7.63293743e-01 -1.37012750e-01
  2.97780409e-02 -3.47925276e-02  1.33829594e-01 -9.46335942e-02
  2.80874342e-01 -3.69074255e-01  1.06222905e-01  3.91911954e-01
  6.22467995e-01 -1.08815186e-01 -5.55823803e-01  1.94088459e-01
 -2.10244671e-01 -6.25454