# ColQwen Embeddings Example

This notebook demonstrates how to use the ColQwen Inference service to generate embeddings for both text and images, and how to perform similarity search using these embeddings.

## Setup

First, let's import the necessary dependencies and initialize our client. Make sure you have deployed the service to Modal and have the endpoint URL.

In [None]:
import sys
import numpy as np
from IPython.display import HTML, display

sys.path.append("..")
from src.client import ColpaliClient

In [None]:
# Replace with your Modal endpoint URL
MODAL_BASE_APP_URL = "your_modal_endpoint"

colpali_client = ColpaliClient(base_url=MODAL_BASE_APP_URL)

## Text Embeddings

Let's generate embeddings for a text query. The embeddings can be used for semantic search, similarity comparison, or other NLP tasks.

In [None]:
# Example text query
QUERY = "What is machine learning?"

# Generate embeddings
result_query = await colpali_client.embed_text(QUERY)
query_embedding = np.array(result_query)
print(f"Query: {QUERY}")
print(f"Embedding shape: {query_embedding.shape}")

## Image Embeddings

Now let's generate embeddings for images. We'll first convert images to base64 format, then process them in batches.

In [None]:
import base64
from pathlib import Path

def image_to_base64(image_path: str) -> str:
    """Convert a single image to base64 string."""
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

# Path to your image directory
DOCUMENT_PATH = "data/images"

# Get all jpg images in the directory
image_paths = [str(p) for p in Path(DOCUMENT_PATH).glob("*.jpg")]
print(f"Found {len(image_paths)} images")

# Convert images to base64
base64_images = [image_to_base64(image) for image in image_paths]

# Process images in batches (example with first 2 images)
batch = base64_images[:2]
image_embeddings = await colpali_client.embed_images(batch, batch_size=2)
image_embedding = np.array(image_embeddings)
print(f"Image embeddings shape: {image_embedding.shape}")

## Similarity Search

Now we'll implement functions to perform similarity search between text queries and images using the generated embeddings.

In [None]:
def score_multi_vector_numpy(qs: np.ndarray, ps: np.ndarray) -> np.ndarray:
    """Compute similarity scores between query and passage embeddings.
    
    Args:
        qs: Query embeddings of shape (b1, n, d) or (n, d)
        ps: Passage embeddings of shape (b2, s, d) or (s, d)
        
    Returns:
        Similarity scores of shape (b1, b2)
    """
    # Add batch dimension if needed
    if qs.ndim == 2:
        qs = qs[np.newaxis, ...]
    if ps.ndim == 2:
        ps = ps[np.newaxis, ...]

    # Compute dot products
    scores_4d = np.einsum("bnd,csd->bcns", qs, ps)
    
    # Max pooling over passage dimension
    scores_max = scores_4d.max(axis=3)
    
    # Sum over query dimension
    scores_2d = scores_max.sum(axis=2)
    
    return scores_2d

def get_top_k_images(
    text_tokens: np.ndarray,
    image_embeddings: np.ndarray,
    k: int = 30,
    threshold: float = 0.1,
) -> list:
    """Get top-k most similar images for a text query.
    
    Args:
        text_tokens: Query embedding of shape (m, d) or (1, m, d)
        image_embeddings: Image embeddings of shape (N, n, d)
        k: Number of results to return
        threshold: Minimum similarity score threshold
        
    Returns:
        List of (image_index, score) tuples sorted by score
    """
    # Add batch dimension if needed
    if text_tokens.ndim == 2:
        text_tokens = text_tokens[np.newaxis, ...]
    if image_embeddings.ndim == 2:
        image_embeddings = image_embeddings[np.newaxis, ...]

    # Compute similarity scores
    scores_2d = score_multi_vector_numpy(text_tokens, image_embeddings)[0]

    # Sort by score
    sorted_by_score = sorted(enumerate(scores_2d), key=lambda x: x[1], reverse=True)

    # Filter by threshold
    filtered = [(idx, float(s)) for idx, s in sorted_by_score if s >= threshold]

    return filtered[:k]

## Visualize Results

Finally, let's visualize the top matching images for our text query.

In [None]:
# Get top k results
top_k_results = get_top_k_images(
    text_tokens=query_embedding,
    image_embeddings=image_embedding,
    k=10,
    threshold=0.1
)

# Display results in a grid
html = '<div style="display: flex; flex-wrap: wrap; gap: 20px;">'
for img_idx, score in top_k_results:
    img_path = image_paths[img_idx]
    html += f'''
        <div style="text-align: center;">
            <img src="{img_path}" style="max-width: 300px; height:auto; margin: 10px;">
            <p>Image {img_idx + 1}<br>Score: {score:.2f}</p>
        </div>
    '''
html += "</div>"

display(HTML(html))