# Multimodal Index

8 Jul 2025
- Added SHA256 hashing for key for hashing image/text query and storing of user query efficiently
- Change input type to "search_query" instead of "multimodal" or "image", as "search_query" can handle hybrid input types

Created by John Tan Chong Min, 7 Jul 2025
- This does multimodal query with image/text as hybrid inputs, with possibility of negative filters as well
- Modify your dataset with the folder name PARENT (default is "Fruits"), and all subfolders of that folder will be available for selection for filtering

This is the initial vibe coding chat for multimodal query: https://chatgpt.com/share/686cddad-036c-8006-befa-7150e31f17f7

### Dataset used: 
- https://www.kaggle.com/datasets/shreyapmaher/fruits-dataset-images
- Download and put into a folder Fruits in same directory as this Jupyter Notebook

File structure
+ Current Directory
    + Fruits (folder)
    + .env (containing COHERE_API_KEY inside)
    + embeddings.db (automatically generated with the sqlite3 in this code)
    + Multimodal_Index.ipynb (this notebook)
  
### Multimodal Embedding Used: 
- Cohere Embed v4 https://cohere.com/blog/embed-4

In [1]:
# Install dependencies
!pip install --quiet cohere python-dotenv numpy pillow gradio

import os
import base64
import sqlite3
import hashlib
import numpy as np
from pathlib import Path
from dotenv import load_dotenv
import cohere
from io import BytesIO
from PIL import Image as PILImage
import gradio as gr

# ─── Configuration ──────────────────────────────────────────────────────────────
MODEL = "embed-v4.0"
PARENT = "Fruits"
NUM_FILES_PER_FOLDER = 10

# Load Cohere API key from .env
dotenv_path = Path('.') / '.env'
if dotenv_path.exists():
    load_dotenv(dotenv_path)

COHERE_API_KEY = os.getenv("COHERE_API_KEY")
if not COHERE_API_KEY:
    raise RuntimeError("Please set COHERE_API_KEY in your .env file.")

co = cohere.ClientV2(api_key=COHERE_API_KEY)

# ─── Helper Functions ───────────────────────────────────────────────────────────
def load_image_as_data_url(path: Path) -> str:
    """Read an image file and return a data URL (base64-encoded)."""
    mime = "image/png" if path.suffix.lower() == ".png" else "image/jpeg"
    raw = path.read_bytes()
    b64 = base64.b64encode(raw).decode("utf-8")
    return f"data:{mime};base64,{b64}"

def pad_to_square(img: PILImage.Image, fill_color=(255, 255, 255)) -> PILImage.Image:
    """Pad a PIL image to a square with a white background."""
    w, h = img.size
    m    = max(w, h)
    new  = PILImage.new('RGB', (m, m), fill_color)
    new.paste(img, ((m - w) // 2, (m - h) // 2))
    return new

def pad_images(files):
    """Apply square-padding to a list of image files for gallery preview."""
    if not files:
        return []
    padded = []
    for f in files:
        img = PILImage.open(f)
        padded.append(pad_to_square(img))
    return padded

# ─── Embedding Cache (SQLite) ─────────────────────────────────────────────────
conn = sqlite3.connect("embeddings.db", check_same_thread=False)
cur = conn.cursor()
cur.execute("""
    CREATE TABLE IF NOT EXISTS embeddings (
        key           TEXT PRIMARY KEY,
        vector        BLOB,
        model_version TEXT
    )
""")
conn.commit()

def get_or_create_embedding(raw_key: str, inputs: list[dict], input_type: str, skip_cache: bool=False) -> np.ndarray:
    """
    Fetch an embedding from cache or compute it via Cohere.
    Caches result in SQLite for reuse, using a SHA-256 of the raw key.
    """
    # Compute cache key as SHA-256 hex digest
    hashed_key = hashlib.sha256(raw_key.encode('utf-8')).hexdigest()

    if not skip_cache:
        cur.execute(
            "SELECT vector FROM embeddings WHERE key=? AND model_version=?",
            (hashed_key, MODEL)
        )
        row = cur.fetchone()
        if row:
            return np.frombuffer(row[0], dtype=np.float32)

    resp = co.embed(
        model=MODEL,
        inputs=inputs,
        input_type=input_type,
        embedding_types=["float"],
        output_dimension=1536
    )
    emb = np.array(resp.embeddings.float, dtype=np.float32)[0]

    if not skip_cache:
        cur.execute(
            "INSERT OR REPLACE INTO embeddings (key, vector, model_version) VALUES (?, ?, ?)",
            (hashed_key, emb.tobytes(), MODEL)
        )
        conn.commit()

    return emb

# ─── Pre-index Dataset Images ──────────────────────────────────────────────────
exts = {".jpg", ".jpeg", ".png"}
paths = []

for sub in sorted(Path(PARENT).iterdir()):
    if sub.is_dir():
        for p in sorted(sub.glob("*.*"))[:NUM_FILES_PER_FOLDER]:
            if p.suffix.lower() in exts:
                paths.append(p)

# Compute and normalize embeddings for each image
def _build_index_embeddings(paths):
    embs = []
    for p in paths:
        key = f"img:{p.relative_to(PARENT)}"
        url = load_image_as_data_url(p)
        inp = [{"content": [{"type": "image_url", "image_url": {"url": url}}]}]
        emb = get_or_create_embedding(key, inp, "search_query")
        embs.append(emb)
    embs = np.vstack(embs)
    norms = np.linalg.norm(embs, axis=1, keepdims=True)
    return embs / (norms + 1e-12)

image_embs = _build_index_embeddings(paths)

# Subfolder filter defaults
subfolders      = [d.name for d in sorted(Path(PARENT).iterdir()) if d.is_dir()]
default_folders = subfolders.copy()

# ─── Search Function ────────────────────────────────────────────────────────────
def search(
    pos_text, pos_files,
    neg_text, neg_files,
    top_k, folders
):
    # Parse inputs
    texts_pos = [t.strip() for t in pos_text.splitlines() if t.strip()]
    texts_neg = [t.strip() for t in neg_text.splitlines() if t.strip()]
    pos_imgs   = [PILImage.open(fp) for fp in (pos_files or [])]
    neg_imgs   = [PILImage.open(fn) for fn in (neg_files or [])]

    if not texts_pos and not pos_imgs:
        raise gr.Error("Provide at least one positive text or image file.")

    def embed_items(texts, images, skip_cache=False):
        """Embed a mix of text + images into a single vector."""
        inputs = []
        key_parts = []
        # Text inputs
        for t in texts:
            inputs.append({"type": "text", "text": t})
            key_parts.append(t)
        # Image inputs (serialize to PNG)
        for im in images:
            buf = BytesIO()
            im.save(buf, format="PNG")
            b64 = base64.b64encode(buf.getvalue()).decode()
            inputs.append({
                "type": "image_url",
                "image_url": {"url": f"data:image/png;base64,{b64}"}
            })
            key_parts.append(b64)

        mode = "search_query"
        raw_key = "||".join(key_parts)
        emb = get_or_create_embedding(raw_key, [{"content": inputs}], mode, skip_cache)
        return emb / np.linalg.norm(emb)

    # Positive embedding & ranking
    emb_pos   = embed_items(texts_pos, pos_imgs)
    valid_idx = [i for i, p in enumerate(paths) if p.parent.name in folders]
    pos_sims  = image_embs.dot(emb_pos)
    ranked    = sorted(valid_idx, key=lambda i: -pos_sims[i])

    # Optional negative filtering
    if texts_neg or neg_imgs:
        emb_neg = embed_items(texts_neg, neg_imgs)
        neg_sims = image_embs.dot(emb_neg)
        exclude = set(sorted(valid_idx, key=lambda i: -neg_sims[i])[:top_k])
        ranked = [i for i in ranked if i not in exclude]

    # Return top-k results
    return [
        (str(paths[i]), f"{paths[i].relative_to(PARENT)} ({pos_sims[i]:.4f})")
        for i in ranked[:top_k]
    ]

# ─── CSS for Tight Single-Row Previews ─────────────────────────────────────────
css = """
#pos_preview {
    display: grid !important;
    grid-auto-flow: column !important;
    grid-auto-columns: auto !important;
    grid-template-rows: 1fr !important;
    gap: 0 !important;
    overflow-x: auto !important;
    max-height: 100px !important;
}
#neg_preview {
    display: grid !important;
    grid-auto-flow: column !important;
    grid-auto-columns: auto !important;
    grid-template-rows: 1fr !important;
    gap: 0 !important;
    overflow-x: auto !important;
    max-height: 100px !important;
}
#gallery {
    display: grid !important;
    grid-auto-flow: column !important;
    grid-auto-columns: auto !important;
    grid-template-rows: 1fr !important;
    gap: 0 !important;
    overflow-x: auto !important;
    max-height: 100px !important;
}
"""

# ─── Build Gradio UI ───────────────────────────────────────────────────────────
demo = gr.Blocks(css=css)

with demo:
    gr.Markdown("### Semantic Image Search - Text & Image")

    with gr.Row():
        pos_text   = gr.Textbox(label="Positive Text Queries (one per line)", lines=3)
        pos_files  = gr.Files(file_types=[".jpg", ".jpeg", ".png"],
                              label="Positive Image Files")
        pos_preview = gr.Gallery(label="Positive Previews",
                                 columns=10,
                                 elem_id="pos_preview")

    pos_files.change(pad_images, inputs=pos_files, outputs=pos_preview)

    with gr.Row():
        neg_text   = gr.Textbox(label="Negative Text Queries (one per line)", lines=3)
        neg_files  = gr.Files(file_types=[".jpg", ".jpeg", ".png"],
                              label="Negative Image Files")
        neg_preview = gr.Gallery(label="Negative Previews",
                                 columns=10,
                                 elem_id="neg_preview")

    neg_files.change(pad_images, inputs=neg_files, outputs=neg_preview)

    top_k     = gr.Slider(1, 20, value=5, step=1, label="Top K Results")
    folders   = gr.CheckboxGroup(choices=subfolders,
                                value=default_folders,
                                label="Filter Subfolders (optional)")
    search_btn = gr.Button("Search")

    with gr.Row():
        gallery = gr.Gallery(label="Results", columns=5)
        search_btn.click(fn=search,
                         inputs=[pos_text, pos_files, neg_text, neg_files, top_k, folders],
                         outputs=gallery)

if __name__ == '__main__':
    demo.launch(debug=True, show_error=True)

* Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


Keyboard interruption in main thread... closing server.
