# CLIP Semantic Search — Complete Notebook (Image→Image, Text→Image, Image+Text→Image)

```
A single, self-contained notebook that lets you:

Image → Image search with CLIP (cosine)
Text → Image search (cosine)
Image + Text → Image fused search (weighted cosine)
✅ Works on CPU or GPU. Optional FAISS index for speed.
```

# Install dep

In [None]:
!pip install torch torchvision transformers pillow tqdm numpy matplotlib
!pip install faiss-cpu


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 2) Imports & Configuration

In [3]:
from __future__ import annotations
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Iterable, Optional, Union

import os, json, math, time, dataclasses
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm

import torch
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor

try:
    import faiss  # optional
    HAS_FAISS = True
except Exception:
    HAS_FAISS = False

# Configure your paths
PHOTOS_DIR = Path("/content/drive/MyDrive/digi/digi_image-similarity/Photos") # 👈 replace
INDEX_DIR  = Path('./index')
INDEX_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = 'openai/clip-vit-base-patch32'  # Swap to a larger CLIP for better recall
BATCH_SIZE = 64                              # Reduce if you hit OOM


# 3) Utilities

In [4]:
IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tif', '.tiff'}

# List images recursively

def list_images(root: Path) -> List[Path]:
    files = []
    for p in root.rglob('*'):
        if p.is_file() and p.suffix.lower() in IMG_EXTS:
            files.append(p)
    files.sort()
    return files

# Open with EXIF orientation and convert to RGB

def open_image(path: Path) -> Image.Image:
    with Image.open(path) as im:
        im = ImageOps.exif_transpose(im)
        return im.convert('RGB')

# Chunk an iterable

def chunked(it: Iterable, size: int):
    buf = []
    for x in it:
        buf.append(x)
        if len(buf) == size:
            yield buf; buf = []
    if buf:
        yield buf

# Row-wise L2 normalize

def normalize(a: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(a, axis=1, keepdims=True) + 1e-12
    return a / norms

# 4) CLIP Encoder (image + text towers)


In [None]:
class ClipEncoder:
    def __init__(self, model_name: str = MODEL_NAME, device: Optional[str] = None):
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.model.eval()
        self.processor = CLIPProcessor.from_pretrained(model_name)
        with torch.no_grad():
            t = self.processor(text=['x'], return_tensors='pt')
            t = {k: v.to(self.device) for k, v in t.items()}
            self.dim = int(self.model.get_text_features(**t).shape[-1])

    @torch.inference_mode()
    def encode_images(self, pil_images: List[Image.Image]) -> np.ndarray:
        if not pil_images:
            return np.zeros((0, self.dim), dtype=np.float32)
        batch = self.processor(images=pil_images, return_tensors='pt')
        feats = self.model.get_image_features(pixel_values=batch['pixel_values'].to(self.device))
        feats = F.normalize(feats, p=2, dim=-1)
        return feats.detach().cpu().to(torch.float32).numpy()

    @torch.inference_mode()
    def encode_text(self, text: str) -> np.ndarray:
        t = self.processor(text=[text], return_tensors='pt', padding=True)
        t = {k: v.to(self.device) for k, v in t.items()}
        feats = self.model.get_text_features(**t)
        feats = F.normalize(feats, p=2, dim=-1)
        return feats.detach().cpu().to(torch.float32).numpy()[0]

enc = ClipEncoder(MODEL_NAME)
print('Device:', enc.model.device, '| dim =', enc.dim)


## 5) On‑Disk Index (embeddings + metadata)


In [6]:
@dataclass
class ImageMeta:
    path: str
    mtime: float

class DiskIndex:
    def __init__(self, index_dir: Path):
        self.index_dir = index_dir
        self.emb_path = index_dir / 'embeddings.npy'
        self.meta_path = index_dir / 'images.jsonl'
        self.header_path = index_dir / 'header.json'
        self.faiss_path = index_dir / 'faiss.index'

    def save(self, header: dict, embeddings: np.ndarray, metas: List[ImageMeta], faiss_index=None):
        self.index_dir.mkdir(parents=True, exist_ok=True)
        embeddings = normalize(embeddings.astype(np.float32, copy=False))
        np.save(self.emb_path, embeddings)
        with open(self.meta_path, 'w', encoding='utf-8') as f:
            for m in metas:
                f.write(json.dumps(dataclasses.asdict(m), ensure_ascii=False) + '\n')
        with open(self.header_path, 'w', encoding='utf-8') as f:
            json.dump(header, f, ensure_ascii=False, indent=2)
        if faiss_index is not None:
            faiss.write_index(faiss_index, str(self.faiss_path))

    def load(self):
        header = json.loads(self.header_path.read_text('utf-8'))
        embs = np.load(self.emb_path)
        metas = [json.loads(line) for line in self.meta_path.read_text('utf-8').splitlines()]
        return header, embs, metas

    def try_load_faiss(self):
        if HAS_FAISS and self.faiss_path.exists():
            return faiss.read_index(str(self.faiss_path))
        return None


# 6) Build (or rebuild) the index from your photo folder


In [None]:
files = list_images(PHOTOS_DIR)
print(f'Found {len(files)} images under {PHOTOS_DIR}')

metas: List[ImageMeta] = []
all_feats: List[np.ndarray] = []

for batch_paths in tqdm(list(chunked(files, BATCH_SIZE)), desc='Embedding', unit='batch'):
    pil_batch, batch_metas = [], []
    for p in batch_paths:
        try:
            pil_batch.append(open_image(p))
            batch_metas.append(ImageMeta(path=str(p), mtime=p.stat().st_mtime))
        except Exception as e:
            print('[skip]', p, e)
    if not pil_batch:
        continue
    feats = enc.encode_images(pil_batch)
    all_feats.append(feats)
    metas.extend(batch_metas)

if not metas:
    raise SystemExit('No images embedded — check PHOTOS_DIR')

embeddings = np.vstack(all_feats)
header = {
    'model_name': MODEL_NAME,
    'dim': int(embeddings.shape[-1]),
    'created_at': time.time(),
    'num_items': len(metas),
    'device_build': str(enc.model.device),
}

index = DiskIndex(INDEX_DIR)
index.save(header, embeddings, metas)
print(f'Saved index → {INDEX_DIR} with {len(metas)} items.')


## 7) Search — Text → Image (cosine)


In [8]:
from typing import Sequence

def search_text(index_dir: Path, text: str, topk: int = 12, prefer_faiss: bool = True):
    index = DiskIndex(index_dir)
    header, embs, metas = index.load()
    embs = normalize(embs.astype(np.float32))

    q = enc.encode_text(text).astype(np.float32)  # (D,)

    if prefer_faiss and HAS_FAISS:
        fidx = index.try_load_faiss()
        if fidx is None:
            fidx = faiss.IndexFlatIP(header['dim'])
            fidx.add(embs)
        D, I = fidx.search(q[None, :], min(topk, len(embs)))
        return [(metas[int(i)], float(d)) for i, d in zip(I[0], D[0])]

    sims = embs @ q
    idx = np.argsort(-sims)[:topk]
    return [(metas[int(i)], float(sims[int(i)])) for i in idx]


# 8) Search — Image → Image (cosine)


In [9]:
def encode_query_image(img: Union[str, Path, Image.Image]) -> np.ndarray:
    if isinstance(img, (str, Path)):
        pil = open_image(Path(img))
    else:
        pil = img
    return enc.encode_images([pil])[0].astype(np.float32)


def search_image(index_dir: Path, query_image: Union[str, Path, Image.Image], topk: int = 12,
                 prefer_faiss: bool = True, include_self: bool = False):
    index = DiskIndex(index_dir)
    header, embs, metas = index.load()
    embs = normalize(embs.astype(np.float32))

    q = encode_query_image(query_image)  # (D,)

    if prefer_faiss and HAS_FAISS:
        fidx = index.try_load_faiss()
        if fidx is None:
            fidx = faiss.IndexFlatIP(header['dim'])
            fidx.add(embs)
        D, I = fidx.search(q[None, :], min(topk+10, len(embs)))
        candidates = [(metas[int(i)], float(d)) for i, d in zip(I[0], D[0]) if int(i) >= 0]
    else:
        sims = embs @ q
        idx = np.argsort(-sims)[:topk+10]
        candidates = [(metas[int(i)], float(sims[int(i)])) for i in idx]

    # Optionally filter out the query image itself
    q_path = None
    if isinstance(query_image, (str, Path)):
        q_path = str(Path(query_image).resolve())
    results = []
    for meta, score in candidates:
        p = str(Path(meta['path']).resolve())
        if not include_self and q_path is not None and p == q_path:
            continue
        results.append((meta, score))
        if len(results) >= topk:
            break
    return results

# 9) Search — Image + Text → Image (fused)


In [10]:
# Combine image and text queries by a weighted sum of their embeddings.
# q = normalize(w_img * q_img + w_txt * q_txt)

def search_image_plus_text(index_dir: Path,
                           query_image: Union[str, Path, Image.Image],
                           text: str,
                           w_img: float = 1.0,
                           w_txt: float = 1.0,
                           topk: int = 12,
                           prefer_faiss: bool = True,
                           include_self: bool = False):
    index = DiskIndex(index_dir)
    header, embs, metas = index.load()
    embs = normalize(embs.astype(np.float32))

    qi = encode_query_image(query_image).astype(np.float32)
    qt = enc.encode_text(text).astype(np.float32)

    q = w_img * qi + w_txt * qt
    q = q / (np.linalg.norm(q) + 1e-12)

    if prefer_faiss and HAS_FAISS:
        fidx = index.try_load_faiss()
        if fidx is None:
            fidx = faiss.IndexFlatIP(header['dim'])
            fidx.add(embs)
        D, I = fidx.search(q[None, :], min(topk+10, len(embs)))
        candidates = [(metas[int(i)], float(d)) for i, d in zip(I[0], D[0]) if int(i) >= 0]
    else:
        sims = embs @ q
        idx = np.argsort(-sims)[:topk+10]
        candidates = [(metas[int(i)], float(sims[int(i)])) for i in idx]

    # Optional self-filter if the image is from the index
    q_path = None
    if isinstance(query_image, (str, Path)):
        q_path = str(Path(query_image).resolve())
    results = []
    for meta, score in candidates:
        p = str(Path(meta['path']).resolve())
        if not include_self and q_path is not None and p == q_path:
            continue
        results.append((meta, score))
        if len(results) >= topk:
            break
    return results


# 10) Visualization helpers


In [11]:
import matplotlib.pyplot as plt

def show_grid(results: List[Tuple[dict, float]], cols: int = 4, max_w: int = 320, title: str = None):
    if not results:
        print('No results.'); return
    rows = math.ceil(len(results)/cols)
    if title:
        print(title)
    plt.figure(figsize=(cols*4, rows*4))
    for i, (meta, score) in enumerate(results, start=1):
        try:
            img = open_image(Path(meta['path']))
            img.thumbnail((max_w, max_w))
            plt.subplot(rows, cols, i)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"{score:.3f}\n{Path(meta['path']).name}", fontsize=9)
        except Exception as e:
            print('[viz-skip]', meta['path'], e)
    plt.tight_layout(); plt.show()


def show_query(query_img_path: Union[str, Path, None] = None, text: Optional[str] = None):
    if query_img_path is None and not text:
        return
    plt.figure(figsize=(4,4))
    if query_img_path is not None:
        img = open_image(Path(query_img_path))
        plt.imshow(img); plt.axis('off')
        ttl = f"Query image: {Path(query_img_path).name}"
        if text:
            ttl += f"\nText: {text}"
        plt.title(ttl)
    else:
        plt.text(0.5, 0.5, text, ha='center', va='center', fontsize=14)
        plt.axis('off')
    plt.show()

# 11) Demos


In [None]:
# 🔎 Text → Image
text_query = 'black female'
text_hits = search_text(INDEX_DIR, text_query, topk=12)
show_query(None, text_query)
show_grid(text_hits, cols=4, title='Text → Image')


In [None]:
# 🖼️ Image → Image
# Pick an existing image from your library as the query
query_image_path = "/content/new_cat.png"  # 👈 replace with a specific file if you like
img_hits = search_image(INDEX_DIR, query_image_path, topk=12, include_self=False)
show_query(query_image_path)
show_grid(img_hits, cols=4, title='Image → Image')


In [None]:
# 🧪 Image + Text → Image (fused)
# Example: refine the image neighborhood with a textual intent
fused_hits = search_image_plus_text(
    INDEX_DIR,
    query_image=query_image_path,
    text='man',
    w_img=1.0,
    w_txt=0.8,
    topk=12,
    include_self=False,
)
show_query(query_image_path, 'man')
show_grid(fused_hits, cols=4, title='Image + Text → Image (fused)')
