# Config

In [16]:
from utils import load_config

config = load_config("../config.json")

# Training

1. train a model with the negatives

- normal chroma training

In [17]:
from torch.optim.lr_scheduler import LambdaLR


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step) /
            float(max(1, num_training_steps - num_warmup_steps))
        )
    return LambdaLR(optimizer, lr_lambda)

In [18]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from torch.utils.data import DataLoader

from utils import TripletDataset, DupletDataset, LinearAdapter


def train_linear_adapter(base_model, train_data, num_epochs=10, batch_size=32,
                         learning_rate=2e-5, warmup_steps=100, max_grad_norm=1.0, margin=1.0, use_negatives=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    adapter = LinearAdapter(
        base_model.get_sentence_embedding_dimension()).to(device)
    optimizer = AdamW(adapter.parameters(), lr=learning_rate)
    scheduler = None

    if use_negatives:
        loss_fn = nn.TripletMarginLoss(margin=margin, p=2)
        dataset = TripletDataset(train_data, base_model)
    else:
        loss_fn = nn.MSELoss()
        dataset = DupletDataset(train_data, base_model)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    for epoch in range(num_epochs):
        adapter.train()
        total_loss = 0
        progress_bar = tqdm(
            dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for batch in progress_bar:
            if use_negatives:
                query_emb, positive_emb, negative_emb = [
                    x.to(device) for x in batch]
                adapted_query_emb = adapter(query_emb)
                loss = loss_fn(adapted_query_emb, positive_emb, negative_emb)
            else:
                query_emb, positive_emb = [x.to(device) for x in batch]
                adapted_query_emb = adapter(query_emb)
                loss = loss_fn(adapted_query_emb, positive_emb)

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(adapter.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    return adapter

## Load Data

In [19]:
from utils import load_data

data = load_data("../" + config.get("training_data_path"))

In [20]:
def get_adapter_filename(adapter_kwargs, dataset_lang) -> str:
    return f"adapter_{dataset_lang}_{str(adapter_kwargs['num_epochs'])}_lr{str(adapter_kwargs['learning_rate'])}_{'negatives' if adapter_kwargs['use_negatives'] else 'no_negatives'}.pth"

## Train a model with the negatives

In [21]:
from sentence_transformers import SentenceTransformer

adapter_kwargs = {
    'num_epochs': 10,
    'batch_size': 32,
    'learning_rate': 0.01,
    'warmup_steps': 100,
    'max_grad_norm': 1.0,
    'margin': 1.0,
    'use_negatives': True,
}
DATASET_LANG = 'cpp_python'

base_model = SentenceTransformer("all-MiniLM-L6-v2")
trained_adapter = train_linear_adapter(
    base_model, data, **adapter_kwargs)

save_dict = {
    'adapter': trained_adapter.state_dict(),
    'adapter_kwargs': adapter_kwargs
}

                                                                        

Epoch 1/10, Average Loss: 0.8517


                                                                        

Epoch 2/10, Average Loss: 0.7755


                                                                        

Epoch 3/10, Average Loss: 0.7591


                                                                        

Epoch 4/10, Average Loss: 0.7495


                                                                        

Epoch 5/10, Average Loss: 0.7423


                                                                        

Epoch 6/10, Average Loss: 0.7368


                                                                        

Epoch 7/10, Average Loss: 0.7320


                                                                        

Epoch 8/10, Average Loss: 0.7278


                                                                        

Epoch 9/10, Average Loss: 0.7243


                                                                         

Epoch 10/10, Average Loss: 0.7203




In [None]:
torch.save(save_dict, '../adapters/' +
           get_adapter_filename(adapter_kwargs, DATASET_LANG))

## Train a model without the negatives

In [23]:

adapter_kwargs['use_negatives'] = False

base_model = SentenceTransformer("all-MiniLM-L6-v2")
trained_adapter = train_linear_adapter(
    base_model, data, **adapter_kwargs)

save_dict = {
    'adapter': trained_adapter.state_dict(),
    'adapter_kwargs': adapter_kwargs
}

                                                                        

Epoch 1/10, Average Loss: 0.0021


                                                                        

Epoch 2/10, Average Loss: 0.0013


                                                                        

Epoch 3/10, Average Loss: 0.0012


                                                                        

Epoch 4/10, Average Loss: 0.0011


                                                                        

Epoch 5/10, Average Loss: 0.0010


                                                                        

Epoch 6/10, Average Loss: 0.0010


                                                                        

Epoch 7/10, Average Loss: 0.0009


                                                                        

Epoch 8/10, Average Loss: 0.0008


                                                                        

Epoch 9/10, Average Loss: 0.0008


                                                                         

Epoch 10/10, Average Loss: 0.0007




In [None]:
torch.save(save_dict, '../adapters/' +
           get_adapter_filename(adapter_kwargs, DATASET_LANG))