# EBR Fine-tuning, Inference, and Embedding Visualization
This notebook consolidates the existing `ebr_finetune.py`, `ebr_infer.py`, and `ebr_dim_reduction_vis.py` scripts into a single, parameterized workflow. Adjust the configuration cells as needed, then execute the desired sections (fine-tuning, batch inference, visualization) end-to-end.


## How to Use
1. Review the dependency versions in `requirements.txt` and install them in the current environment.
2. Update the configuration dictionaries (paths, batch sizes, prompt text, etc.) to match your data layout.
3. Enable the `RUN_*` flags before executing the fine-tuning, inference, or visualization cells.
4. Optional: skip costly steps (e.g., training) by leaving the corresponding flag set to `False`.


In [None]:
import os
import logging
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import umap
import dask.dataframe as dd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from sklearn.manifold import TSNE
from sklearn.metrics import (
    silhouette_score,
    calinski_harabasz_score,
    davies_bouldin_score,
)

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction
from sentence_transformers.losses import MultipleNegativesSymmetricRankingLoss
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from peft import LoraConfig, TaskType


In [None]:
def generate_negative(dataset_dict: Dict[str, List[str]]) -> List[str]:
    candidate_pool = list(set(dataset_dict.get("positive", [])))
    negatives: List[str] = []
    if not candidate_pool:
        return negatives
    for pos in dataset_dict.get("positive", []):
        neg = random.choice(candidate_pool)
        while neg == pos and len(candidate_pool) > 1:
            neg = random.choice(candidate_pool)
        negatives.append(neg)
    return negatives


def add_prompt_to_text(text: str, prompt: str) -> str:
    if isinstance(text, str) and text.strip():
        return f"{prompt} {text}"
    return text


def add_prompt_to_example(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
    for key, value in example.items():
        if isinstance(value, list):
            example[key] = [add_prompt_to_text(v, prompt) for v in value]
        elif isinstance(value, str):
            example[key] = add_prompt_to_text(value, prompt)
    return example


def print_trainable_parameters(model: SentenceTransformer) -> None:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    pct = 100 * trainable_params / total_params if total_params else 0
    print(
        f"Trainable params: {trainable_params} | All params: {total_params} | Trainable %: {pct:.2f}"
    )


In [None]:
@dataclass
class FinetuneConfig:
    model_name: str = "./model/KaLM-embedding-multilingual-mini-instruct-v2.5"
    trust_remote_code: bool = True
    lora_r: int = 64
    lora_alpha: int = 128
    lora_dropout: float = 0.1
    data_path: str = "./train_text"
    use_prompt: bool = True
    prompt: str = "Instruct: Retrieve semantically similar text.\nQuery:"
    test_size: float = 0.2
    seed: int = 12
    output_dir: str = "./saved_model"
    num_epochs: int = 1
    train_batch_size: int = 6
    eval_batch_size: int = 4
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    eval_steps: int = 500
    save_steps: int = 500
    warmup_steps: int = 300
    logging_steps: int = 100
    fp16: bool = True
    bf16: bool = False
    max_seq_length: int = 256


@dataclass
class InferenceConfig:
    model_path: str = "./saved_model/KaLM-embedding-multilingual-mini-instruct-v2.5-peft-lora/checkpoint-500"
    truncate_dim: Optional[int] = 256
    infer_data_dir: str = "./infer_data"
    batch_size: int = 512
    description_column: str = "description"
    csv_sep: str = "\001"
    read_csv_kwargs: Optional[Dict[str, Any]] = None
    prompt: str = "Instruct: Retrieve semantically similar text.\nQuery:"
    output_path: str = "./faiss/embeds_2.txt"
    vector_column: str = "vector"
    output_columns: Optional[List[str]] = None


@dataclass
class VisualizationConfig:
    embeddings_path: str = "./faiss/embeds_2.txt"
    csv_sep: str = "\001"
    vector_column: str = "vector"
    label_column: str = "label"
    chunk_size: Optional[int] = None
    chinese_fonts: Optional[List[str]] = None
    tsne_perplexity: int = 30
    tsne_learning_rate: int = 200
    umap_neighbors: int = 15
    random_state: int = 0


In [None]:
def train_sentence_transformer(cfg: FinetuneConfig) -> str:
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
    )
    base_model = SentenceTransformer(cfg.model_name)
    lora_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        lora_dropout=cfg.lora_dropout,
        bias="none",
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "down_proj",
            "gate_proj",
            "up_proj",
        ],
    )
    base_model.add_adapter(lora_config)
    print_trainable_parameters(base_model)

    data_dir = Path(cfg.data_path)
    json_files = [str(p) for p in data_dir.glob("*.json")]
    if not json_files:
        raise FileNotFoundError(f"No json files found under {data_dir}")

    dataset = load_dataset("json", data_files=json_files)
    dataset = dataset.filter(lambda example: example != "")

    if cfg.use_prompt:
        dataset = dataset.map(lambda x: add_prompt_to_example(x, cfg.prompt))

    dataset_dict = dataset["train"].train_test_split(
        test_size=cfg.test_size, seed=cfg.seed
    )
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]

    loss = MultipleNegativesSymmetricRankingLoss(base_model)
    run_name = f"{Path(cfg.model_name).name}-peft-lora"
    output_dir = Path(cfg.output_dir) / run_name

    training_args = SentenceTransformerTrainingArguments(
        output_dir=str(output_dir),
        num_train_epochs=cfg.num_epochs,
        per_device_train_batch_size=cfg.train_batch_size,
        per_device_eval_batch_size=cfg.eval_batch_size,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        eval_strategy="steps",
        eval_steps=cfg.eval_steps,
        save_strategy="steps",
        save_steps=cfg.save_steps,
        save_total_limit=3,
        warmup_steps=cfg.warmup_steps,
        logging_steps=cfg.logging_steps,
        logging_dir=str(output_dir),
        fp16=cfg.fp16,
        bf16=cfg.bf16,
    )

    evaluator = TripletEvaluator(
        anchors=eval_dataset["anchor"],
        positives=eval_dataset["positive"],
        negatives=generate_negative(eval_dataset),
        main_similarity_function=SimilarityFunction.COSINE,
        name="sts-dev",
    )

    trainer = SentenceTransformerTrainer(
        model=base_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=[evaluator],
    )

    trainer.train()
    base_model.save_pretrained(str(output_dir))
    return str(output_dir)


In [None]:
def run_batch_inference(cfg: InferenceConfig) -> Path:
    files = sorted(Path(cfg.infer_data_dir).glob("*.csv"))
    if not files:
        raise FileNotFoundError(f"No csv files found under {cfg.infer_data_dir}")

    frames = []
    for file_path in files:
        kwargs = cfg.read_csv_kwargs or {}
        frame = pd.read_csv(file_path, sep=cfg.csv_sep, **kwargs)
        frames.append(frame)
    data = pd.concat(frames, ignore_index=True)
    data = data.drop_duplicates()

    if cfg.description_column not in data.columns:
        raise ValueError(f"Column '{cfg.description_column}' not found in data")

    model = SentenceTransformer(cfg.model_path, truncate_dim=cfg.truncate_dim)
    output_path = Path(cfg.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with output_path.open("w", encoding="utf-8") as f:
        for start_idx in range(0, len(data), cfg.batch_size):
            end_idx = min(start_idx + cfg.batch_size, len(data))
            batch = data.iloc[start_idx:end_idx]
            texts = batch[cfg.description_column].astype(str).tolist()
            embeds = model.encode(
                texts,
                normalize_embeddings=True,
                batch_size=cfg.batch_size,
                show_progress_bar=True,
                prompt=cfg.prompt,
            )
            embed_text = [" ".join(str(x) for x in row) for row in embeds]
            result = pd.DataFrame({cfg.vector_column: embed_text})
            if cfg.output_columns:
                merged = pd.concat([batch.reset_index(drop=True), result], axis=1)
                merged = merged[cfg.output_columns]
            else:
                merged = pd.concat([batch.reset_index(drop=True), result], axis=1)
            merged.to_csv(f, index=False, header=False, sep=cfg.csv_sep, mode="a")
    return output_path


In [None]:
def _configure_fonts(font_candidates: Optional[List[str]]) -> None:
    import matplotlib.font_manager as font_manager

    if not font_candidates:
        font_candidates = [
            "SimHei",
            "Microsoft YaHei",
            "WenQuanYi Micro Hei",
            "Arial Unicode MS",
            "STHeiti",
            "STSong",
        ]
    available_fonts = {f.name for f in font_manager.fontManager.ttflist}
    for font in font_candidates:
        if font in available_fonts:
            plt.rcParams["font.sans-serif"] = [font]
            break
    plt.rcParams["axes.unicode_minus"] = False


def _compute_cluster_metrics(x: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    metrics = {"silhouette": np.nan, "calinski_harabasz": np.nan, "davies_bouldin": np.nan}
    unique_labels = np.unique(labels)
    if len(unique_labels) < 2 or x.shape[0] <= len(unique_labels):
        return metrics
    try:
        metrics["silhouette"] = silhouette_score(x, labels)
    except ValueError:
        pass
    try:
        metrics["calinski_harabasz"] = calinski_harabasz_score(x, labels)
    except ValueError:
        pass
    try:
        metrics["davies_bouldin"] = davies_bouldin_score(x, labels)
    except ValueError:
        pass
    return metrics


def _format_metrics_text(name: str, metrics: Dict[str, float]) -> str:
    def fmt(value: float) -> str:
        return "nan" if np.isnan(value) else f"{value:.3f}"

    return (
        f"{name}\n"
        f"Silhouette: {fmt(metrics['silhouette'])}\n"
        f"C-H: {fmt(metrics['calinski_harabasz'])}\n"
        f"D-B: {fmt(metrics['davies_bouldin'])}"
    )


def visualize_embeddings(cfg: VisualizationConfig, save_path: str = "./embedding_scatter.png") -> Path:
    _configure_fonts(cfg.chinese_fonts)
    df = dd.read_csv(
        cfg.embeddings_path,
        sep=cfg.csv_sep,
        header=None,
        names=[cfg.label_column, cfg.vector_column]
        if cfg.label_column != cfg.vector_column
        else None,
        blocksize=cfg.chunk_size,
    ).compute()

    if cfg.vector_column not in df.columns:
        raise ValueError(f"Column '{cfg.vector_column}' not found in file {cfg.embeddings_path}")

    df[cfg.vector_column] = df[cfg.vector_column].apply(
        lambda x: [float(i) for i in str(x).split(" ")]
    )
    embeddings = np.array(df[cfg.vector_column].tolist())

    if cfg.label_column not in df.columns:
        labels = np.zeros(len(df))
    else:
        labels = df[cfg.label_column].values

    tsne_model = TSNE(
        n_components=2,
        random_state=cfg.random_state,
        perplexity=cfg.tsne_perplexity,
        learning_rate=cfg.tsne_learning_rate,
    )
    tsne_2d = tsne_model.fit_transform(embeddings)

    umap_model = umap.UMAP(
        n_components=2,
        random_state=cfg.random_state,
        n_neighbors=cfg.umap_neighbors,
    )
    umap_2d = umap_model.fit_transform(embeddings)

    metrics_original = _compute_cluster_metrics(embeddings, labels)
    metrics_tsne = _compute_cluster_metrics(tsne_2d, labels)
    metrics_umap = _compute_cluster_metrics(umap_2d, labels)

    unique_labels = np.unique(labels)
    color_map = {label: color for label, color in zip(unique_labels, sns.color_palette("husl", len(unique_labels) or 1))}

    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    projections = [
        ("t-SNE Projection", tsne_2d, metrics_tsne),
        ("UMAP Projection", umap_2d, metrics_umap),
    ]

    for ax, (title, values, metric_values) in zip(axes, projections):
        tmp_df = pd.DataFrame(values, columns=["x", "y"])
        tmp_df["label"] = labels
        sns.scatterplot(
            x="x",
            y="y",
            hue="label",
            data=tmp_df,
            ax=ax,
            palette=color_map,
            alpha=0.6,
            s=50,
            legend=False,
        )
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.text(
            0.02,
            0.98,
            _format_metrics_text(title, metric_values),
            transform=ax.transAxes,
            ha="left",
            va="top",
            bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
        )

    fig.text(
        0.5,
        0.02,
        _format_metrics_text("Original Embeddings", metrics_original),
        ha="center",
        va="bottom",
        fontsize=12,
        bbox=dict(boxstyle="round,pad=0.4", fc="white", alpha=0.8),
    )
    fig.text(
        0.5,
        0.95,
        "指标解读：Silhouette/CH 越大越好，Davies-Bouldin 越小越好",
        ha="center",
        va="top",
        fontsize=12,
        color="dimgray",
    )

    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(save_path, dpi=180, bbox_inches="tight")
    plt.close(fig)
    return save_path


In [None]:
FINETUNE_CONFIG = FinetuneConfig(
    model_name="./model/KaLM-embedding-multilingual-mini-instruct-v2.5",
    data_path="./train_text",
    output_dir="./saved_model",
)

INFERENCE_CONFIG = InferenceConfig(
    model_path="./saved_model/KaLM-embedding-multilingual-mini-instruct-v2.5-peft-lora/checkpoint-500",
    infer_data_dir="./infer_data",
    description_column="description",
    read_csv_kwargs={"names": ["description"], "engine": "python"},
    output_columns=["description", "vector"],
)

VIS_CONFIG = VisualizationConfig(
    embeddings_path="./faiss/embeds_2.txt",
    label_column="label",
    vector_column="vector",
)


In [None]:
RUN_FINETUNE = False
RUN_INFERENCE = False
RUN_VISUALIZATION = False

trained_model_path = INFERENCE_CONFIG.model_path

if RUN_FINETUNE:
    trained_model_path = train_sentence_transformer(FINETUNE_CONFIG)
    INFERENCE_CONFIG.model_path = trained_model_path
    print(f"Model saved to {trained_model_path}")

if RUN_INFERENCE:
    output_file = run_batch_inference(INFERENCE_CONFIG)
    print(f"Embeddings written to {output_file}")

if RUN_VISUALIZATION:
    vis_path = visualize_embeddings(VIS_CONFIG)
    print(f"Visualization saved to {vis_path}")


---
Need to customize something else? Document questions directly inside the corresponding config cell so future runs stay reproducible. Happy fine-tuning!
