In [None]:
%pip install numpy
%pip install tqdm
%pip install torch
%pip install torchvision
%pip install qdrant_client
%pip install fastapi
%pip install uvicorn
%pip install python-multipart
%pip install streamlit
%pip install streamlit-jupyter

# Listing all images

In [3]:
import os
from pathlib import Path

# Define image directory
image_dir = Path("./data/images/")

# Get all image paths
image_paths = [str(f) for f in image_dir.glob("*.jpg")]
print(f"Found {len(image_paths)} images")

# Generating Vector Embeddings

In [None]:
import numpy as np
from tqdm import tqdm
import torch
from torchvision import models, transforms
from PIL import Image

model = models.resnet50(pretrained=True)

# Remove the classification layer to get embeddings
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()

# Prepare image transformation
transform = transforms.Compose([
    transforms.Resize(1024),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

def get_image_embedding(image_path):
    # Load and transform image
    img = Image.open(image_path).convert('RGB')
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)
    
    # Get embedding
    with torch.no_grad():
        embedding = model(batch_t)
    
    # Flatten to 1D vector and return as numpy array
    return embedding.squeeze().cpu().numpy()




# Generate embeddings for all images
embeddings = {}
for img_path in tqdm(image_paths, desc="Generating embeddings"):
    try:
        # Get the image ID from the filename
        img_id = os.path.basename(img_path).split('.')[0]
        # Generate embedding
        embedding = get_image_embedding(img_path)
        embeddings[img_id] = embedding
    except Exception as e:
        print(f"Error processing {img_path}: {e}")

print(f"Generated embeddings for {len(embeddings)} images")

# Storing Vectors in a Vector Database

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct

# Initialize Qdrant client - local one
collection_name = "fruits_and_vegetables"

client = ""

if not client:
    client = QdrantClient(path="./qdrant_data")

# Create a new collection for our image embeddings
vector_size = next(iter(embeddings.values())).shape[0]  # Get dimension from first embedding


client.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
)

# Prepare points for upload
cnt = 1
embedded_images = []
for img_id, embedding in embeddings.items():
    cnt+=1
    embedded_images.append(PointStruct(
        id=cnt,
        vector=embedding.tolist(),
        payload={"image_path": str(image_dir / f"{img_id}.jpg"), "name": img_id}
    ))

# Upload in batches
batch_size = 10

for i in range(0, len(embedded_images), batch_size):
    client.upsert(
        collection_name=collection_name,
        points=embedded_images[i:i + batch_size]
    )

print(f"Uploaded {len(embedded_images)} embeddings to Qdrant")

# Creating the Search API

In [None]:
import nest_asyncio
import uvicorn
import threading
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import os
import tempfile

print(client)

# Apply nest_asyncio to allow running asyncio event loops within Jupyter
nest_asyncio.apply()

# Create the FastAPI app
app = FastAPI()

# Add CORS middleware to allow requests from the Jupyter environment
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/search")
async def search_similar(
    file: UploadFile = File(...),
    limit: int = 5
):
    # Save uploaded file temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp:
        temp.write(await file.read())
        temp_path = temp.name
    
    try:
        # Generate embedding for uploaded image
        query_embedding = get_image_embedding(temp_path)
        
        # Search for similar images
        search_results = client.search(
            collection_name="fruits_and_vegetables",
            query_vector=query_embedding.tolist(),
            limit=limit
        )
        
        # Format results
        results = []
        for res in search_results:
            results.append({
                "image_id": res.id,
                "image_path": res.payload["image_path"],
                "similarity": res.score
            })
        
        return {"results": results[:2]}
    finally:
        # Clean up
        os.unlink(temp_path)

# Function to start the FastAPI server in a separate thread
def run_fastapi(host="127.0.0.1", port=8000):
    server = uvicorn.Server(config=uvicorn.Config(app=app, host=host, port=port))
    thread = threading.Thread(target=server.run)
    thread.daemon = True
    thread.start()
    print(f"FastAPI running on http://{host}:{port}")
    return thread

# Start the server
fastapi_thread = run_fastapi()


# Display URL for testing
from IPython.display import display, HTML
display(HTML('<a href="http://127.0.0.1:8000/docs" target="_blank">Open FastAPI Swagger UI</a>'))