In [11]:
import torch
import clip
from PIL import Image
import pandas as pd
import faiss
import numpy as np

In [12]:
# Load the pre-trained CLIP model and its preprocessing pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [14]:
# Define a function to process image and extract embeddings
def process_image(image_paths):
    # Process each image and extract its embedding
    embeddings = []
    for path in image_paths:
        image = Image.open(path)
        # converts the image into a tensor format the model can process 
        image_input = preprocess(image).unsqueeze(0).to(device)
        with torch.no_grad():
            embedding = model.encode_image(image_input)
        embeddings.append(embedding)

    # Combine the embeddings by averaging the values of the views 
    img_embedding = torch.mean(torch.stack(embeddings), dim=0)
    return img_embedding



In [15]:
# Load the descriptions from data.csv
def load_descriptions():
    df = pd.read_csv('data.csv')
    descriptions = df['description'].tolist()
    return descriptions

# Encode the descriptions using the CLIP model
def encode_descriptions(descriptions):
    text_embeddings = []
    for description in descriptions:
        # convert the description into a tensor format the model can process
        text_input = clip.tokenize(description).to(device)
        with torch.no_grad():
            text_embedding = model.encode_text(text_input)
            
        text_embeddings.append(text_embedding.cpu().numpy())

    # Convert text embeddings to a numpy array
    text_embeddings = np.vstack(text_embeddings)
    return text_embeddings

In [21]:
# Use FAISS to find and search for the nearest neighbor
def find_and_search_nn(text_embeddings, img_embedding):
    # Use FAISS to find the nearest neighbor
    index = faiss.IndexFlatL2(text_embeddings.shape[1])
    index.add(text_embeddings)

    # Search for the nearest neighbor
    img_embedding_np = img_embedding.cpu().numpy()

    # distances is a 2D NumPy array containing distances to the neighbors 
    # indices is a 2D NumPy array containing the indices of the neighbors
    distances, indices = index.search(img_embedding_np, 1)
    return indices

In [22]:
# Get the closest description
def get_description(descriptions, indices):
    closest_description = descriptions[indices[0][0]]
    print(f'The closest description is: {closest_description}')

In [23]:
def main():
    image_paths = ["images/button_head_screw_bottom.jpg", 
                   "images/button_head_screw_front.jpg", 
                   "images/button_head_screw_isometric.jpg", 
                   "images/button_head_screw_top.jpg"]
    
    img_embedding = process_image(image_paths)
    descriptions = load_descriptions()
    text_embeddings = encode_descriptions(descriptions)
    indices = find_and_search_nn(text_embeddings, img_embedding)
    get_description(descriptions, indices)

if __name__ == "__main__":
    main()

The closest description is: Retains components using interlocking snap rings for a secure yet removable connection.
