In [1]:
import os
import sys
from pathlib import Path
import numpy as np
from PIL import Image

import torch
from torchvision import models, transforms
import json

  from .autonotebook import tqdm as notebook_tqdm


Load Data

In [2]:
embeddings = np.load('embeddings.npy', mmap_mode='r')   # shape: (N, D)
assert embeddings.ndim == 2
N, D = embeddings.shape

In [3]:
filenames = np.load('filenames.npy', mmap_mode='r')

Query Embedding

In [4]:
image_path = 'IMG_0018.png'

In [5]:
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = torch.nn.Sequential(*(list(model.children())[:-1]))

img = Image.open(image_path).convert("RGB")
img_t = transform(img).unsqueeze(0)  # shape: [1, 3, 224, 224]

with torch.no_grad():
    feats = model(img_t)       # shape: [1, 2048, 1, 1]
    feats = feats.squeeze()    # shape: [2048]

embedding_query = feats.numpy()

Brute Force ANN

In [6]:
def ann_naive(emb: np.ndarray, query_vec: np.ndarray, k: int):
    """
    Naive L2 k-NN search using explicit Python looping.
    emb: (N, D) numpy array
    query_vec: (D,) numpy array
    Returns (indices, distances)
    """
    N, D = emb.shape
    k = min(k, N)
    q = query_vec.astype(np.float32)

    # --- compute squared L2 distance for each embedding ---
    dists = []
    for i in range(N):
        diff = emb[i] - q
        d2 = float(np.dot(diff, diff))  # same as np.sum(diff**2)
        dists.append((i, d2))

    # --- sort by distance (ascending) ---
    dists.sort(key=lambda x: x[1])

    # --- take top-k ---
    topk = dists[:k]
    idx = [i for i, _ in topk]
    scores = [d for _, d in topk]

    return np.array(idx, dtype=int), np.array(scores, dtype=float)

In [7]:
%%time
print(ann_naive(embeddings, embedding_query, 10))

(array([20934, 21100,  4687, 14014, 18868,  8132, 11690,  8782, 13416,
        5137]), array([70.06198883, 70.39277649, 70.6230011 , 72.50684357, 72.68650055,
       72.91902161, 73.40797424, 73.51124573, 73.81577301, 74.15214539]))
CPU times: user 911 ms, sys: 55.5 ms, total: 966 ms
Wall time: 98.2 ms


Paths to Metadata

In [8]:
path_to_ids = {}
for i in [0,1,2,3,4,5,6,7,8,9,'a','b','c','d','e']:
    with open(f"map0{i}.csv", "r") as f:
        for line in f:
            parts = line.strip().split(",")
            path_to_ids[parts[0]] = parts[3][3:]
print(f"Loaded {len(path_to_ids)} image ID mappings.")

path_to_meta = {}
with open("metadata0.py", "r") as f:
    for line in f:
        curr_line = line[1:-1]
        curr_line = curr_line.strip().split(",")
        image_id = curr_line[0]
        metadata = ",".join(curr_line[1:])[:-1]
        # id_to_meta[image_id] = metadata
        path_to_meta[path_to_ids[image_id]] = json.loads(metadata)

Loaded 23396 image ID mappings.


Matching Metadata

In [9]:
def metadata_matches(node_meta: dict, query_metadata: dict) -> bool:
    if not query_metadata:
        return True

    for query_key in query_metadata.keys():
        if query_key not in node_meta.keys():
            continue

        op, target = query_metadata[query_key]

        if op == "exact":
            if query_key == "item_weight" and node_meta[query_key][0]["value"] != target:
                return False
            elif query_key == "model_year" and node_meta[query_key][0]["value"] != target:
                return False
            elif query_key == "color" and node_meta[query_key][0]["value"] != target:
                return False
            elif query_key == "country" and node_meta[query_key] != target:
                return False
            elif query_key == "brand" and node_meta[query_key][0]["value"] != target:
                return False

        elif op == "leq":
            if query_key == "item_weight" and node_meta[query_key][0]["value"] > target:
                return False
            elif query_key == "model_year" and node_meta[query_key][0]["value"] > target:
                return False

        elif op == "geq":
            if query_key == "item_weight" and node_meta[query_key][0]["value"] < target:
                return False
            elif query_key == "model_year" and node_meta[query_key][0]["value"] < target:
                return False

    return True

Pre-Filtering

In [10]:
def prefilter_search(query_vec: np.ndarray,
                     query_meta: dict,
                     embeddings: np.ndarray,
                     filenames: np.ndarray,
                     path_to_meta: dict,
                     top_k: int = 10):
    """
    1) Use metadata to restrict to a subset of indices.
    2) Run k-NN ONLY within that subset.

    Good when metadata is selective (subset is small).
    """
    assert query_vec.shape == (embeddings.shape[1],)

    # 1) Collect indices that match the metadata
    candidate_indices = []
    for i, fname in enumerate(filenames):
        fname = str(fname)
        node_meta = path_to_meta.get(fname, {})
        if metadata_matches(node_meta, query_meta):
            candidate_indices.append(i)

    print(f"[Pre-filter] Matched {len(candidate_indices)} of {len(filenames)} items.")

    if not candidate_indices:
        return []

    # 2) Run k-NN only on those candidates
    sub_emb = embeddings[candidate_indices]
    local_idx, scores = ann_naive(sub_emb, query_vec, k=top_k)

    results = []
    for rank, (li, score) in enumerate(zip(local_idx, scores)):
        global_idx = candidate_indices[li]
        fname      = str(filenames[global_idx])
        results.append({
            "rank":   rank,
            "index":  int(global_idx),
            "file":   fname,
            "score":  float(score),
            "meta":   path_to_meta.get(fname, {})
        })

    return results

Post-Filtering

In [11]:
def postfilter_search(query_vec: np.ndarray,
                      query_meta: dict,
                      embeddings: np.ndarray,
                      filenames: np.ndarray,
                      path_to_meta: dict,
                      top_k: int = 10,
                      large_k: int = 200):
    """
    1) Compute k-NN over the WHOLE collection (or ANN results).
    2) Walk results in similarity order and keep only those that match metadata.

    Good when you want strong semantic ranking and metadata is not super selective.
    """
    assert query_vec.shape == (embeddings.shape[1],)

    # Get a larger candidate set than final top_k so metadata filtering has room
    k_all = min(large_k, embeddings.shape[0])
    cand_idx, cand_scores = ann_naive(embeddings, query_vec, k=k_all)

    results = []
    for rank_all, (idx, score) in enumerate(zip(cand_idx, cand_scores)):
        fname = str(filenames[idx])
        node_meta = path_to_meta.get(fname, {})
        if metadata_matches(node_meta, query_meta):
            results.append({
                "rank":   len(results),   # rank after metadata filtering
                "index":  int(idx),
                "file":   fname,
                "score":  float(score),
                "meta":   node_meta
            })
            if len(results) == top_k:
                break

    print(f"[Post-filter] Returned {len(results)} results (from {k_all} ANN candidates).")
    return results

In [12]:
query_meta = {"country": ["exact", "US"], "item_weight": ["geq", 0.3]}

In [13]:
%%time
print("=== PREFILTER SEARCH ===")
pre_results = prefilter_search(
    query_vec=embedding_query,
    query_meta=query_meta,
    embeddings=embeddings,
    filenames=filenames,
    path_to_meta=path_to_meta,
    top_k=3,
)
for r in pre_results:
    print(r["index"], r["file"])

=== PREFILTER SEARCH ===
[Pre-filter] Matched 5153 of 23396 items.
4687 032f5f9c.jpg
548 0001b659.jpg
8328 05d37d4c.jpg
CPU times: user 29.4 ms, sys: 7.68 ms, total: 37.1 ms
Wall time: 36.4 ms


In [14]:
%%time
print("\n=== POSTFILTER SEARCH ===")
post_results = postfilter_search(
    query_vec=embedding_query,
    query_meta=query_meta,
    embeddings=embeddings,
    filenames=filenames,
    path_to_meta=path_to_meta,
    top_k=3,
    large_k=200,
)
for r in post_results:
    print(r["index"], r["file"])


=== POSTFILTER SEARCH ===
[Post-filter] Returned 3 results (from 200 ANN candidates).
4687 032f5f9c.jpg
548 0001b659.jpg
8328 05d37d4c.jpg
CPU times: user 58.8 ms, sys: 1.73 ms, total: 60.6 ms
Wall time: 60.1 ms
