In [1]:
import torch
import clip
from PIL import Image
import os
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

#CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def load_dataset(dataset_path):
    features = []
    labels = []
    class_names = ['my_wallet', 'friend_wallet']  # Fixed class names
    
    for label, class_name in enumerate(class_names):  # Use enumerate to get both label and class_name
        class_path = os.path.join(dataset_path, class_name)
        for img_file in os.listdir(class_path):
            image_path = os.path.join(class_path, img_file)
            image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
            with torch.no_grad():
                image_features = model.encode_image(image).cpu().numpy()
            features.append(image_features)
            labels.append(label)  # Append the numerical label (0 or 1)
    
    return np.vstack(features), np.array(labels)

dataset_path = "dataset"
features, labels = load_dataset(dataset_path)

X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

clf = LogisticRegression(random_state=42)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")

def classify(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image).cpu().numpy()
    
    prediction = clf.predict(image_features)
    probabilities = clf.predict_proba(image_features)[0]
    predicted_class = "my_wallet" if prediction[0] == 0 else "friend_wallet"
    # predicted_class_index = np.argmax(probabilities)
    # predicted_class = "my_wallet" if predicted_class_index == 0 else "friends_wallet"
    
    confidence_score = np.max(probabilities)
    
    return predicted_class, probabilities, confidence_score

Accuracy: 100.00%


In [2]:
test_image_path = "IMG-20241216-WA0018.jpg"

# Classify the image
prediction = classify(test_image_path)
print(f"Predicted class: {prediction}")

Predicted class: ('friend_wallet', array([0.14191626, 0.85808374]), np.float64(0.8580837428527495))
