In [None]:
import chromadb
import matplotlib.pyplot as plt
chroma_client = chromadb.PersistentClient(path="./chroma_db") 
collection = chroma_client.get_or_create_collection("image_collection")

# Function to get image paths
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

def get_image_paths(directory):
    valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
    return [os.path.join(directory, f) for f in os.listdir(directory) if f.lower().endswith(valid_extensions)]
    
def image_embeddings(model, image_paths):
    model.eval()
    
    image_tensors = [preprocess_image(img_path).to(device) for img_path in image_paths]
    image_tensors = torch.cat(image_tensors, dim=0)

    with torch.no_grad():
        image_features = model.image_encoder(image_tensors)
        image_features = F.normalize(image_features, dim=-1)
    return image_features.tolist()

# Get image paths
image_dir = "/kaggle/input/ritesh-gallery"
image_paths = get_image_paths(image_dir)
embeddings=image_embeddings(model,image_paths)

# Process and store in ChromaDB
data = [
    {"id": str(idx), "embedding": embedding, "metadata": {"path": img_path}}
    for idx, (embedding, img_path) in enumerate(zip(embeddings, image_paths))
]


collection.add(
    ids=[item["id"] for item in data],  
    embeddings=[item["embedding"] for item in data],  
    metadatas=[item["metadata"] for item in data]
)




In [None]:
query="laptop screen"

with torch.no_grad():
        query_features = model.text_encoder([query]).to(device)
        query_features = F.normalize(query_features, dim=-1)

query_features=query_features.tolist()


results = collection.query(
    query_embeddings=query_features,  
    n_results=1  
)

best_match_path = results["metadatas"][0][0]["path"]
print("Best matching image:", best_match_path)


image = Image.open(best_match_path)
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.axis("off")  
plt.title(f"Query: {query}", fontsize=14, fontweight="bold", color="blue")  
plt.show()