# BGE-M3 Matryoshka Fine-tuning + ONNX INT8

Fine-tune `BAAI/bge-m3` with **MatryoshkaLoss** for truncatable embeddings `[1024, 768, 512, 256]`, then export to **ONNX INT8**.

**Requirements:** A100 GPU (Colab Pro), ~1.5-2h training time.

**Output:** Model pushed to HuggingFace Hub.

---

## 0. Setup

In [None]:
!pip install -q sentence-transformers>=3.3.0 datasets>=3.0.0 accelerate>=1.0.0 optimum[onnxruntime]>=1.22.0 onnxruntime>=1.19.0 huggingface_hub>=0.25.0

In [None]:
import torch, gc, os, random
import numpy as np

# Clear any leftover GPU memory
gc.collect()
torch.cuda.empty_cache()

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================
# CONFIGURATION - edit these values
# ============================================================

BASE_MODEL = "BAAI/bge-m3"
HF_REPO_ID = "tss-deposium/bge-m3-matryoshka-1024d"
MATRYOSHKA_DIMS = [1024, 768, 512, 256]

NUM_EPOCHS = 3
BATCH_SIZE = 32           # A100 80GB handles 32 easily
LEARNING_RATE = 2e-5

ALLNLI_MAX_SAMPLES = 0    # 0 = all (~560k)
NQ_MAX_SAMPLES = 0        # 0 = all (~100k)
GOOAQ_MAX_SAMPLES = 50000

OUTPUT_DIR = "./bge-m3-matryoshka-deposium"
ONNX_OUTPUT_DIR = "./bge-m3-matryoshka-onnx-int8"

# HF Token
HF_TOKEN = ""
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get("HF_TOKEN")
    print(f"HF_TOKEN loaded from Colab secrets")
except Exception:
    HF_TOKEN = os.getenv("HF_TOKEN", "")

if not HF_TOKEN:
    print("WARNING: No HF_TOKEN. Set it in Colab: Settings > Secrets > HF_TOKEN")
    print("The model will still train, but won't push to Hub automatically.")

## 1. Load Datasets

In [None]:
from datasets import load_dataset, Dataset

all_pairs = []

# --- all-nli (~560k NLI pairs) ---
print("Loading sentence-transformers/all-nli...")
try:
    nli = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
    for row in nli:
        all_pairs.append({"anchor": row["anchor"], "positive": row["positive"]})
        if ALLNLI_MAX_SAMPLES and len(all_pairs) >= ALLNLI_MAX_SAMPLES:
            break
    print(f"  +{len(all_pairs)} pairs")
except Exception as e:
    print(f"  Skipping: {e}")

nli_count = len(all_pairs)

# --- natural-questions (~100k Q&A) ---
print("Loading sentence-transformers/natural-questions...")
try:
    nq = load_dataset("sentence-transformers/natural-questions", split="train")
    before = len(all_pairs)
    for row in nq:
        all_pairs.append({"anchor": row["query"], "positive": row["answer"]})
        if NQ_MAX_SAMPLES and (len(all_pairs) - before) >= NQ_MAX_SAMPLES:
            break
    print(f"  +{len(all_pairs) - before} pairs")
except Exception as e:
    print(f"  Skipping: {e}")

nq_count = len(all_pairs) - nli_count

# --- gooaq (50k from 3M+ Q&A) ---
print(f"Loading sentence-transformers/gooaq (max {GOOAQ_MAX_SAMPLES})...")
try:
    gooaq = load_dataset("sentence-transformers/gooaq", split="train", streaming=True)
    before = len(all_pairs)
    for row in gooaq:
        all_pairs.append({"anchor": row["question"], "positive": row["answer"]})
        if (len(all_pairs) - before) >= GOOAQ_MAX_SAMPLES:
            break
    print(f"  +{len(all_pairs) - before} pairs")
except Exception as e:
    print(f"  Skipping: {e}")

gooaq_count = len(all_pairs) - nli_count - nq_count

# Shuffle and split
random.seed(42)
random.shuffle(all_pairs)
split_idx = int(len(all_pairs) * 0.95)
train_dataset = Dataset.from_list(all_pairs[:split_idx])
eval_dataset = Dataset.from_list(all_pairs[split_idx:])

print(f"\nTotal: {len(all_pairs)} pairs (nli: {nli_count}, nq: {nq_count}, gooaq: {gooaq_count})")
print(f"Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")

# Free raw data
del all_pairs, nli, nq
gc.collect()

## 2. Fine-tune with MatryoshkaLoss

In [None]:
# Clear GPU before loading model
gc.collect()
torch.cuda.empty_cache()

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

# Load model
print(f"Loading {BASE_MODEL}...")
model = SentenceTransformer(BASE_MODEL)
print(f"Loaded. Embedding dim: {model.get_sentence_embedding_dimension()}")

# Loss
train_loss = MatryoshkaLoss(
    model,
    MultipleNegativesRankingLoss(model),
    matryoshka_dims=MATRYOSHKA_DIMS,
    matryoshka_weights=[1, 1, 1, 1],
)

# Training
args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    bf16=True,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    gradient_checkpointing=True,
    dataloader_pin_memory=False,
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
)

print(f"\nConfig: batch={BATCH_SIZE}, epochs={NUM_EPOCHS}, lr={LEARNING_RATE}")
print(f"Steps per epoch: {len(train_dataset) // BATCH_SIZE}")
print(f"Total steps: {(len(train_dataset) // BATCH_SIZE) * NUM_EPOCHS}")
print("\nStarting training...")
trainer.train()

In [None]:
# Save
model.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

## 3. Validate Matryoshka Embeddings

In [None]:
test_sentences = [
    "Comment resilier un contrat d'assurance ?",
    "La resiliation d'un contrat peut etre effectuee par lettre recommandee.",
    "Les previsions meteo annoncent de la pluie demain.",
    "How to cancel an insurance contract?",
    "El contrato puede ser rescindido mediante carta certificada.",
]

embeddings = model.encode(test_sentences, normalize_embeddings=True)

print(f"{'Dim':>6} | {'Sim(q,doc)':>10} | {'Sim(q,noise)':>12} | {'Delta':>8} | {'Status':>8}")
print("-" * 60)

for dim in MATRYOSHKA_DIMS:
    embs = embeddings[:, :dim]
    norms = np.linalg.norm(embs, axis=1, keepdims=True)
    embs = embs / norms
    sim_ok = np.dot(embs[0], embs[1])
    sim_bad = np.dot(embs[0], embs[2])
    delta = sim_ok - sim_bad
    print(f"{dim:>6} | {sim_ok:>10.4f} | {sim_bad:>12.4f} | {delta:>8.4f} | {'OK' if delta > 0.1 else 'WARN':>8}")

print("\nCross-lingual (FR query vs):")
for lang, idx in [("FR doc", 1), ("EN query", 3), ("ES doc", 4)]:
    for dim in [1024, 256]:
        embs = embeddings[:, :dim]
        norms = np.linalg.norm(embs, axis=1, keepdims=True)
        embs = embs / norms
        print(f"  {lang} @ {dim}D: {np.dot(embs[0], embs[idx]):.4f}")

## 4. Export ONNX INT8

In [None]:
from pathlib import Path

try:
    from sentence_transformers import export_dynamic_quantized_onnx_model
    print("Exporting ONNX INT8 (sentence-transformers native)...")
    export_dynamic_quantized_onnx_model(
        model,
        quantization_config="avx512_vnni",
        model_name_or_path=ONNX_OUTPUT_DIR,
    )
    print(f"Saved to {ONNX_OUTPUT_DIR}")

except Exception as e:
    print(f"Native export failed: {e}")
    print("Falling back to HuggingFace Optimum...")

    from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTQuantizer
    from optimum.onnxruntime.configuration import AutoQuantizationConfig
    import shutil

    ort_model = ORTModelForFeatureExtraction.from_pretrained(OUTPUT_DIR, export=True)
    ort_model.save_pretrained(f"{ONNX_OUTPUT_DIR}-fp32-tmp")

    quantizer = ORTQuantizer.from_pretrained(f"{ONNX_OUTPUT_DIR}-fp32-tmp")
    qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
    quantizer.quantize(save_dir=ONNX_OUTPUT_DIR, quantization_config=qconfig)

    for f in Path(f"{ONNX_OUTPUT_DIR}-fp32-tmp").glob("*.json"):
        if "model" not in f.name.lower() or f.name == "config.json":
            shutil.copy2(f, ONNX_OUTPUT_DIR)
    shutil.rmtree(f"{ONNX_OUTPUT_DIR}-fp32-tmp", ignore_errors=True)
    print(f"Saved to {ONNX_OUTPUT_DIR}")

total_size = sum(f.stat().st_size for f in Path(ONNX_OUTPUT_DIR).rglob("*") if f.is_file())
print(f"ONNX model size: {total_size / 1e6:.0f} MB")

## 5. Push to HuggingFace Hub

In [None]:
from huggingface_hub import HfApi

if not HF_TOKEN:
    print("No HF_TOKEN. To push manually:")
    print(f"  !huggingface-cli login")
    print(f"  !huggingface-cli upload {HF_REPO_ID} {OUTPUT_DIR}")
    print(f"  !huggingface-cli upload {HF_REPO_ID}-onnx-int8 {ONNX_OUTPUT_DIR}")
else:
    # Pass token directly to HfApi (login() has auth issues on Colab)
    api = HfApi(token=HF_TOKEN)

    print(f"Pushing PyTorch model to {HF_REPO_ID}...")
    api.create_repo(HF_REPO_ID, exist_ok=True, private=False)
    api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID,
                      commit_message="BGE-M3 Matryoshka [1024, 768, 512, 256]")
    print(f"Done: https://huggingface.co/{HF_REPO_ID}")

    onnx_repo = f"{HF_REPO_ID}-onnx-int8"
    print(f"\nPushing ONNX INT8 to {onnx_repo}...")
    api.create_repo(onnx_repo, exist_ok=True, private=False)
    api.upload_folder(folder_path=ONNX_OUTPUT_DIR, repo_id=onnx_repo,
                      commit_message="BGE-M3 Matryoshka ONNX INT8 [1024, 768, 512, 256]")
    print(f"Done: https://huggingface.co/{onnx_repo}")