# Sift — Fine-Tune EmbeddingGemma on Google Colab

Train your personal feed-scoring model on a free T4 GPU, then download the ONNX model for use with the Sift Chrome extension.

## Prerequisites

**Accept the EmbeddingGemma license** on HuggingFace before running:
https://huggingface.co/google/embeddinggemma-300m

You'll also need an **HF token** (with read access) — create one at https://huggingface.co/settings/tokens and paste it in the Configuration cell below.

## Quick Start
1. **Set GPU runtime:** `Runtime → Change runtime type → T4 GPU`
2. **Run all cells:** `Runtime → Run all`
3. **Upload** your exported CSV when prompted
4. **Download** the resulting ONNX zip at the end

No local setup required — everything runs in this notebook.

In [None]:
# Install dependencies (torch is pre-installed by Colab with CUDA)
!pip install -q \
    "sentence-transformers>=3.0" \
    "transformers>=4.56.0" \
    "datasets>=2.18" \
    "accelerate>=0.30" \
    "huggingface-hub>=0.23" \
    "optimum[exporters,onnxruntime]>=1.21" \
    "onnx>=1.16" \
    "onnxruntime>=1.18" \
    "onnx-ir>=0.1"

In [None]:
import torch

if torch.cuda.is_available():
    DEVICE = "cuda"
    gpu_name = torch.cuda.get_device_name(0)
    try:
        vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        print(f"GPU: {gpu_name} ({vram_gb:.1f} GB VRAM)")
    except Exception:
        print(f"GPU: {gpu_name}")
else:
    DEVICE = "cpu"
    print("WARNING: No GPU detected. Training will be very slow.")
    print("Go to Runtime → Change runtime type → T4 GPU")

In [None]:
import csv
import io
from google.colab import files


def load_csv(content: str) -> list[list[str]]:
    """Load Anchor,Positive,Negative triplets from CSV string."""
    triplets = []
    reader = csv.reader(io.StringIO(content))

    try:
        first_row = next(reader)
    except StopIteration:
        return triplets

    def maybe_append_triplet(row: list[str]) -> None:
        if len(row) < 3:
            return
        anchor = row[0].strip()
        positive = row[1].strip()
        negative = row[2].strip()
        if anchor and positive and negative:
            triplets.append([anchor, positive, negative])

    is_header = (
        len(first_row) >= 3
        and first_row[0].strip().lower() == "anchor"
        and first_row[1].strip().lower() == "positive"
        and first_row[2].strip().lower() == "negative"
    )
    if not is_header:
        maybe_append_triplet(first_row)

    for row in reader:
        maybe_append_triplet(row)
    return triplets


uploaded = files.upload()
csv_filename = list(uploaded.keys())[0]
csv_content = uploaded[csv_filename].decode("utf-8")
triplets = load_csv(csv_content)

print(f"\nLoaded {len(triplets)} triplets from {csv_filename}")

# Summary by anchor
from collections import Counter
anchor_counts = Counter(t[0] for t in triplets)
for anchor, count in anchor_counts.most_common():
    print(f"  {anchor}: {count} triplets")

if triplets:
    print(f"\nExample: {triplets[0]}")

## Configuration

Edit the values below before training. The defaults work well for most cases.

In [None]:
# --- Model ---
MODEL_NAME = "google/embeddinggemma-300m"
TASK_NAME = "Classification"

# --- Training ---
EPOCHS = 4
LEARNING_RATE = 2e-5
BATCH_SIZE = 4  # T4 16GB handles 4; reduce to 1 if you hit OOM

# --- Held-out evaluation ---
HELDOUT_FRACTION = 0.15  # set to 0 to use all data for training
SEED = 42

# --- Output ---
FINETUNED_DIR = "sift-finetuned"
ONNX_DIR = "sift-finetuned_onnx_transformersjs"

# --- HuggingFace ---
HF_TOKEN = ""   # REQUIRED: needed to download gated model (read access)
HF_REPO = ""    # e.g. "yourname/sift-finetuned" — leave empty to skip upload

## Training

In [None]:
import random
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

from datasets import Dataset
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback
from huggingface_hub import HfApi, login, model_info, metadata_update


# --- Held-Out Evaluation ---

@dataclass
class HeldOutItem:
    text: str
    is_positive: bool
    baseline_score: float = 0.0

@dataclass
class AnchorHeldOutGroup:
    anchor: str
    items: list[HeldOutItem] = field(default_factory=list)


def _normalize_text(text: str) -> str:
    """Lowercase + collapse whitespace for near-duplicate detection."""
    return " ".join(text.lower().split())


def split_held_out(
    triplets: list[list[str]],
    fraction: float = 0.15,
    min_anchor_triplets: int = 4,
    seed: int = 42,
) -> tuple[list[list[str]], list[AnchorHeldOutGroup]]:
    """Split triplets into train set and per-anchor held-out groups."""
    if fraction == 0:
        return list(triplets), []

    rng = random.Random(seed)

    by_anchor: dict[str, list[list[str]]] = defaultdict(list)
    for t in triplets:
        by_anchor[t[0]].append(t)

    train_triplets: list[list[str]] = []
    held_out_groups: list[AnchorHeldOutGroup] = []

    for anchor, rows in by_anchor.items():
        if len(rows) < min_anchor_triplets:
            train_triplets.extend(rows)
            continue

        shuffled = rows[:]
        rng.shuffle(shuffled)
        n_held = max(1, round(len(rows) * fraction))
        held_rows = shuffled[:n_held]
        train_part = shuffled[n_held:]
        train_triplets.extend(train_part)

        train_norms: set[str] = set()
        for row in train_part:
            train_norms.add(_normalize_text(row[1]))
            train_norms.add(_normalize_text(row[2]))

        seen_norm: set[str] = set()
        items: list[HeldOutItem] = []
        candidates: list[tuple[str, bool, str]] = []
        for row in held_rows:
            candidates.append((row[1], True, _normalize_text(row[1])))
            candidates.append((row[2], False, _normalize_text(row[2])))

        for text, is_pos, norm in candidates:
            if norm not in seen_norm and norm not in train_norms:
                seen_norm.add(norm)
                items.append(HeldOutItem(text=text, is_positive=is_pos))

        has_pos = any(i.is_positive for i in items)
        has_neg = any(not i.is_positive for i in items)

        if not has_pos:
            for text, is_pos, norm in candidates:
                if is_pos and norm not in seen_norm:
                    seen_norm.add(norm)
                    items.append(HeldOutItem(text=text, is_positive=True))
                    break

        if not has_neg:
            for text, is_pos, norm in candidates:
                if (not is_pos) and norm not in seen_norm:
                    seen_norm.add(norm)
                    items.append(HeldOutItem(text=text, is_positive=False))
                    break

        if items:
            held_out_groups.append(AnchorHeldOutGroup(anchor=anchor, items=items))

    return train_triplets, held_out_groups


# --- Scoring & Formatting ---

def score_held_out_items(
    model: SentenceTransformer,
    groups: list[AnchorHeldOutGroup],
    task_name: str,
) -> dict[str, list[float]]:
    """Score each held-out item against its anchor."""
    results: dict[str, list[float]] = {}
    for g in groups:
        anchor_emb = model.encode(g.anchor, prompt_name=task_name)
        texts = [item.text for item in g.items]
        text_embs = model.encode(texts, prompt_name=task_name)
        sims = util.cos_sim(anchor_emb, text_embs)[0].tolist()
        results[g.anchor] = sims
    return results


def format_taste_table(
    groups: list[AnchorHeldOutGroup],
    scores: dict[str, list[float]],
    header: str,
    show_baseline_delta: bool = False,
) -> str:
    """Format a per-epoch taste table with optional delta from baseline."""
    lines = [f"\n=== Taste Check ({header}) {'=' * max(1, 45 - len(header))}"]
    for g in groups:
        s = scores[g.anchor]
        lines.append(f"\nAnchor: {g.anchor} ({len(g.items)} items)")
        pos_scores, neg_scores = [], []
        for item, score in zip(g.items, s):
            tag = "+" if item.is_positive else "-"
            label = item.text[:50]
            delta_str = ""
            if show_baseline_delta:
                delta = score - item.baseline_score
                delta_str = f"  ({delta:+.2f})"
            lines.append(f"  {tag} \"{label}\"{' ' * max(1, 55 - len(label))}{score:.2f}{delta_str}")
            (pos_scores if item.is_positive else neg_scores).append(score)

        avg_p = sum(pos_scores) / len(pos_scores) if pos_scores else 0
        avg_n = sum(neg_scores) / len(neg_scores) if neg_scores else 0
        gap = avg_p - avg_n
        n_pairs = len(pos_scores) * len(neg_scores)
        n_correct = sum(1 for ps in pos_scores for ns in neg_scores if ps > ns)
        pair_pct = (n_correct / n_pairs * 100) if n_pairs else 0
        gap_str = f"  gap: {gap:.2f}"
        pair_str = f"  pos>neg: {pair_pct:.0f}%"
        if show_baseline_delta:
            bp = [it.baseline_score for it in g.items if it.is_positive]
            bn = [it.baseline_score for it in g.items if not it.is_positive]
            old_gap = (sum(bp) / len(bp) if bp else 0) - (sum(bn) / len(bn) if bn else 0)
            gap_str += f"  (was {old_gap:.2f})"
            base_correct = sum(1 for ps in bp for ns in bn if ps > ns)
            base_pairs = len(bp) * len(bn)
            base_pct = (base_correct / base_pairs * 100) if base_pairs else 0
            pair_str += f" (was {base_pct:.0f}%)"
        lines.append(f"  avg +: {avg_p:.2f}  avg -: {avg_n:.2f}{gap_str}{pair_str}")
    return "\n".join(lines)


def format_taste_final(
    groups: list[AnchorHeldOutGroup],
    final_scores: dict[str, list[float]],
) -> str:
    """Format the before→after final summary."""
    lines = [f"\n=== Taste Check -- Final {'=' * 30}"]
    for g in groups:
        s = final_scores[g.anchor]
        lines.append(f"\nAnchor: {g.anchor}")
        lines.append(f"  {'':55s} Before -> After")
        pos_before, pos_after, neg_before, neg_after = [], [], [], []
        for item, score in zip(g.items, s):
            tag = "+" if item.is_positive else "-"
            label = item.text[:50]
            delta = score - item.baseline_score
            lines.append(
                f"  {tag} \"{label}\"{' ' * max(1, 55 - len(label))}"
                f"{item.baseline_score:.2f}  ->  {score:.2f}  ({delta:+.2f})"
            )
            if item.is_positive:
                pos_before.append(item.baseline_score)
                pos_after.append(score)
            else:
                neg_before.append(item.baseline_score)
                neg_after.append(score)

        avg_pb = sum(pos_before) / len(pos_before) if pos_before else 0
        avg_pa = sum(pos_after) / len(pos_after) if pos_after else 0
        avg_nb = sum(neg_before) / len(neg_before) if neg_before else 0
        avg_na = sum(neg_after) / len(neg_after) if neg_after else 0
        gap_b = avg_pb - avg_nb
        gap_a = avg_pa - avg_na
        n_pairs = len(pos_before) * len(neg_before)
        pct_b = (sum(1 for p in pos_before for n in neg_before if p > n) / n_pairs * 100) if n_pairs else 0
        pct_a = (sum(1 for p in pos_after for n in neg_after if p > n) / n_pairs * 100) if n_pairs else 0
        lines.append(
            f"  avg +: {avg_pb:.2f} -> {avg_pa:.2f}  "
            f"avg -: {avg_nb:.2f} -> {avg_na:.2f}  "
            f"gap: {gap_b:.2f} -> {gap_a:.2f}  "
            f"pos>neg: {pct_b:.0f}% -> {pct_a:.0f}%"
        )
    return "\n".join(lines)


class TasteTracker(TrainerCallback):
    """Scores held-out items at baseline and each epoch to track taste alignment."""

    def __init__(
        self,
        model: SentenceTransformer,
        groups: list[AnchorHeldOutGroup],
        task_name: str,
    ):
        self.model = model
        self.groups = groups
        self.task_name = task_name
        self.final_scores: dict[str, list[float]] = {}

    def on_train_begin(self, args, state, control, **kwargs):
        scores = score_held_out_items(self.model, self.groups, self.task_name)
        for g in self.groups:
            for item, s in zip(g.items, scores[g.anchor]):
                item.baseline_score = s
        if state.is_world_process_zero:
            print(format_taste_table(self.groups, scores, "baseline"))

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch = int(state.epoch)
        scores = score_held_out_items(self.model, self.groups, self.task_name)
        self.final_scores = scores
        if state.is_world_process_zero:
            print(format_taste_table(self.groups, scores, f"epoch {epoch}", show_baseline_delta=True))

    def get_final_summary(self) -> str:
        if not self.final_scores:
            return ""
        return format_taste_final(self.groups, self.final_scores)


# --- Split held-out ---

train_triplets, held_out_groups = split_held_out(
    triplets,
    fraction=HELDOUT_FRACTION,
    min_anchor_triplets=4,
    seed=SEED,
)

held_out_count = sum(len(g.items) for g in held_out_groups)
if held_out_groups:
    anchors = ", ".join(g.anchor for g in held_out_groups)
    print(f"Split: {len(train_triplets)} train, {held_out_count} held-out items across {len(held_out_groups)} anchor(s) [{anchors}]")
else:
    print(f"All {len(train_triplets)} triplets used for training (too few per anchor to split)")

assert len(train_triplets) >= 2, (
    f"Need at least 2 training triplets, got {len(train_triplets)}. Collect more labels!"
)

# --- Load model ---

print(f"\nLoading {MODEL_NAME} on {DEVICE}...")

if HF_TOKEN:
    login(token=HF_TOKEN, add_to_git_credential=True)
    model = SentenceTransformer(MODEL_NAME, device=DEVICE, token=HF_TOKEN)
else:
    model = SentenceTransformer(MODEL_NAME, device=DEVICE)

print(f"Model loaded on {model.device}")

# --- Train ---

data_as_dicts = [
    {"anchor": row[0], "positive": row[1], "negative": row[2]}
    for row in train_triplets
]
train_dataset = Dataset.from_list(data_as_dicts)
loss = MultipleNegativesRankingLoss(model)

prompts = getattr(model, 'prompts', {}).get(TASK_NAME)

callbacks = []
taste_tracker = None
if held_out_groups:
    taste_tracker = TasteTracker(model, held_out_groups, TASK_NAME)
    callbacks.append(taste_tracker)

is_cuda = DEVICE == "cuda"

training_args = SentenceTransformerTrainingArguments(
    output_dir=FINETUNED_DIR,
    prompts=prompts,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    logging_steps=train_dataset.num_rows,
    report_to="none",
    save_strategy="no",
    dataloader_pin_memory=is_cuda,
    fp16=is_cuda,
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=callbacks,
)

print(f"\nTraining: {EPOCHS} epochs, lr={LEARNING_RATE}, batch_size={BATCH_SIZE}")
trainer.train()

print("Training finished. Saving model...")
trainer.save_model()
print(f"Model saved to: {FINETUNED_DIR}")

In [None]:
# Display taste eval results (before → after)
if taste_tracker:
    print(taste_tracker.get_final_summary())
else:
    print("No held-out evaluation (HELDOUT_FRACTION was 0 or too few triplets per anchor).")

## ONNX Conversion

Converts the fine-tuned model to ONNX format for use with Transformers.js in the Chrome extension.
Produces four variants: fp32, int8, q4 (WASM), and q4 no_gather (WebGPU-compatible).

In [None]:
import logging
import shutil
import warnings
from pathlib import Path
from optimum.exporters.onnx import main_export

model_dir = Path(FINETUNED_DIR)
output_dir = Path(ONNX_DIR)

print("--- ONNX Conversion ---")
print(f"Exporting {model_dir} → {output_dir}...")

try:
    from torch.jit import TracerWarning
except Exception:
    TracerWarning = UserWarning

export_loggers = ("transformers", "optimum", "torch.onnx", "onnxruntime")
prev_logger_levels = {}
for name in export_loggers:
    logger = logging.getLogger(name)
    prev_logger_levels[name] = logger.level
    logger.setLevel(logging.ERROR)
try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message="`torch_dtype` is deprecated! Use `dtype` instead!")
        warnings.filterwarnings(
            "ignore",
            message=r"The tokenizer you are loading from .*incorrect regex pattern.*",
        )
        warnings.filterwarnings("ignore", category=TracerWarning)
        warnings.filterwarnings(
            "ignore",
            message=r"Exporting aten::index operator of advanced indexing.*",
            category=UserWarning,
        )
        main_export(
            model_name_or_path=str(model_dir),
            output=output_dir,
            task="feature-extraction",
            device="cpu",
            dtype="fp32",
            library_name="sentence_transformers",
            do_validation=False,
        )
finally:
    for name, level in prev_logger_levels.items():
        logging.getLogger(name).setLevel(level)

# optimum puts model.onnx at root; Transformers.js expects onnx/ subdirectory
onnx_subdir = output_dir / "onnx"
onnx_subdir.mkdir(exist_ok=True)
root_onnx = output_dir / "model.onnx"
onnx_path = onnx_subdir / "model.onnx"
if root_onnx.exists():
    shutil.move(str(root_onnx), str(onnx_path))

size_mb = onnx_path.stat().st_size / (1024 * 1024)
print(f"ONNX model (fp32): {size_mb:.1f} MB")

# INT8 dynamic quantization
prev_disable = logging.root.manager.disable
logging.disable(logging.WARNING)

try:
    from onnxruntime.quantization import quantize_dynamic, QuantType

    quant_path = onnx_subdir / "model_quantized.onnx"
    print("Quantizing to INT8...")
    quantize_dynamic(str(onnx_path), str(quant_path), weight_type=QuantType.QInt8)
    print(f"INT8 model: {quant_path.stat().st_size / (1024*1024):.1f} MB")
except Exception as e:
    print(f"INT8 quantization failed (non-critical): {e}")

# 4-bit block quantization
try:
    import onnx
    from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer

    print("Quantizing to Q4...")
    model_proto = onnx.load(str(onnx_path))
    quant = MatMulNBitsQuantizer(
        model_proto, block_size=32, is_symmetric=True, accuracy_level=4,
    )
    quant.process()
    q4_path = onnx_subdir / "model_q4.onnx"
    onnx.save(quant.model.model, str(q4_path))
    print(f"Q4 model: {q4_path.stat().st_size / (1024*1024):.1f} MB")

    # WebGPU-compatible variant: strip GatherElements ops if present
    no_gather_path = onnx_subdir / "model_no_gather_q4.onnx"
    q4_model = onnx.load(str(q4_path))
    gather_nodes = [n for n in q4_model.graph.node if n.op_type == "GatherElements"]
    if not gather_nodes:
        shutil.copy2(str(q4_path), str(no_gather_path))
        print(f"Q4 no_gather (WebGPU): copied (no GatherElements ops found)")
    else:
        for node in gather_nodes:
            node.op_type = "Gather"
        onnx.save(q4_model, str(no_gather_path))
        print(f"Q4 no_gather (WebGPU): replaced {len(gather_nodes)} GatherElements → Gather")
except Exception as e:
    print(f"Q4 quantization failed (non-critical): {e}")

logging.disable(prev_disable)

print(f"\nTransformers.js model ready at: {output_dir}")

In [None]:
import shutil
from google.colab import files

zip_name = "sift-onnx-model"
shutil.make_archive(zip_name, "zip", str(output_dir))
print(f"Created {zip_name}.zip")
files.download(f"{zip_name}.zip")

## Optional: Push to HuggingFace Hub

Set `HF_REPO` and `HF_TOKEN` in the Configuration cell above, then run this cell.
The model must be public for the Sift extension to load it (HF auth is not supported browser-side).

In [None]:
if not HF_REPO:
    print("Skipping Hub upload: HF_REPO not set in Configuration cell.")
else:
    if not HF_TOKEN:
        print("Error: HF_TOKEN is required to push to Hub.")
    else:
        login(token=HF_TOKEN)
        api = HfApi(token=HF_TOKEN)

        requested = HF_REPO.strip().strip("/")
        if "/" in requested:
            repo_id = requested
        else:
            user_info = api.whoami()
            repo_id = f"{user_info['name']}/{requested}"

        print(f"Uploading to: {repo_id}")
        api.create_repo(repo_id=repo_id, exist_ok=True)
        url = api.upload_folder(
            folder_path=str(output_dir),
            repo_id=repo_id,
            repo_type="model",
        )

        info = model_info(repo_id=repo_id, token=HF_TOKEN)
        tags = list((info.card_data.tags if info.card_data else []) or [])
        if "embeddinggemma-tuning-lab" not in tags:
            tags.append("embeddinggemma-tuning-lab")
            metadata_update(
                repo_id=repo_id,
                metadata={"tags": tags},
                overwrite=True,
                token=HF_TOKEN,
            )

        print(f"\nSuccess! Model published at: {url}")