In [None]:
! pip install -Uq huggingface_hub transformers sentence_transformers datasets

In [None]:
import wandb
wandb.init(mode="disabled")

In [None]:
!rm -r models

rm: cannot remove 'models': No such file or directory


### **Load Dataset**

In [None]:
from datasets import load_dataset

msmarco = load_dataset("rasyosef/msmarco")
msmarco

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'positives', 'negatives'],
        num_rows: 502901
    })
    dev: Dataset({
        features: ['query_id', 'query', 'positives', 'negatives'],
        num_rows: 55577
    })
})

In [None]:
msmarco_distil = load_dataset("yosefw/msmarco-train-distil-v2", split="train")
msmarco_distil

Dataset({
    features: ['query_id', 'query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'label'],
    num_rows: 496123
})

In [None]:
train_dataset = msmarco_distil#.shuffle(seed=42).select(range(250_000))
test_dataset = msmarco["dev"]#.select(range(10_000))

train_dataset, test_dataset

(Dataset({
     features: ['query_id', 'query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'label'],
     num_rows: 496123
 }),
 Dataset({
     features: ['query_id', 'query', 'positives', 'negatives'],
     num_rows: 55577
 }))

In [None]:
len(set(train_dataset['query_id'])), len(set(test_dataset['query_id']))

(496123, 55577)

In [None]:
from datasets import Dataset
import numpy as np
from tqdm import tqdm
import random

ds_rows = []
for row in tqdm(train_dataset):
  negatives = [row["negative_1"], row["negative_2"], row["negative_3"], row["negative_4"], row["negative_5"], row["negative_6"], row["negative_7"], row["negative_8"]]
  labels = np.array(row["label"])
  # labels[labels<0.5] = 0.5

  pairs = sorted(list(zip(negatives, labels)), key=lambda x: x[1])
  negatives_sorted, labels_sorted = [x[0] for x in pairs], [max(x[1], 0.5) for x in pairs]

  ds_rows.append({
      "query_id": row["query_id"],
      "query": row["query"],
      "positive": row["positive"],
      "negative_1": negatives_sorted[0],
      "negative_2": negatives_sorted[2],
      "negative_3": negatives_sorted[4],
      "negative_4": negatives_sorted[6],
      "label": [labels_sorted[0], labels_sorted[2], labels_sorted[4], labels_sorted[6]]
    })

relevance_dataset = Dataset.from_list(ds_rows)#.shuffle(seed=42)#.sort("query_id")#.select(range(4000))
relevance_dataset

100%|██████████| 496123/496123 [01:10<00:00, 7063.45it/s]


Dataset({
    features: ['query_id', 'query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'label'],
    num_rows: 496123
})

In [None]:
relevance_dataset[0]

{'query_id': 111652,
 'query': 'could Nexium antacid cause sweating',
 'positive': 'Summary: Sweating-excessive is found among people who take Nexium, especially for people who are 60+ old, have been taking the drug for.Personalized health information: on eHealthMe you can find out what patients like me (same gender, age) reported their drugs and conditions on FDA and social media since 1977. I am a 56 year old female who has been taking Nexium for 13 years and has been plagued by shingles.. 2  Support group for people who have Sweating-Excessive.  3 Been on warfarin for 6 days and having sweating at times.',
 'negative_1': 'More questions for: Nexium, Sweating-excessive. You may be interested at these reviews (Write a review): 1  Xarelto caused shortness of breath. 2  After taking Xarelto for 3 years I suddently experienced shortness of breath, sweating and pain in my arms. 3  Myrbetriq & hyperhidrosis (night sweats). I am a 56 year old female who has been taking Nexium for 13 years a

### **Initialize SPLADE**

In [None]:
from sentence_transformers import SparseEncoder, SparseEncoderModelCardData
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling

#1. Load a model to finetune with 2. (Optional) model card data

model = SparseEncoder(
    modules=[
        MLMTransformer("prajjwal1/bert-tiny"),
        SpladePooling(pooling_strategy="max")
    ],
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="mit",
        model_name="SPLADE-BERT-Tiny-Distil",
    )
)

model

SparseEncoder(
  (0): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
  (1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
)

In [None]:
model_size = sum(t.numel() for t in model.parameters())
print(f"BERT Mini size: {model_size/1_000_000:.1f}M parameters")

BERT Mini size: 4.4M parameters


In [None]:
from sentence_transformers.sparse_encoder.losses import SparseMarginMSELoss, SpladeLoss

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMarginMSELoss(model=model),
    query_regularizer_weight=5e-1,
    document_regularizer_weight=3e-1,
)

### **Evaluator**

In [None]:
import hashlib

def md5(text):
  res = hashlib.md5(text.encode())
  return res.hexdigest()

dev_dataset = test_dataset.select(range(5_000))

dev_queries = dict(zip(dev_dataset["query_id"], dev_dataset["query"]))

dev_corpus = {}
for row in dev_dataset:
  for positive in row["positives"]:
    dev_corpus[md5(positive)] = positive

  for negative in row["negatives"]:
    dev_corpus[md5(negative)] = negative

dev_relevant_docs = dict(
    zip(
      dev_dataset["query_id"],
      [[md5(pos) for pos in positives] for positives in dev_dataset["positives"]]
    )
  )

len(dev_corpus), len(dev_queries), len(dev_relevant_docs)

(49036, 5000, 5000)

In [None]:
from sentence_transformers.sparse_encoder.evaluation import SparseInformationRetrievalEvaluator

dev_evaluator = SparseInformationRetrievalEvaluator(
    queries=dev_queries,
    corpus=dev_corpus,
    relevant_docs=dev_relevant_docs,
    batch_size=64,
    corpus_chunk_size=2048,
    show_progress_bar=False
)

### **Train**

In [None]:
from sentence_transformers import SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.training_args import BatchSamplers

num_epochs = 6
batch_size = 48
# gradient_accum_steps = 4

# 5. (Optional) Specify training arguments
run_name = "SPLADE-BERT-Tiny-Distil"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    # gradient_accumulation_steps=gradient_accum_steps,
    per_device_eval_batch_size=batch_size,
    learning_rate=8e-5,
    warmup_ratio=0.025,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    # push_to_hub=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    load_best_model_at_end=True,
    metric_for_best_model="dot_mrr@10",
)

# 7. Create a trainer
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=relevance_dataset.select_columns(['query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'label']), #
    loss=loss,
    evaluator=dev_evaluator,
)



Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
# Train
trainer.train()

In [None]:
# 9. Save the trained model
model.save_pretrained(f"./final")

In [None]:
# 10. (Optional) Push it to the Hugging Face Hub
# from google.colab import userdata

# # push model to hub
# trainer.model.push_to_hub(run_name, exist_ok=True, private=True, token=userdata.get("HF_WRITE_TOKEN"))