# BGE-M3 Matryoshka Fine-tuning + ONNX INT8

Fine-tune `BAAI/bge-m3` with **MatryoshkaLoss** for truncatable embeddings `[1024, 768, 512, 256]`, then export to **ONNX INT8**.

**Requirements:** T4 GPU (free Colab), ~4-6h training time.

**Output:** Model pushed to `tss-deposium/bge-m3-matryoshka-1024d` on HuggingFace Hub.

---

## 0. Setup

In [None]:
# Install dependencies
!pip install -q sentence-transformers>=3.3.0 datasets>=3.0.0 accelerate>=1.0.0 optimum[onnxruntime]>=1.22.0 onnxruntime>=1.19.0 huggingface_hub>=0.25.0

In [None]:
# Check GPU
import torch
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Configuration
import os

# Set your HF token (Colab: Settings > Secrets > add HF_TOKEN)
# Or uncomment and paste directly:
# os.environ["HF_TOKEN"] = "hf_your_token_here"

try:
    from google.colab import userdata
    HF_TOKEN = userdata.get("HF_TOKEN")
    print("HF_TOKEN loaded from Colab secrets")
except Exception:
    HF_TOKEN = os.getenv("HF_TOKEN", "")
    if HF_TOKEN:
        print("HF_TOKEN loaded from environment")
    else:
        print("WARNING: No HF_TOKEN found. Set it above to push to Hub.")

# Model config
BASE_MODEL = "BAAI/bge-m3"
HF_REPO_ID = "tss-deposium/bge-m3-matryoshka-1024d"
MATRYOSHKA_DIMS = [1024, 768, 512, 256]

# Training config
NUM_EPOCHS = 4
BATCH_SIZE = 16  # T4 16GB handles this fine, reduce to 8 if OOM
LEARNING_RATE = 2e-5

# Dataset config
MIRACL_LANGUAGES = ["fr", "en", "de", "es", "ar", "zh", "ja", "ru"]
MIRACL_MAX_SAMPLES = 50_000
MMARCO_MAX_SAMPLES = 20_000

# Output
OUTPUT_DIR = "./bge-m3-matryoshka-deposium"
ONNX_OUTPUT_DIR = "./bge-m3-matryoshka-onnx-int8"

## 1. Load Dataset (MIRACL + mMARCO)

In [None]:
from datasets import load_dataset, Dataset
import random

all_pairs = []

# --- MIRACL (multilingual retrieval) ---
for lang in MIRACL_LANGUAGES:
    try:
        print(f"Loading MIRACL/{lang}...")
        ds = load_dataset("miracl/miracl", lang, split="train", trust_remote_code=True)
        count_before = len(all_pairs)
        for row in ds:
            query = row.get("query", "")
            for pos in row.get("positive_passages", []):
                text = pos.get("text", "")
                if query and text:
                    all_pairs.append({"anchor": query, "positive": text})
            if len(all_pairs) >= MIRACL_MAX_SAMPLES:
                break
        print(f"  +{len(all_pairs) - count_before} pairs (total: {len(all_pairs)})")
    except Exception as e:
        print(f"  Skipping {lang}: {e}")
    if len(all_pairs) >= MIRACL_MAX_SAMPLES:
        break

miracl_count = len(all_pairs)
print(f"\nMIRACL total: {miracl_count} pairs")

In [None]:
# --- mMARCO (multilingual MS MARCO, French) ---
try:
    print("Loading mMARCO/french (streaming)...")
    ds = load_dataset("unicamp-dl/mmarco", "french", split="train", streaming=True)
    mmarco_count = 0
    for row in ds:
        query = row.get("query", "")
        positive = row.get("positive", "")
        if query and positive:
            all_pairs.append({"anchor": query, "positive": positive})
            mmarco_count += 1
        if mmarco_count >= MMARCO_MAX_SAMPLES:
            break
    print(f"  +{mmarco_count} pairs from mMARCO")
except Exception as e:
    print(f"  Skipping mMARCO: {e}")

print(f"\nTotal dataset: {len(all_pairs)} pairs")

In [None]:
# Shuffle and split
random.seed(42)
random.shuffle(all_pairs)

split_idx = int(len(all_pairs) * 0.95)
train_dataset = Dataset.from_list(all_pairs[:split_idx])
eval_dataset = Dataset.from_list(all_pairs[split_idx:])

print(f"Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")
print(f"\nSample: {train_dataset[0]}")

## 2. Fine-tune with MatryoshkaLoss

In [None]:
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

# Load base model
print(f"Loading {BASE_MODEL}...")
model = SentenceTransformer(BASE_MODEL)
print(f"Loaded. Embedding dim: {model.get_sentence_embedding_dimension()}")

# Matryoshka + MNRL loss
inner_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model,
    inner_loss,
    matryoshka_dims=MATRYOSHKA_DIMS,
    matryoshka_weights=[1, 1, 1, 1],
)

In [None]:
# Training arguments
args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    fp16=True,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

# Trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
)

# Train!
print("Starting training...")
trainer.train()

In [None]:
# Save the fine-tuned model
model.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

## 3. Validate Matryoshka Embeddings

In [None]:
import numpy as np

test_sentences = [
    "Comment resilier un contrat d'assurance ?",
    "La resiliation d'un contrat peut etre effectuee par lettre recommandee avec accuse de reception.",
    "Les previsions meteo annoncent de la pluie demain.",
    "How to cancel an insurance contract?",
    "El contrato puede ser rescindido mediante carta certificada.",
]

embeddings = model.encode(test_sentences, normalize_embeddings=True)

print(f"{'Dim':>6} | {'Sim(q,doc)':>10} | {'Sim(q,noise)':>12} | {'Delta':>8} | {'Status':>8}")
print("-" * 60)

for dim in MATRYOSHKA_DIMS:
    embs = embeddings[:, :dim]
    norms = np.linalg.norm(embs, axis=1, keepdims=True)
    embs = embs / norms
    sim_relevant = np.dot(embs[0], embs[1])
    sim_noise = np.dot(embs[0], embs[2])
    delta = sim_relevant - sim_noise
    status = "OK" if delta > 0.1 else "WARN"
    print(f"{dim:>6} | {sim_relevant:>10.4f} | {sim_noise:>12.4f} | {delta:>8.4f} | {status:>8}")

print("\nCross-lingual (FR query vs):")
for lang, idx in [("FR doc", 1), ("EN query", 3), ("ES doc", 4)]:
    for dim in [1024, 256]:
        embs = embeddings[:, :dim]
        norms = np.linalg.norm(embs, axis=1, keepdims=True)
        embs = embs / norms
        print(f"  {lang} @ {dim}D: {np.dot(embs[0], embs[idx]):.4f}")

## 4. Export ONNX INT8

In [None]:
# Try sentence-transformers native export first
try:
    from sentence_transformers import export_dynamic_quantized_onnx_model

    print("Exporting with sentence-transformers native method...")
    export_dynamic_quantized_onnx_model(
        model,
        quantization_config="avx512_vnni",
        model_name_or_path=ONNX_OUTPUT_DIR,
    )
    print(f"ONNX INT8 saved to {ONNX_OUTPUT_DIR}")

except Exception as e:
    print(f"Native export failed: {e}")
    print("Falling back to HuggingFace Optimum...")

    from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTQuantizer
    from optimum.onnxruntime.configuration import AutoQuantizationConfig
    import shutil

    # Export FP32
    ort_model = ORTModelForFeatureExtraction.from_pretrained(OUTPUT_DIR, export=True)
    ort_model.save_pretrained(f"{ONNX_OUTPUT_DIR}-fp32-tmp")

    # Quantize to INT8
    quantizer = ORTQuantizer.from_pretrained(f"{ONNX_OUTPUT_DIR}-fp32-tmp")
    qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
    quantizer.quantize(save_dir=ONNX_OUTPUT_DIR, quantization_config=qconfig)

    # Copy tokenizer files
    from pathlib import Path
    for f in Path(f"{ONNX_OUTPUT_DIR}-fp32-tmp").glob("*.json"):
        if "model" not in f.name.lower() or f.name == "config.json":
            shutil.copy2(f, ONNX_OUTPUT_DIR)
    shutil.rmtree(f"{ONNX_OUTPUT_DIR}-fp32-tmp", ignore_errors=True)

    print(f"ONNX INT8 saved to {ONNX_OUTPUT_DIR}")

# Show model size
import os
total_size = sum(f.stat().st_size for f in Path(ONNX_OUTPUT_DIR).rglob("*") if f.is_file())
print(f"Total ONNX model size: {total_size / 1e6:.0f} MB")

## 5. Push to HuggingFace Hub

In [None]:
from huggingface_hub import HfApi, login

if not HF_TOKEN:
    print("No HF_TOKEN set. Run this to push manually:")
    print(f"  huggingface-cli login")
    print(f"  huggingface-cli upload {HF_REPO_ID} {OUTPUT_DIR}")
    print(f"  huggingface-cli upload {HF_REPO_ID}-onnx-int8 {ONNX_OUTPUT_DIR}")
else:
    login(token=HF_TOKEN)
    api = HfApi()

    # Push PyTorch model
    print(f"Pushing PyTorch model to {HF_REPO_ID}...")
    api.create_repo(HF_REPO_ID, exist_ok=True, private=False)
    api.upload_folder(
        folder_path=OUTPUT_DIR,
        repo_id=HF_REPO_ID,
        commit_message="BGE-M3 Matryoshka [1024, 768, 512, 256]",
    )
    print(f"Done: https://huggingface.co/{HF_REPO_ID}")

    # Push ONNX INT8 model
    onnx_repo = f"{HF_REPO_ID}-onnx-int8"
    print(f"\nPushing ONNX INT8 model to {onnx_repo}...")
    api.create_repo(onnx_repo, exist_ok=True, private=False)
    api.upload_folder(
        folder_path=ONNX_OUTPUT_DIR,
        repo_id=onnx_repo,
        commit_message="BGE-M3 Matryoshka ONNX INT8 [1024, 768, 512, 256]",
    )
    print(f"Done: https://huggingface.co/{onnx_repo}")

## Done!

Models are now on HuggingFace Hub. Next steps:
1. Update `model_manager.py` to reference the new model
2. Add `truncate_dims` support in `OnnxEmbeddingModel`
3. Test in staging with the existing benchmark suite