# 03 — Build Item Embeddings (inputs for Semantic IDs)

**Goal.** Convert each item’s text (`description`) into a fixed 384‑dimensional embedding
using `sentence-transformers/all-MiniLM-L6-v2`, saving sharded `.pt` files (fp16, L2‑normalized).
These continuous vectors will be quantized into **Semantic IDs** in the next notebook.

**Why this step?** The paper’s idea is to replace raw IDs with *semantic* identifiers that generalize.
This notebook produces the semantic *content* vectors we’ll discretize next.

### 1. Imports, environment, and paths

In [1]:
# --- Imports
import os, gc, json, time, glob, math, random
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow.parquet as pq

import torch
import torch.nn.functional as F

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from tqdm import tqdm

# --- Env knobs (avoid ipywidgets progress errors and tokenizer thread spam)
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- Repro
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- Paths (adjust if needed)
ITEMS_FP   = "data/processed/books/items.parquet"  # from notebook 2
TEXT_COL   = "description"                               # column to encode
OUT_DIR    = "data/embeddings/items_miniLM"        # new output dir

Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

print("Items:", ITEMS_FP)
print("Out  :", OUT_DIR)

Items: data/processed/books/items.parquet
Out  : data/embeddings/items_miniLM


### 2. Config: model + throughput

In [2]:
# --- Model + speed knobs
MODEL_ID         = "sentence-transformers/all-MiniLM-L6-v2"   # 384-d
CACHE_DIR        = "models/all-MiniLM-L6-v2"                  # local cache
MAX_LEN          = 256                                        # token window
ARROW_BATCH_ROWS = 200_000                                    # rows per Arrow batch (feedback every batch)
TARGET_MICRO_BS  = 512                                        # items per forward; auto-fallback on OOM
DTYPE            = torch.float16                              # fp16 on disk
NORM             = True                                       # L2-normalize rows (cosine-ready)
USE_MANIFEST     = False                                      # usually unnecessary; shards are source of truth
MANIFEST_FP      = os.path.join(OUT_DIR, "manifest.jsonl")

if USE_MANIFEST:
    open(MANIFEST_FP, "w").close()

print("MAX_LEN:", MAX_LEN, "| ARROW_BATCH_ROWS:", ARROW_BATCH_ROWS, "| MICRO_BS:", TARGET_MICRO_BS)


MAX_LEN: 256 | ARROW_BATCH_ROWS: 200000 | MICRO_BS: 512


### 3. Device and model load (+ test)

In [3]:
# --- Device selection (MPS on Apple Silicon)
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# --- Load tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
model     = SentenceTransformer(MODEL_ID, cache_folder=CACHE_DIR, device=DEVICE)
model.max_seq_length = MAX_LEN

# --- Sanity smoke test (also warms up MPS)
_ = model.encode(["hello world"], batch_size=4, convert_to_tensor=True, show_progress_bar=False, device=DEVICE)
print("Model ready. Embedding dim:", model.get_sentence_embedding_dimension())

Device: mps
Model ready. Embedding dim: 384


### 4. Resume helper: which shards already exist?

In [4]:
# We resume by skipping shard files already on disk.
existing = sorted(glob.glob(os.path.join(OUT_DIR, "shard_batch_*.pt")))
print("Existing shards:", len(existing))
if len(existing) > 0:
    print("Example:", os.path.basename(existing[-1]))


Existing shards: 17
Example: shard_batch_000016.pt


### 5. Encoders: simple (default) and optional “full‑coverage” chunking

In [5]:
from typing import List

@torch.inference_mode()
def encode_texts_simple(texts: List[str], init_bs=TARGET_MICRO_BS) -> torch.Tensor:
    """
    Encode list[str] -> [N, d] tensor on CPU (fp16), with L2-norm if NORM=True.
    Auto‑fallback halves batch size on OOM.
    """
    d = model.get_sentence_embedding_dimension()
    if not texts:
        return torch.empty((0, d), dtype=DTYPE)

    out = []
    bs  = init_bs
    i   = 0
    while i < len(texts):
        step = min(bs, len(texts) - i)
        try:
            emb = model.encode(
                texts[i:i+step],
                batch_size=step,
                convert_to_tensor=True,
                show_progress_bar=False,
                device=DEVICE,
                normalize_embeddings=False,   # we will normalize ourselves
            )
        except RuntimeError as e:
            # Auto OOM fallback on MPS/CUDA
            if "out of memory" in str(e).lower() and bs > 16:
                if DEVICE == "mps": torch.mps.empty_cache()
                bs //= 2
                print(f"OOM fallback -> micro-batch {bs}")
                continue
            raise
        if NORM:
            emb = F.normalize(emb, p=2, dim=1)
        out.append(emb.to("cpu").to(DTYPE))
        i += step
        del emb
        if DEVICE == "mps":
            torch.mps.empty_cache()
    return torch.cat(out, dim=0)


### 6. Streaming loop: read → encode → save shards

In [6]:
pf = pq.ParquetFile(ITEMS_FP)
print("Row-groups:", pf.num_row_groups)

# For "resume": build a set of shard basenames that already exist
existing_set = {os.path.basename(p) for p in existing}

total_items = 0
batch_idx   = 0
t0 = time.time()

for batch in pf.iter_batches(columns=["item_id", TEXT_COL], batch_size=ARROW_BATCH_ROWS):
    # Convert only needed columns to Python lists (avoid building a giant DataFrame)
    ids   = batch.column(0).to_pylist()
    texts = batch.column(1).to_pylist()

    # Drop empties fast
    keep = [i for i, t in enumerate(texts) if t is not None and t != ""]
    if len(keep) != len(texts):
        ids   = [ids[i]   for i in keep]
        texts = [texts[i] for i in keep]

    shard_name = f"shard_batch_{batch_idx:06d}.pt"
    shard_path = os.path.join(OUT_DIR, shard_name)

    # Resume: skip if shard already exists
    if shard_name in existing_set:
        print(f"[batch {batch_idx:06d}] skipping (exists) → {shard_name}")
        total_items += len(ids)
        batch_idx   += 1
        # clean up
        del ids, texts, batch
        gc.collect()
        continue

    if not texts:
        print(f"[batch {batch_idx:06d}] empty texts; skipping")
        batch_idx += 1
        del ids, texts, batch
        gc.collect()
        continue

    t_batch = time.time()
    # --- Encode
    E = encode_texts_simple(texts, init_bs=TARGET_MICRO_BS)   # [Nb, 384] on CPU, fp16

    # --- Save shard
    torch.save({"item_id": ids[:len(E)], "embeddings": E}, shard_path)

    # --- Optional manifest record
    if USE_MANIFEST:
        with open(MANIFEST_FP, "a") as mf:
            mf.write(json.dumps({"batch": batch_idx, "n": len(E), "path": shard_name}) + "\n")

    # progress
    total_items += len(E)
    dt = time.time() - t_batch
    ips = int(len(E) / max(dt, 1e-6))
    print(f"[batch {batch_idx:06d}] rows={len(E):,} → {shard_name} | {ips} items/s | total={total_items:,}")

    # cleanup
    del ids, texts, E, batch
    gc.collect()
    if DEVICE == "mps":
        torch.mps.empty_cache()

    batch_idx += 1

# Stop heartbeat if you enabled it
stop_flag = True

print(f"Done. Batches: {batch_idx:,} | items: {total_items:,} | time: {(time.time()-t0)/60:.1f} min")


Row-groups: 4
[batch 000000] skipping (exists) → shard_batch_000000.pt
[batch 000001] skipping (exists) → shard_batch_000001.pt
[batch 000002] skipping (exists) → shard_batch_000002.pt
[batch 000003] skipping (exists) → shard_batch_000003.pt
[batch 000004] skipping (exists) → shard_batch_000004.pt
[batch 000005] skipping (exists) → shard_batch_000005.pt
[batch 000006] skipping (exists) → shard_batch_000006.pt
[batch 000007] skipping (exists) → shard_batch_000007.pt
[batch 000008] skipping (exists) → shard_batch_000008.pt
[batch 000009] skipping (exists) → shard_batch_000009.pt
[batch 000010] skipping (exists) → shard_batch_000010.pt
[batch 000011] skipping (exists) → shard_batch_000011.pt
[batch 000012] skipping (exists) → shard_batch_000012.pt
[batch 000013] skipping (exists) → shard_batch_000013.pt
[batch 000014] skipping (exists) → shard_batch_000014.pt
[batch 000015] skipping (exists) → shard_batch_000015.pt
[batch 000016] skipping (exists) → shard_batch_000016.pt
[batch 000017] ro

### 7. Sanity checks: shape, dtype, norms, quick NN

In [9]:
# --- Load the first shard we produced ---
files = sorted(glob.glob(os.path.join(OUT_DIR, "shard_batch_*.pt")))
assert files, "No shards found."
probe = files[0]

blob  = torch.load(probe, map_location="cpu")
ids   = blob["item_id"]
E     = blob["embeddings"]  # [n, 384], fp16
print(f"Loaded probe shard: {os.path.basename(probe)} | shape={tuple(E.shape)} | dtype={E.dtype}")

# --- Check L2 norms to ensure vectors are unit-length (for cosine similarity) ---
row_norms = E.float().norm(dim=1)
TOL = 5e-3  # acceptable drift due to fp16 rounding
frac_bad = float((row_norms.sub(1.0).abs() > TOL).float().mean())

print(f"Norm stats → mean={row_norms.mean():.6f}, std={row_norms.std():.6f}, "
      f"min={row_norms.min():.6f}, max={row_norms.max():.6f}, "
      f"frac outside |1±{TOL}| = {frac_bad:.6%}")

# --- Tiny nearest-neighbor (NN) smoke test ---
# We’ll normalize and compute cosine similarities for a small subset
m = min(2000, E.shape[0])     # keep it light
Esm = F.normalize(E[:m].float(), p=2, dim=1)    # ensure re-normalized (safe)
q   = Esm[:5]                                  # 5 queries
S   = q @ Esm.T                                # cosine similarities
topk = torch.topk(S, k=6, dim=1).indices       # top-1 is itself

print("\nNearest-neighbor sanity check:")
for i in range(q.shape[0]):
    nbrs = [ids[j] for j in topk[i].tolist()[1:6]]
    print(f"  Query {i}: nearest neighbors → {nbrs}")

# --- Cleanup ---
del blob, E, Esm, q, S, topk, row_norms
gc.collect()
print("\nSanity checks complete.")

Loaded probe shard: shard_batch_000000.pt | shape=(200000, 384) | dtype=torch.float16
Norm stats → mean=1.000000, std=0.000018, min=0.999914, max=1.000087, frac outside |1±0.005| = 0.000000%

Nearest-neighbor sanity check:
  Query 0: nearest neighbors → ['B008DM2LQ8', '0976640538', '1441599258', '3548241107', '1133307299']
  Query 1: nearest neighbors → ['B007HDXO4C', '1111302731', '1424057418', '1508927618', '1932225323']
  Query 2: nearest neighbors → ['0786477946', '1501121960', '1585426628', 'B000KZQCMK', '1952816025']
  Query 3: nearest neighbors → ['1546879471', '0670015547', '1087719887', '1595729003', 'B007HXFBDE']
  Query 4: nearest neighbors → ['0786019492', 'B01F0OPEB0', '0984903054', '1643136135', 'B0007ZNUUK']

Sanity checks complete.


### 8. Save run metadata summary

In [8]:
# Quick run snapshot
meta = {
    "model_id": MODEL_ID,
    "cache_dir": CACHE_DIR,
    "max_len": MAX_LEN,
    "arrow_batch_rows": ARROW_BATCH_ROWS,
    "target_micro_bs": TARGET_MICRO_BS,
    "dtype": str(DTYPE),
    "norm": NORM,
    "device": DEVICE,
    "items_file": ITEMS_FP,
    "text_col": TEXT_COL,
    "out_dir": OUT_DIR,
    "time_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
with open(os.path.join(OUT_DIR, "stats.json"), "w") as f:
    json.dump(meta, f, indent=2)
print("Saved:", os.path.join(OUT_DIR, "stats.json"))

Saved: data/embeddings/items_miniLM/stats.json
