In [16]:
!pip3 install torch torchvision

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [17]:
import json
import random

from torch import nn
from torch.optim.lr_scheduler import LambdaLR
import torch
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer

In [18]:
from utils import load_data, get_document_chunks, TripletDataset, LinearAdapter

In [19]:
with open('../globals.json') as config_file:
    config = json.load(config_file)
    main_file = config.get("main_pdf")
    negative_file = config.get("negative_pdf")
    train_path = config.get("train_path")
    validation_path = config.get("validation_path")

In [20]:
def random_negative():
    random_sample = random.choice(get_document_chunks(negative_file))
    return random_sample


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step) /
            float(max(1, num_training_steps - num_warmup_steps))
        )
    return LambdaLR(optimizer, lr_lambda)


def train_linear_adapter(base_model, train_data, negative_sampler, num_epochs=10, batch_size=32,
                         learning_rate=2e-5, warmup_steps=100, max_grad_norm=1.0, margin=1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    adapter = LinearAdapter(
        base_model.get_sentence_embedding_dimension()).to(device)
    triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
    optimizer = AdamW(adapter.parameters(), lr=learning_rate)
    dataset = TripletDataset(train_data, base_model, negative_sampler)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    for epoch in range(num_epochs):
        adapter.train()
        total_loss = 0
        progress_bar = tqdm(
            dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch in progress_bar:
            query_emb, positive_emb, negative_emb = [
                x.to(device) for x in batch]
            adapted_query_emb = adapter(query_emb)
            loss = triplet_loss(adapted_query_emb, positive_emb, negative_emb)
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(adapter.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    return adapter

In [21]:
adapter_kwargs = {
    'num_epochs': 10,
    'batch_size': 32,
    'learning_rate': 0.003,
    'warmup_steps': 100,
    'max_grad_norm': 1.0,
    'margin': 1.0
}

train_data, validation_data = load_data(train_path, validation_path)

base_model = SentenceTransformer('all-MiniLM-L6-v2')
trained_adapter = train_linear_adapter(
    base_model, train_data, random_negative, **adapter_kwargs
)

save_dict = {
    'adapter_state_dict': trained_adapter.state_dict(),
    'adapter_kwargs': adapter_kwargs
}

                                                  

KeyboardInterrupt: 

In [None]:
torch.save(save_dict, '../adapters/' + 'linear_adapter_' +
           str(adapter_kwargs['num_epochs']) + 'epochs.pth')