In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.io import read_image
import numpy as np
import h5py
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import matplotlib.pyplot as plt

# Configuration
image_dir = './final_search_img_dir'
new_image_path = './candidate_img_dir/example_image.png'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the image preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_and_preprocess_image(image_path):
    """Load and preprocess an image from a file."""
    image = read_image(image_path)
    if image.shape[0] == 1:
        image = image.repeat(3, 1, 1)
    elif image.shape[0] == 4:
        image = image[:3, :, :]
    return preprocess(image).unsqueeze(0).to(device)

class VisualTransformerFeatureExtractor(nn.Module):
    """Feature extractor using a pre-trained Visual Transformer."""
    def __init__(self):
        super(VisualTransformerFeatureExtractor, self).__init__()
        base_model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1).to(device)
        self.features = base_model.get_features

    def forward(self, x):
        with torch.no_grad():
            x = self.features(x)
            return torch.flatten(x, 1)

def compute_embeddings(directory, model):
    """Compute embeddings for all images in the specified directory."""
    embeddings = {}
    for img_name in os.listdir(directory):
        img_path = os.path.join(directory, img_name)
        image = load_and_preprocess_image(img_path)
        embeddings[img_name] = model(image)
    return embeddings

def show_similar_images(candidate_path, similarities, image_dir, top_n=5):
    """Display the candidate image and its top N most similar images."""
    candidate_image = Image.open(candidate_path)
    sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_n]
    fig, axs = plt.subplots(1, top_n + 1, figsize=(15, 10))
    axs[0].imshow(candidate_image)
    axs[0].set_title("Candidate Image")
    axs[0].axis('off')
    for i, (img_name, sim_score) in enumerate(sorted_similarities, start=1):
        img_path = os.path.join(image_dir, img_name)
        image = Image.open(img_path)
        axs[i].imshow(image)
        axs[i].set_title(f"Match {i}\nScore: {sim_score:.2f}")
        axs[i].axis('off')
    plt.tight_layout()
    plt.show()

# Main execution flow
model = VisualTransformerFeatureExtractor()
model.eval()

embeddings = compute_embeddings(image_dir, model)
similarities = compare_new_image(new_image_path, embeddings, model)
show_similar_images(new_image_path, similarities, image_dir)


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/dlb/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100.0%


AttributeError: 'VisionTransformer' object has no attribute 'get_features'