In [None]:
! pip install -Uq huggingface_hub transformers sentence_transformers datasets

In [None]:
import wandb
wandb.init(mode="disabled")

In [None]:
!rm -r models

rm: cannot remove 'models': No such file or directory


### **Load Dataset**

In [None]:
from datasets import load_dataset

msmarco = load_dataset("rasyosef/msmarco")
msmarco

In [None]:
msmarco_distil = load_dataset("yosefw/msmarco-train-distil-v2", split="train")
msmarco_distil

In [None]:
train_dataset = msmarco_distil.shuffle(seed=42).select(range(250_000))
test_dataset = msmarco["dev"]#.select(range(10_000))

train_dataset, test_dataset

(Dataset({
     features: ['query_id', 'query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'label'],
     num_rows: 250000
 }),
 Dataset({
     features: ['query_id', 'query', 'positives', 'negatives'],
     num_rows: 55577
 }))

In [None]:
len(set(train_dataset['query_id'])), len(set(test_dataset['query_id']))

(250000, 55577)

In [None]:
from datasets import Dataset
import numpy as np
from tqdm import tqdm
import random

ds_rows = []
for row in tqdm(train_dataset):
  negatives = [row["negative_1"], row["negative_2"], row["negative_3"], row["negative_4"], row["negative_5"], row["negative_6"], row["negative_7"], row["negative_8"]]
  labels = np.array(row["label"])
  # labels[labels<0.5] = 0.5

  pairs = sorted(list(zip(negatives, labels)), key=lambda x: x[1])
  negatives_sorted, labels_sorted = [x[0] for x in pairs], [max(x[1], 1.0) for x in pairs]

  ds_rows.append({
      "query_id": row["query_id"],
      "query": row["query"],
      "positive": row["positive"],
      "negative_1": negatives_sorted[0],
      "negative_2": negatives_sorted[1],
      "negative_3": negatives_sorted[4],
      "negative_4": negatives_sorted[7],
      "label": [labels_sorted[0], labels_sorted[1], labels_sorted[4], labels_sorted[7]]
    })

relevance_dataset = Dataset.from_list(ds_rows)#.shuffle(seed=42)#.sort("query_id")#.select(range(4000))
relevance_dataset

In [None]:
relevance_dataset[0]

{'query_id': 1182558,
 'query': 'heart specialists in ridgeland ms',
 'positive': 'Dr. George Reynolds Jr, MD is a cardiology specialist in Ridgeland, MS and has been practicing for 35 years. He graduated from Vanderbilt University School Of Medicine in 1977 and specializes in cardiology and internal medicine.',
 'negative_1': "Dr. James Kramer is a Internist in Ridgeland, MS. Find Dr. Kramer's phone number, address and more.",
 'negative_2': "Dr. James Kramer is an internist in Ridgeland, Mississippi. He received his medical degree from Loma Linda University School of Medicine and has been in practice for more than 20 years. Dr. James Kramer's Details",
 'negative_3': 'Chronic Pulmonary Heart Diseases (incl. Pulmonary Hypertension) Coarctation of the Aorta; Congenital Aortic Valve Disorders; Congenital Heart Defects; Congenital Heart Disease; Congestive Heart Failure; Coronary Artery Disease (CAD) Endocarditis; Heart Attack (Acute Myocardial Infarction) Heart Disease; Heart Murmur; He

### **Initialize SPLADE**

In [None]:
from sentence_transformers import SparseEncoder, SparseEncoderModelCardData
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling

#1. Load a model to finetune with 2. (Optional) model card data

model = SparseEncoder(
    modules=[
        MLMTransformer("prajjwal1/bert-mini"),
        SpladePooling(pooling_strategy="max")
    ],
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="mit",
        model_name="SPLADE-BERT-Mini-Distil",
    )
)

model

In [None]:
model_size = sum(t.numel() for t in model.parameters())
print(f"BERT Mini size: {model_size/1_000_000:.1f}M parameters")

BERT Mini size: 11.2M parameters


In [None]:
from sentence_transformers.sparse_encoder.losses import SparseMarginMSELoss, SpladeLoss

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMarginMSELoss(model=model),
    query_regularizer_weight=5e-1,
    document_regularizer_weight=3e-1,
)

### **Evaluator**

In [None]:
import hashlib

def md5(text):
  res = hashlib.md5(text.encode())
  return res.hexdigest()

dev_dataset = test_dataset.select(range(5_000))

dev_queries = dict(zip(dev_dataset["query_id"], dev_dataset["query"]))

dev_corpus = {}
for row in dev_dataset:
  for positive in row["positives"]:
    dev_corpus[md5(positive)] = positive

  for negative in row["negatives"]:
    dev_corpus[md5(negative)] = negative

dev_relevant_docs = dict(
    zip(
      dev_dataset["query_id"],
      [[md5(pos) for pos in positives] for positives in dev_dataset["positives"]]
    )
  )

len(dev_corpus), len(dev_queries), len(dev_relevant_docs)

(49036, 5000, 5000)

In [None]:
from sentence_transformers.sparse_encoder.evaluation import SparseInformationRetrievalEvaluator

dev_evaluator = SparseInformationRetrievalEvaluator(
    queries=dev_queries,
    corpus=dev_corpus,
    relevant_docs=dev_relevant_docs,
    batch_size=64,
    corpus_chunk_size=2048,
    show_progress_bar=False
)

### **Train**

In [None]:
from sentence_transformers import SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.training_args import BatchSamplers

num_epochs = 6
batch_size = 48
# gradient_accum_steps = 4

# 5. (Optional) Specify training arguments
run_name = "SPLADE-BERT-Mini-distil"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    # gradient_accumulation_steps=gradient_accum_steps,
    per_device_eval_batch_size=batch_size,
    learning_rate=8e-5,
    warmup_ratio=0.025,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    # push_to_hub=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    load_best_model_at_end=True,
    metric_for_best_model="dot_mrr@10",
)

# 7. Create a trainer
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=relevance_dataset.select_columns(['query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'label']), #
    loss=loss,
    evaluator=dev_evaluator,
)

In [None]:
# Train
trainer.train()

In [None]:
# 9. Save the trained model
model.save_pretrained(f"./final")

In [None]:
# 10. (Optional) Push it to the Hugging Face Hub
# from google.colab import userdata

# # push model to hub
# trainer.model.push_to_hub(run_name, exist_ok=True, private=True, token=userdata.get("HF_WRITE_TOKEN"))