In [None]:
# !pip install pandas "sentence-transformers[train]" 'accelerate>=0.26.0'

In [None]:
# TODO:
# 1. Analyze. Understand the dataset better - how can we create a better dataset?
# 2. Create test dataset
# 3. Use step wise evaluation because epoch takes too much time - unnecessary for early stopping evaluation

In [None]:
import os
import pandas as pd
import logging
from datetime import datetime
from sentence_transformers import (
    SentenceTransformer, losses, 
    SentenceTransformerTrainer, SentenceTransformerTrainingArguments
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator, EmbeddingSimilarityEvaluator
from transformers import EarlyStoppingCallback
from datasets import Dataset


logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO)

In [None]:
CONFIG = {
    "model_name": "BAAI/bge-m3",
    "train_file": "malayalam_dict.csv",
    "dev_file": None,  # Optional; will auto-split from train_file if not set
    "output_path": f"./finetuned-bge-m3-{datetime.now().strftime('%Y-%m-%d_%H-%M')}",
    "batch_size": 16,
    "num_epochs": 20,
    "max_seq_length": 256,
    "loss_type": "multiple_negatives",  # "triplet" or "multiple_negatives"
    "warmup_ratio": 0.1,
    "learning_rate": 2e-5,
    "use_amp": True,
    "early_stopping_patience": 3,
    "early_stopping_threshold": 0.05,
}


In [None]:

def load_data(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    df = pd.read_csv(file_path)
    required_cols = ["word", "definition", "incorrect_definition"]
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Missing column: {col}")
        df[col] = df[col].fillna("").astype(str).str.strip()

    df = df.rename(columns={"word": "anchor", "definition": "positive", "incorrect_definition":"negative"})    

    return df


def prepare_examples(df, loss_type="multiple_negatives"):
    examples = []

    if loss_type == "triplet":
        dataset = Dataset.from_pandas(df)
        return dataset
    if loss_type == "multiple_negatives":
        dataset = Dataset.from_pandas(df[["anchor", "positive"]])
        return dataset
    raise ValueError(f"Unsupported loss type: {loss_type}")


def split_train_dev(df, dev_file=None, split_ratio=0.1, min_dev_size=100):
    if dev_file and os.path.exists(dev_file):
        dev_df = load_data(dev_file)
        train_df = df
    else:
        dev_size = max(int(len(df) * split_ratio), min_dev_size)
        dev_df = df.sample(dev_size, random_state=42)
        train_df = df.drop(dev_df.index)

    return train_df.reset_index(drop=True), dev_df.reset_index(drop=True)

In [None]:
def create_evaluator(dev_examples, loss_type):
    if loss_type == "triplet":
        return TripletEvaluator(
            anchors=dev_examples["anchor"],
            positives=dev_examples["positive"],
            negatives=dev_examples["negative"],
            name="mal-triplet-dev",
        )
    else:
        return EmbeddingSimilarityEvaluator(
            sentences1=dev_examples["anchor"],
            sentences2=dev_examples["positive"],
            scores=[1.0] * len(dev_examples),
            name="ml-embeddings-dev",
        )

In [20]:
def main(config):
    model_name = config["model_name"]
    logging.info(f"Loading model: {model_name}")
    model = SentenceTransformer(model_name)
    model.max_seq_length = config["max_seq_length"]

    df = load_data(config["train_file"])
    train_df, dev_df = split_train_dev(df, config["dev_file"])

    logging.info(f"Training samples: {len(train_df)}, Validation samples: {len(dev_df)}")

    train_examples = prepare_examples(train_df, config["loss_type"])
    dev_examples = prepare_examples(dev_df, config["loss_type"])

    # train_dataset = input_examples_to_dict(train_examples)
    # eval_dataset = input_examples_to_dict(dev_examples)

    evaluator = create_evaluator(dev_examples, config["loss_type"])

    if config["loss_type"] == "triplet":
        train_loss = losses.TripletLoss(model)
    else:
        train_loss = losses.MultipleNegativesRankingLoss(model)

    training_args = SentenceTransformerTrainingArguments(
        output_dir=config["output_path"],
        num_train_epochs=config["num_epochs"],
        per_device_train_batch_size=config["batch_size"],
        per_device_eval_batch_size=config["batch_size"] // 2,
        warmup_ratio=config["warmup_ratio"],
        learning_rate=config["learning_rate"],
        fp16=config["use_amp"],
        lr_scheduler_type="cosine",
        optim="adamw_torch_fused", 
        batch_sampler=BatchSamplers.NO_DUPLICATES,

        logging_steps=100,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
    )

    early_stopper = EarlyStoppingCallback(
        early_stopping_patience=config["early_stopping_patience"],
        early_stopping_threshold=config["early_stopping_threshold"]
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_examples,
        eval_dataset=dev_examples,
        loss=train_loss,
        evaluator=evaluator,
        callbacks=[early_stopper],
    )

    logging.info("Starting training...")
    trainer.train()
    logging.info(f"Training completed. Model saved to: {config['output_path']}")

In [None]:
if __name__ == "__main__":
    main(CONFIG)