# 03 — Build Item Embeddings (inputs for Semantic IDs)

**Goal.** Convert each item’s text (`description`) into a fixed 384‑dimensional embedding
using `sentence-transformers/all-MiniLM-L6-v2`, saving sharded `.pt` files (fp16, L2‑normalized).
These continuous vectors will be quantized into **Semantic IDs** in the next notebook.

**Why this step?** The paper’s idea is to replace raw IDs with *semantic* identifiers that generalize.
This notebook produces the semantic *content* vectors we’ll discretize next.

### 1. Imports, environment, and paths

In [1]:
# --- Imports
import os, gc, json, time, glob, math, random
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow.parquet as pq

import torch
import torch.nn.functional as F

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from tqdm import tqdm

# Env knobs (to avoid ipywidgets progress errors and tokenizer thread spam)
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Reproducability
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Paths 
ITEMS_FP   = "data/processed/books/items.parquet"  # from notebook 2
TEXT_COL   = "description"                               # column to encode
OUT_DIR    = "data/embeddings/items_miniLM"        # new output dir

Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

print("Items:", ITEMS_FP)
print("Out  :", OUT_DIR)

Items: data/processed/books/items.parquet
Out  : data/embeddings/items_miniLM


### 2. Config: model + throughput

In [2]:
# Model + speed knobs
MODEL_ID         = "sentence-transformers/all-MiniLM-L6-v2"   # 384-d
CACHE_DIR        = "models/all-MiniLM-L6-v2"                  # local cache
MAX_LEN          = 256                                        # token window
ARROW_BATCH_ROWS = 200_000                                    # rows per Arrow batch (feedback every batch)
TARGET_MICRO_BS  = 512                                        # items per forward; auto-fallback on OOM
DTYPE            = torch.float16                              # fp16 on disk
NORM             = True                                       # L2-normalize rows (cosine-ready)
USE_MANIFEST     = False
MANIFEST_FP      = os.path.join(OUT_DIR, "manifest.jsonl")

if USE_MANIFEST:
    open(MANIFEST_FP, "w").close()

print("MAX_LEN:", MAX_LEN, "| ARROW_BATCH_ROWS:", ARROW_BATCH_ROWS, "| MICRO_BS:", TARGET_MICRO_BS)

MAX_LEN: 256 | ARROW_BATCH_ROWS: 200000 | MICRO_BS: 512


### 3. Device and model load (+ test)

In [3]:
# Device selection (MPS)
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Load tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
model     = SentenceTransformer(MODEL_ID, cache_folder=CACHE_DIR, device=DEVICE)
model.max_seq_length = MAX_LEN

# test 
_ = model.encode(["hello world"], batch_size=4, convert_to_tensor=True, show_progress_bar=False, device=DEVICE)
print("Model ready. Embedding dim:", model.get_sentence_embedding_dimension())

Device: mps
Model ready. Embedding dim: 384


### 4. Resume helper: which shards already exist?

In [4]:
# We resume by skipping shard files already on disk.
existing = sorted(glob.glob(os.path.join(OUT_DIR, "shard_batch_*.pt")))
print("Existing shards:", len(existing))
if len(existing) > 0:
    print("Example:", os.path.basename(existing[-1]))


Existing shards: 13
Example: shard_batch_000012.pt


### 5. Encoders: simple (default) and optional “full‑coverage” chunking

In [5]:
from typing import List

@torch.inference_mode()
def encode_texts_simple(texts: List[str], init_bs=TARGET_MICRO_BS) -> torch.Tensor:
    """
    Encode list[str] -> [N, d] tensor on CPU (fp16), with L2-norm if NORM=True.
    Auto‑fallback halves batch size on OOM.
    """
    d = model.get_sentence_embedding_dimension()
    if not texts:
        return torch.empty((0, d), dtype=DTYPE)

    out = []
    bs  = init_bs
    i   = 0
    while i < len(texts):
        step = min(bs, len(texts) - i)
        try:
            emb = model.encode(
                texts[i:i+step],
                batch_size=step,
                convert_to_tensor=True,
                show_progress_bar=False,
                device=DEVICE,
                normalize_embeddings=False, 
            )
        except RuntimeError as e:
            # Auto OOM fallback on MPS/CUDA
            if "out of memory" in str(e).lower() and bs > 16:
                if DEVICE == "mps": torch.mps.empty_cache()
                bs //= 2
                print(f"OOM fallback -> micro-batch {bs}")
                continue
            raise
        if NORM:
            emb = F.normalize(emb, p=2, dim=1)
        out.append(emb.to("cpu").to(DTYPE))
        i += step
        del emb
        if DEVICE == "mps":
            torch.mps.empty_cache()
    return torch.cat(out, dim=0)


### 6. Streaming loop: read → encode → save shards

In [6]:
pf = pq.ParquetFile(ITEMS_FP)
print("Row-groups:", pf.num_row_groups)

# For "resume": build a set of shard basenames that already exist
existing_set = {os.path.basename(p) for p in existing}

total_items = 0
batch_idx   = 0
t0 = time.time()

for batch in pf.iter_batches(columns=["item_id", TEXT_COL], batch_size=ARROW_BATCH_ROWS):
    # Convert only needed columns to Python lists (avoid building a giant DataFrame)
    ids   = batch.column(0).to_pylist()
    texts = batch.column(1).to_pylist()

    # Drop empties fast
    keep = [i for i, t in enumerate(texts) if t is not None and t != ""]
    if len(keep) != len(texts):
        ids   = [ids[i]   for i in keep]
        texts = [texts[i] for i in keep]

    shard_name = f"shard_batch_{batch_idx:06d}.pt"
    shard_path = os.path.join(OUT_DIR, shard_name)

    # Resume: skip if shard already exists
    if shard_name in existing_set:
        print(f"[batch {batch_idx:06d}] skipping (exists) → {shard_name}")
        total_items += len(ids)
        batch_idx   += 1
        # clean up
        del ids, texts, batch
        gc.collect()
        continue

    if not texts:
        print(f"[batch {batch_idx:06d}] empty texts; skipping")
        batch_idx += 1
        del ids, texts, batch
        gc.collect()
        continue

    t_batch = time.time()
    # Encode
    E = encode_texts_simple(texts, init_bs=TARGET_MICRO_BS)   # [Nb, 384] on CPU, fp16

    # Save shard
    torch.save({"item_id": ids[:len(E)], "embeddings": E}, shard_path)

    if USE_MANIFEST:
        with open(MANIFEST_FP, "a") as mf:
            mf.write(json.dumps({"batch": batch_idx, "n": len(E), "path": shard_name}) + "\n")

    # progress
    total_items += len(E)
    dt = time.time() - t_batch
    ips = int(len(E) / max(dt, 1e-6))
    print(f"[batch {batch_idx:06d}] rows={len(E):,} → {shard_name} | {ips} items/s | total={total_items:,}")

    # cleanup
    del ids, texts, E, batch
    gc.collect()
    if DEVICE == "mps":
        torch.mps.empty_cache()

    batch_idx += 1

print(f"Done. Batches: {batch_idx:,} | items: {total_items:,} | time: {(time.time()-t0)/60:.1f} min")

Row-groups: 3
[batch 000000] skipping (exists) → shard_batch_000000.pt
[batch 000001] skipping (exists) → shard_batch_000001.pt
[batch 000002] skipping (exists) → shard_batch_000002.pt
[batch 000003] skipping (exists) → shard_batch_000003.pt
[batch 000004] skipping (exists) → shard_batch_000004.pt
[batch 000005] skipping (exists) → shard_batch_000005.pt
[batch 000006] skipping (exists) → shard_batch_000006.pt
[batch 000007] skipping (exists) → shard_batch_000007.pt
[batch 000008] skipping (exists) → shard_batch_000008.pt
[batch 000009] skipping (exists) → shard_batch_000009.pt
[batch 000010] skipping (exists) → shard_batch_000010.pt
[batch 000011] skipping (exists) → shard_batch_000011.pt
[batch 000012] skipping (exists) → shard_batch_000012.pt
[batch 000013] rows=200,000 → shard_batch_000013.pt | 197 items/s | total=2,800,000
[batch 000014] rows=200,000 → shard_batch_000014.pt | 198 items/s | total=3,000,000
[batch 000015] rows=47,489 → shard_batch_000015.pt | 195 items/s | total=3,04

### 7. Concatenate shards → one matrix & row-normalize (Code)

In [12]:
paths_pt = sorted(glob.glob(os.path.join(OUT_DIR, "shard_batch_*.pt")))
assert paths_pt, "No shards found."

ids_chunks, vec_chunks = [], []
for p in paths_pt:
    blob = torch.load(p, map_location="cpu")
    ids_chunks.append(np.asarray(blob["item_id"], dtype=object))
    # keep in float16 from disk, cast to float32 only for stable normalization math
    v = blob["embeddings"].to(torch.float16).cpu().numpy().astype(np.float32, copy=False)
    vec_chunks.append(v)

item_ids = np.concatenate(ids_chunks)
X = np.vstack(vec_chunks)                      # float32 now
# row-wise L2 normalize (cosine-ready)
norm = np.linalg.norm(X, axis=1, keepdims=True)
norm[norm == 0] = 1.0
X = (X / norm).astype(np.float16, copy=False)  # back to fp16 for storage

np.save(os.path.join(OUT_DIR, "item_ids.npy"), item_ids)
np.save(os.path.join(OUT_DIR, "item_embeds.npy"), X)
print("Saved:", X.shape, X.dtype)

Saved: (3047489, 384) float16


### 8. Sanity checks: shape, dtype, norms, quick NN

In [28]:
# reload to be sure we’re reading what we wrote
ids = np.load(OUT_DIR / "item_ids.npy", allow_pickle=True)
E   = np.load(OUT_DIR / "item_embeds.npy")

print("IDs:", ids.shape, ids.dtype, "| E:", E.shape, E.dtype)

# norms (should be 1.0 mean/std)
row_norms = np.linalg.norm(E.astype(np.float32), axis=1)
print("norms -> mean:", row_norms.mean(), "std:", row_norms.std(), "min:", row_norms.min(), "max:", row_norms.max())

# quick NN
m = min(2000, E.shape[0])
Q = E[:5].astype(np.float32, copy=False)
C = E[:m].astype(np.float32, copy=False)
S = Q @ C.T                      # cosine since rows are unit
k = 5
idx_unsorted = np.argpartition(-S, k-1, axis=1)[:, :k]

# sort within the top-k
row_sorted = np.take_along_axis(
    np.argsort(-np.take_along_axis(S, idx_unsorted, axis=1), axis=1),
    np.arange(k)[None, :],
    axis=1
)
topk_idx = np.take_along_axis(idx_unsorted, row_sorted, axis=1)
topk_scores = np.take_along_axis(S, topk_idx, axis=1)

# map to item_ids
for qi in range(Q.shape[0]):
    nbrs = [(ids[j], float(topk_scores[qi, c])) for c, j in enumerate(topk_idx[qi])]
    print(f"q{qi} →", nbrs)


IDs: (3047489,) object | E: (3047489, 384) float16
norms -> mean: 0.99999994 std: 1.8474228e-05 min: 0.99991024 max: 1.0000939
q0 → [('0701169850', 1.0000230073928833), ('B008DM2LQ8', 0.465138703584671), ('1441599258', 0.45543068647384644), ('3548241107', 0.4497717320919037), ('1133307299', 0.4491029381752014)]
q1 → [('0316185361', 1.0000216960906982), ('0786477946', 0.4279870390892029), ('1501121960', 0.4205887019634247), ('0330361120', 0.3991989195346832), ('1585426628', 0.3881700336933136)]
q2 → [('0545425573', 1.0000718832015991), ('0670015547', 0.5404351353645325), ('1087719887', 0.5104107856750488), ('B007HXFBDE', 0.4974365532398224), ('0340875577', 0.49628835916519165)]
q3 → [('B00KFOP3RG', 1.0000190734863281), ('B01F0OPEB0', 0.4912773370742798), ('0984903054', 0.49120232462882996), ('1643136135', 0.4862109124660492), ('B09WFB2Z94', 0.4673328995704651)]
q4 → [('B09PHG4FQ8', 1.0000277757644653), ('0134485351', 0.6188583970069885), ('B08C4FTJF5', 0.5452044010162354), ('0766015491'

### 9. Metadata Snapshoot

In [19]:
meta = {
    "model_id": MODEL_ID,
    "cache_dir": CACHE_DIR,
    "max_len": MAX_LEN,
    "arrow_batch_rows": ARROW_BATCH_ROWS,
    "target_micro_bs": TARGET_MICRO_BS,
    "dtype": "float16",
    "norm_after_concat": True,
    "device": DEVICE,
    "items_file": ITEMS_FP,
    "text_col": TEXT_COL,
    "out_final": str(OUT_DIR),
}
with open(OUT_DIR / "stats.json", "w") as f:
    json.dump(meta, f, indent=2)
print("Saved:", OUT_DIR / "stats.json")

Saved: data/embeddings/items_miniLM/stats.json
