## Training pipline to finetune pre-trained models from [SentenceTransformers](https://www.sbert.net/docs/pretrained_models.html) with Contrastive Loss and Hard Negative Samples

* reference: [CONTRASTIVE LEARNING WITH
HARD NEGATIVE SAMPLES](https://arxiv.org/pdf/2010.04592.pdf)

In [None]:
!pip install faiss-cpu sentence_transformers

In [None]:
import itertools
import os
import random as rn
import shutil
from scipy.special import comb

import faiss
import numpy as np
import pandas as pd
from sentence_transformers import (
    SentencesDataset,
    SentenceTransformer,
    evaluation,
    losses,
)
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
from tqdm import tqdm, notebook

In [None]:
MODEL_NAME = "stsb-roberta-base"
MODEL = SentenceTransformer(MODEL_NAME)
MODEL_SAVE_PATH = f"finetuned-model/{MODEL_NAME}"
BATCH_SIZE = 32
CAP_SIZE = 50

In [None]:
def data_prep():
    df = pd.read_csv("../input/shopee-product-matching/train.csv")

    label_groups = df["label_group"].unique()
    # rn.shuffle(label_groups)
    train_df = df.loc[df.label_group.isin(label_groups[: int(0.8 * len(label_groups))])]
    eval_df = df.loc[df.label_group.isin(label_groups[int(0.8 * len(label_groups)) :])]

    # Prepare train_data according to ContrastiveLoss
    train_examples = list()
    train_groups = [
        train_df.loc[train_df["label_group"] == lg]["title"].values.tolist()
        for lg in train_df["label_group"].unique()
    ]

    # Build FAISS index to query hard negative samples
    train_titles = sum(train_groups, [])
    train_embeddings = MODEL.encode(train_titles)
    train_index = faiss.IndexFlatL2(train_embeddings.shape[1])
    train_index.add(train_embeddings)
    for _, group in enumerate(train_groups):
        negative_pairs_no = int(max(CAP_SIZE - comb(len(group), 2), comb(len(group), 2)))

        group_embedding = np.ascontiguousarray(
            np.mean(MODEL.encode(group), axis=0).reshape(1, -1), dtype=np.float32
        )
        _, similar_idx = train_index.search(group_embedding, negative_pairs_no * 2)
        negative_titles = [train_titles[idx] for idx in similar_idx[0]]
        for title in group:
            try:
                negative_titles.remove(title)
            except:
                pass
        negative_titles = negative_titles[:negative_pairs_no]

        positive_pairs = [
            list(pair)
            for pair in list(itertools.combinations(group, 2))
            if (isinstance(pair[0], str) and isinstance(pair[1], str))
        ]
        for pair in positive_pairs:
            train_examples.append(InputExample(texts=pair, label=1))
        negative_pairs = [
            [rn.choice(rn.choice(positive_pairs)), negative_title]
            for negative_title in negative_titles
        ]
        for pair in negative_pairs:
            train_examples.append(InputExample(texts=pair, label=0))

    train_dataset = SentencesDataset(train_examples, MODEL)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

    # Prepare eval_data according to BinaryClassificationEvaluator
    eval_examples = list()
    eval_groups = [
        eval_df.loc[eval_df["label_group"] == lg]["title"].values.tolist()
        for lg in eval_df["label_group"].unique()
    ]
    # Build FAISS index to query hard negative samples
    eval_titles = sum(eval_groups, [])
    eval_embeddings = MODEL.encode(eval_titles)
    eval_index = faiss.IndexFlatL2(eval_embeddings.shape[1])
    eval_index.add(eval_embeddings)
    for _, group in enumerate(eval_groups):
        negative_pairs_no = int(max(CAP_SIZE - comb(len(group), 2), comb(len(group), 2)))

        group_embedding = np.ascontiguousarray(
            np.mean(MODEL.encode(group), axis=0).reshape(1, -1), dtype=np.float32
        )
        _, similar_idx = eval_index.search(group_embedding, negative_pairs_no * 2)
        negative_titles = [eval_titles[idx] for idx in similar_idx[0]]
        for title in group:
            try:
                negative_titles.remove(title)
            except:
                pass
        negative_titles = negative_titles[:negative_pairs_no]

        positive_pairs = [
            list(pair) + [1]
            for pair in list(itertools.combinations(group, 2))
            if (isinstance(pair[0], str) and isinstance(pair[1], str))
        ]
        negative_pairs = [
            [rn.choice(rn.choice(positive_pairs)[:2]), negative_title, 0]
            for negative_title in negative_titles
        ]
        eval_examples.append(positive_pairs)
        eval_examples.append(negative_pairs)

    eval_examples = sum(eval_examples, [])

    evaluator = evaluation.BinaryClassificationEvaluator(
        sentences1=list(zip(*eval_examples))[0],
        sentences2=list(zip(*eval_examples))[1],
        labels=list(zip(*eval_examples))[2],
        batch_size=BATCH_SIZE,
    )

    return train_dataloader, evaluator

In [None]:
train_dataloader, evaluator = data_prep()

train_loss = losses.ContrastiveLoss(model=MODEL)

if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH)
else:
    shutil.rmtree(MODEL_SAVE_PATH)
    os.makedirs(MODEL_SAVE_PATH)

MODEL.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=100,
    warmup_steps=100,
    evaluation_steps=500,
    output_path=MODEL_SAVE_PATH,
    evaluator=evaluator,
)