In [1]:
from tqdm import tqdm
from urllib.error import HTTPError, URLError
from urllib.request import urlretrieve

# https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
class TqdmUpTo(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        """
        b  : int, optional
            Number of blocks transferred so far [default: 1].
        bsize  : int, optional
            Size of each block (in tqdm units) [default: 1].
        tsize  : int, optional
            Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)  # will also set self.n = b * bsize

# From Pyserini
# https://github.com/castorini/pyserini/blob/1bbf7a72626866c88e8b21da99d48da6cb43673f/pyserini/util.py#L67C1-L101C28
def download_url(url, save_dir, local_filename=None, md5=None, force=False, verbose=True):
    # If caller does not specify local filename, figure it out from the download URL:
    if not local_filename:
        filename = url.split('/')[-1]
        filename = re.sub('\\?dl=1$', '', filename)  # Remove the Dropbox 'force download' parameter
    else:
        # Otherwise, use the specified local_filename:
        filename = local_filename

    destination_path = os.path.join(save_dir, filename)

    if verbose:
        print(f'Downloading {url} to {destination_path}...')

    # Check to see if file already exists, if so, simply return (quietly) unless force=True, in which case we remove
    # destination file and download fresh copy.
    if os.path.exists(destination_path):
        if verbose:
            print(f'{destination_path} already exists!')
        if not force:
            if verbose:
                print(f'Skipping download.')
            return destination_path
        if verbose:
            print(f'force=True, removing {destination_path}; fetching fresh copy...')
        os.remove(destination_path)

    with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename) as t:
        urlretrieve(url, filename=destination_path, reporthook=t.update_to)

    if md5:
        md5_computed = compute_md5(destination_path)
        assert md5_computed == md5, f'{destination_path} does not match checksum! Expecting {md5} got {md5_computed}.'

    return destination_path

## Change the DATA_DIR below if needed

In [2]:
import os
DATA_DIR=os.environ['HOME'] + '/finetune_e5'
os.makedirs(DATA_DIR, exist_ok = True)

In [3]:
DATA_URL='https://file.io/BIAWtmnSHrQi'
DST_TRAIN_SUBDIR='sbert_train_qty_50000_neg_qty_1'
DST_FILE_NAME=f'{DST_TRAIN_SUBDIR}.tar.bz2'

DST_PATH=f'{DATA_DIR}/{DST_FILE_NAME}'
DST_TRAIN_DATA_DIR=f'{DATA_DIR}'
DST_TRAIN_PATH=f'{DATA_DIR}/{DST_TRAIN_SUBDIR}'

MODEL_OUTPUT_PATH=f'{DATA_DIR}/trained_model'

print(DST_PATH)
print(MODEL_OUTPUT_PATH)
print(DST_TRAIN_PATH)

/home/leo/finetune_e5/sbert_train_qty_50000_neg_qty_1.tar.bz2
/home/leo/finetune_e5/trained_model
/home/leo/finetune_e5/sbert_train_qty_50000_neg_qty_1


In [4]:
import tarfile
import shutil

DOWNLOAD_DATA=False # Enabled it only once

if DOWNLOAD_DATA:
    if os.path.exists(DST_PATH):
        os.unlink(DST_PATH)
    
    if os.path.exists(DST_TRAIN_PATH):
        shutil.rmtree(DST_TRAIN_PATH)

    download_url(DATA_URL, DATA_DIR, DST_FILE_NAME)

    with tarfile.open(DST_PATH) as tarball:
        tarball.extractall(DATA_DIR)

    os.unlink(DST_PATH)
    os.listdir(DST_TRAIN_PATH)

## Change Sentence BERT training parameters if needed

In [5]:
# 8 for A100 40GB
#BATCH_SIZE=8
#LOSS_MINI_BATCH_SIZE=16 

# Single small GPU
EVAL_BATCH_SIZE=8
BATCH_SIZE=2
LOGGING_STEPS=64
LOSS_MINI_BATCH_SIZE=64

NUM_EPOCHS=1

In [6]:
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import RerankingEvaluator
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss, MultipleNegativesRankingLoss

In [7]:
import torch
from torch import Tensor

def get_position_ids(input_ids: Tensor, max_original_positions: int = 512, encode_max_length: int = 4096) -> Tensor:
    position_ids = list(range(input_ids.size(1)))
    factor = max(encode_max_length // max_original_positions, 1)
    if input_ids.size(1) <= max_original_positions:
        position_ids = [(pid * factor) for pid in position_ids]

    position_ids = torch.tensor(position_ids, dtype=torch.long)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

    return position_ids

class SentenceTransformerE5(SentenceTransformer):
    def tokenize(self, texts) -> dict[str, Tensor]:
        """
        Tokenizes the texts. This function is E5-base-4k specific. It recalculates/adds position_ids.

        Args:
            texts (Union[List[str], List[Dict], List[Tuple[str, str]]]): A list of texts to be tokenized.

        Returns:
            Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
                "attention_mask", and "token_type_ids".
        """
        batch_dict = self._first_module().tokenize(texts)

        assert self.get_max_seq_length() == 4096

        batch_dict['position_ids'] = get_position_ids(batch_dict['input_ids'], max_original_positions=512, 
                                                      encode_max_length=4096)  
        
        return batch_dict

In [8]:
from datasets import Dataset
test=Dataset.from_json(f'{DST_TRAIN_PATH}/eval.json')
train=Dataset.from_json(f'{DST_TRAIN_PATH}/train_triplets.jsonl')

In [9]:
model = SentenceTransformerE5("dwzhu/e5-base-4k")
# must have position ids
assert 'position_ids' in model.tokenize('This is simple').keys()

No sentence-transformers model found with name dwzhu/e5-base-4k. Creating a new one with mean pooling.


In [10]:
# Sometimes doesn't work well unless you disable tokenizer parallelism
os.environ['TOKENIZERS_PARALLELISM']='false'

In [11]:
loss=CachedMultipleNegativesRankingLoss(model, mini_batch_size=LOSS_MINI_BATCH_SIZE)

In [12]:
BATCH_SIZE, LOSS_MINI_BATCH_SIZE

(2, 64)

In [13]:
import wandb
wandb.init(mode="disabled")



In [14]:
train_args=SentenceTransformerTrainingArguments(output_dir=MODEL_OUTPUT_PATH, num_train_epochs=NUM_EPOCHS,
                                                run_name=None, report_to=None, logging_steps=LOGGING_STEPS,
                                                fp16=True, dataloader_pin_memory=False,
                                                per_device_train_batch_size=BATCH_SIZE,
                                                dataloader_prefetch_factor=2,
                                                dataloader_num_workers=2)


## Must run accelerate config before running training

In [15]:
trainer = SentenceTransformerTrainer(model=model, train_dataset=train, loss=loss, args=train_args)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [16]:
model.train()
trainer.train()

Step,Training Loss
64,1.1279
128,0.976
192,1.0196
256,0.858
320,0.9175
384,1.1214
448,1.1232
512,0.9242
576,1.0109
640,0.9373


Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-1000 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-1500 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-2000 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-2500 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-3000 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory /home/leo/finetune_e5/trained_model/checkpoint-3500 already exists and is non-empty. Saving will proceed bu

TrainOutput(global_step=24972, training_loss=1.3482850604230694, metrics={'train_runtime': 30299.5238, 'train_samples_per_second': 1.648, 'train_steps_per_second': 0.824, 'total_flos': 0.0, 'train_loss': 1.3482850604230694, 'epoch': 1.0})

In [None]:
if True:
    N_EVAL=200
    eval_obj=RerankingEvaluator(test.select(range(N_EVAL)), batch_size=EVAL_BATCH_SIZE)
else:
    eval_obj=RerankingEvaluator(test, batch_size=EVAL_BATCH_SIZE)
model.eval()
eval_obj.show_progress_bar=True

results=eval_obj(model)

Batches:   0%|          | 0/17 [00:00<?, ?it/s]

Batches:   0%|          | 0/1678 [00:00<?, ?it/s]

In [18]:
results

{'map': 0.06353592870993571,
 'mrr@10': 0.06352124183006536,
 'ndcg@10': 0.07587305098697536}