Read and encode sample data

In [None]:
!pip install --upgrade sentencepiece

In [None]:
# python - <<'PY'
import sentencepiece as spm
print("SentencePiece version:", spm.__version__)
# PY

In [2]:
"""Sample 100 MS-COCO val2017 images + captions and encode them with:
  Image encoder:  google/siglip2-giant-opt-patch16-256
  Text  encoder:  intfloat/e5-mistral-7b-instruct

Assumed directory layout (relative to this script or working directory):

Data/
  val2017/                # COCO val2017 images *.jpg
  annotations/
    captions_val2017.json # COCO caption annotations

Outputs:
  outputs/sample_image_ids.json
  outputs/image_embeddings.pt          (tensor [N, D_img])
  outputs/text_embeddings.pt           (tensor [N, D_txt])
  outputs/pairs.parquet                (optional metadata)

Notes:
  * The SigLIP2 giant + Mistral 7B models are large; you likely need a >=24GB GPU (or multiple) for full precision.
  * The code tries to use bfloat16/float16 and device_map='auto'. If you face OOM, lower batch_size or enable 8-bit loading (commented section).
  * For simplicity we pick ONE caption per image (the first), but you can adapt to keep all.
"""
from __future__ import annotations
import json, random, math, os
from pathlib import Path
from typing import List, Dict

import torch
from PIL import Image
from tqdm import tqdm

from transformers import (
    AutoProcessor,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
)

# ----------------------------- Configuration -------------------------------- #
DATA_ROOT = Path("Data")
IMAGES_DIR = DATA_ROOT / "val2017"
CAPTIONS_FILE = DATA_ROOT / "annotations" / "captions_val2017.json"
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

NUM_SAMPLES = 100            # number of image-caption pairs to sample
SEED = 42
IMAGE_MODEL_NAME = "google/siglip2-giant-opt-patch16-256"
TEXT_MODEL_NAME  = "intfloat/e5-mistral-7b-instruct"
IMAGE_BATCH_SIZE = 8         # adjust if you hit OOM
TEXT_BATCH_SIZE  = 4         # adjust if you hit OOM
TORCH_DTYPE_PREF = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else (
    torch.float16 if torch.cuda.is_available() else torch.float32
)

# If you want 8-bit quantization for the text model (requires bitsandbytes), uncomment below:
# LOAD_TEXT_IN_8BIT = True
LOAD_TEXT_IN_8BIT = False

# ----------------------------- Utilities ------------------------------------ #
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_coco_captions(captions_path: Path) -> Dict[int, List[str]]:
    with open(captions_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    id_to_captions: Dict[int, List[str]] = {}
    for ann in data["annotations"]:
        img_id = ann["image_id"]
        id_to_captions.setdefault(img_id, []).append(ann["caption"].strip())
    return id_to_captions, {img['id']: img for img in data["images"]}


def sample_image_ids(all_image_meta: Dict[int, dict], k: int) -> List[int]:
    all_ids = list(all_image_meta.keys())
    random.shuffle(all_ids)
    return all_ids[:k]


def load_image(path: Path) -> Image.Image:
    img = Image.open(path).convert("RGB")
    return img

# ----------------------------- Main Encoding Logic -------------------------- #
@torch.no_grad()
def encode_images(image_paths: List[Path], image_model, image_processor, batch_size: int) -> torch.Tensor:
    embs = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Encoding images"):
        batch_paths = image_paths[i:i+batch_size]
        images = [load_image(p) for p in batch_paths]
        inputs = image_processor(images=images, return_tensors="pt")
        inputs = {k: v.to(image_model.device, non_blocking=True) for k, v in inputs.items()}
        feats = image_model.get_image_features(**inputs)  # shape [B, D]
        feats = torch.nn.functional.normalize(feats, dim=-1)
        embs.append(feats.cpu())
    return torch.cat(embs, dim=0)

@torch.no_grad()
def encode_texts(texts: List[str], text_model, tokenizer, batch_size: int) -> torch.Tensor:
    embs = []
    # E5 instruct models expect a task prefix for queries/passages. For captions treat as 'passage:'
    proc_texts = [f"passage: {t}" for t in texts]
    for i in tqdm(range(0, len(proc_texts), batch_size), desc="Encoding texts"):
        batch_txt = proc_texts[i:i+batch_size]
        inputs = tokenizer(batch_txt, padding=True, truncation=True, return_tensors="pt", max_length=256)
        inputs = {k: v.to(text_model.device, non_blocking=True) for k, v in inputs.items()}
        outputs = text_model(**inputs)
        # E5 uses last hidden state + attention mask mean pooling
        hidden = outputs.last_hidden_state  # [B, L, H]
        mask = inputs['attention_mask'].unsqueeze(-1)  # [B, L, 1]
        summed = (hidden * mask).sum(dim=1)
        counts = mask.sum(dim=1)
        sentence_emb = summed / counts
        sentence_emb = torch.nn.functional.normalize(sentence_emb, dim=-1)
        embs.append(sentence_emb.cpu())
    return torch.cat(embs, dim=0)

# ----------------------------- Execution ------------------------------------ #
if __name__ == "__main__":
    set_seed(SEED)

    if not CAPTIONS_FILE.exists():
        raise FileNotFoundError(f"Cannot find captions file at {CAPTIONS_FILE}")
    if not IMAGES_DIR.exists():
        raise FileNotFoundError(f"Cannot find images directory at {IMAGES_DIR}")

    print("Loading COCO captions metadata ...")
    id_to_caps, id_to_imgmeta = load_coco_captions(CAPTIONS_FILE)
    sampled_ids = sample_image_ids(id_to_imgmeta, NUM_SAMPLES)

    # Select first caption for each image (customize as needed)
    sampled_captions = []
    image_paths = []
    for img_id in sampled_ids:
        rel_name = f"{img_id:012d}.jpg"  # COCO filename pattern
        img_path = IMAGES_DIR / rel_name
        if not img_path.exists():
            raise FileNotFoundError(f"Missing image file {img_path}")
        image_paths.append(img_path)
        cap_list = id_to_caps.get(img_id, ["(no caption)"])
        sampled_captions.append(cap_list[0])

    print(f"Sampled {len(sampled_ids)} image-caption pairs.")

    print("Loading image encoder ...")
    image_processor = AutoProcessor.from_pretrained(IMAGE_MODEL_NAME)
    image_model = AutoModel.from_pretrained(
        IMAGE_MODEL_NAME,
        torch_dtype=TORCH_DTYPE_PREF,
        device_map="auto"
    )

    print("Loading text encoder ...")
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
    text_model_kwargs = dict(torch_dtype=TORCH_DTYPE_PREF, device_map="auto")

    if LOAD_TEXT_IN_8BIT:
        # Requires bitsandbytes installed; uncomment if desired.
        # from transformers import BitsAndBytesConfig
        # quant_config = BitsAndBytesConfig(load_in_8bit=True)
        # text_model_kwargs.update(dict(quantization_config=quant_config, device_map="auto"))
        pass

    text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME, **text_model_kwargs)

    # Encode
    image_embeddings = encode_images(image_paths, image_model, image_processor, IMAGE_BATCH_SIZE)
    text_embeddings = encode_texts(sampled_captions, text_model, tokenizer, TEXT_BATCH_SIZE)

    print("Embeddings shapes:")
    print("  Images:", image_embeddings.shape)
    print("  Texts :", text_embeddings.shape)

    torch.save(image_embeddings, OUTPUT_DIR / "image_embeddings.pt")
    torch.save(text_embeddings, OUTPUT_DIR / "text_embeddings.pt")

    # Save mapping metadata
    import json
    meta = {
        "image_ids": sampled_ids,
        "captions": sampled_captions,
        "image_paths": [str(p) for p in image_paths],
        "image_embedding_file": "image_embeddings.pt",
        "text_embedding_file": "text_embeddings.pt",
        "image_model": IMAGE_MODEL_NAME,
        "text_model": TEXT_MODEL_NAME,
        "seed": SEED,
    }
    with open(OUTPUT_DIR / "sample_image_ids.json", "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    try:
        import pandas as pd
        import pyarrow  # noqa: F401
        import pyarrow.parquet  # noqa: F401
        df = pd.DataFrame({
            "image_id": sampled_ids,
            "image_path": [str(p) for p in image_paths],
            "caption": sampled_captions
        })
        df.to_parquet(OUTPUT_DIR / "pairs.parquet", index=False)
        print("Saved pairs.parquet")
    except ImportError:
        print("pandas/pyarrow not installed; skipping parquet export.")

    print("Done.")


Loading COCO captions metadata ...
Sampled 100 image-caption pairs.
Loading image encoder ...


preprocessor_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

ImportError: 
SiglipTokenizer requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
