In [None]:
!pip install ray[data]

In [None]:
import ray
ray.init()

In [None]:
from huggingface_hub import hf_hub_url

repo  = "Francesco/insects-mytwu"
files = [
    hf_hub_url(repo, "data/test-00000-of-00001-670031141e816b9b.parquet", repo_type="dataset"),
    hf_hub_url(repo, "data/test-00000-of-00001-670031141e816b9b.parquet", repo_type="dataset"),
    hf_hub_url(repo, "data/validation-00000-of-00001-876de533d76d48c6.parquet", repo_type="dataset")
]

ds = ray.data.read_parquet(files, columns=["image_id", "image"])

In [None]:
print(ds)

In [None]:
import pyarrow as pa

def first_row_arrow(tbl: pa.Table) -> pa.Table:
    # keep exactly one representative row per group
    return tbl.slice(0, 1)

ds = (
    ds
    .sort("image_id")  # optional: makes which row is “first” deterministic
    .groupby("image_id")
    .map_groups(first_row_arrow, batch_format="pyarrow")
)


In [None]:
import io
import numpy as np
from PIL import Image

def decode_to_rgb_row(row):
    b = row["image"]["bytes"]  # nested in the struct
    row["image"] = np.asarray(Image.open(io.BytesIO(b)).convert("RGB"), dtype=np.uint8)
    return row

ds = ds.map(decode_to_rgb_row)   # <-- row-wise map
ds = ds.materialize()

In [None]:
print(ds)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

rows = ds.take(2)  # returns a list of dicts

for i, row in enumerate(rows, 1):
    img = row["image"]            # HxWx3 uint8
    img_id = row["image_id"]
    plt.figure()
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"image_id={img_id} (#{i})")


In [None]:
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor

In [None]:
class EmbeddingGenerator(object):
    def __init__(self, model_id):
        # Load CLIP model and processor.
        self.model = CLIPModel.from_pretrained(model_id)
        self.processor = CLIPProcessor.from_pretrained(model_id)

    def __call__(self, batch, device="cpu"):
        # Load and preprocess images.
        images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch["image"]]
        inputs = self.processor(images=images, return_tensors="pt", padding=True).to(device)

        # Generate embeddings.
        self.model.to(device)
        with torch.inference_mode():
            batch["embedding"] = self.model.get_image_features(**inputs).cpu().numpy()

        return batch

In [None]:
# Batch embeddings.
embeddings_ds = ds.map_batches(
    EmbeddingGenerator,
    fn_constructor_kwargs={"model_id": "openai/clip-vit-base-patch32"},  # class kwargs
    fn_kwargs={"device": "cuda"},  # __call__ kwargs
    concurrency=1,
    batch_size=128,
    num_gpus=1,
    accelerator_type="L4",
)

In [None]:
embeddings_ds = embeddings_ds.materialize()

In [None]:
print(embeddings_ds)

In [None]:
!pip install faiss-cpu

In [None]:
# Collect embeddings (float32, L2-normalized) and ids
rows = embeddings_ds.select_columns(["image_id","embedding"]).take_all()
ids  = np.array([r["image_id"] for r in rows])
X    = np.stack([r["embedding"] for r in rows]).astype("float32")  # [N,D]

# FAISS cosine ≈ inner-product if X is normalized
import faiss
d = X.shape[1]
index = faiss.IndexFlatIP(d)
index.add(X)

# Query top-5 neighbors for the first 3 images
D, I = index.search(X[:3], 5)  # D: similarity scores, I: indices
print(ids[I], D)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# If your dataset is named `embeddings_ds`, uncomment the next line:
# ds = embeddings_ds

# Pull everything we need into memory (N=498 is tiny)
rows = embeddings_ds.select_columns(["image_id", "image", "embedding"]).take_all()

ids   = np.array([r["image_id"] for r in rows])
imgs  = [r["image"] for r in rows]                          # list of HxWx3 uint8
X     = np.stack([r["embedding"] for r in rows]).astype("float32")  # [N, D]

# Normalize embeddings (safe even if already normalized)
X /= np.linalg.norm(X, axis=1, keepdims=True) + 1e-12

q_num =7
# ---- Query = first row ----
q_img  = imgs[q_num]
q_id   = int(ids[q_num])
q      = X[q_num]                                               # [D]
sims   = X @ q                                              # cosine similarity to all images
order  = np.argsort(-sims)                                  # descending
neighbors_idx = [i for i in order if i != q_num][:10]           # exclude self

# ---- Plot: 3x4 grid: query + 10 neighbors ----
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
axes = axes.ravel()

# Slot 0 = query
axes[0].imshow(q_img)
axes[0].set_title(f"Query\nid={q_id}")
axes[0].axis("off")

# Slots 1..10 = neighbors
for k, idx in enumerate(neighbors_idx, start=1):
    axes[k].imshow(imgs[idx])
    axes[k].set_title(f"NN{k}: id={int(ids[idx])}\nsim={sims[idx]:.3f}", fontsize=9)
    axes[k].axis("off")

# Any leftover axes (e.g., slot 11) -> hide
for j in range(1 + len(neighbors_idx), len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()


In [None]:
import os
path = os.path.join("/content/", "output")
embeddings_ds.write_parquet(path)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!zip -r /content/drive/MyDrive/clip_output/output.zip /content/output