In [1]:
from datasets import load_dataset

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.splade.models.transformer_rep import Splade

df = load_dataset("microsoft/ms_marco", "v1.1")

model_type_or_dir = "naver/splade-cocondenser-ensembledistil"

model = Splade(model_type_or_dir, agg="max")
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

  return self.fget.__get__(instance, owner)()
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [None]:
import csv
import torch
from tqdm import tqdm
from collections import defaultdict
import pickle
import os
import time
from datetime import datetime

os.makedirs("backups", exist_ok=True)

reverse_index = defaultdict(list)
batch_size = 32

last_save_time = time.time()
save_interval = 30 * 600
counter = 0
with open("collection.tsv") as fd:
    rd = csv.reader(fd, delimiter="\t", quotechar='"')
    
    batch_docs = []
    batch_ids = []
    total_processed = 0
    
    for row in tqdm(rd):
        counter += 1
        if counter < 4_276_000:
            continue
            
        batch_ids.append(row[0])
        batch_docs.append(row[1])
        
        if len(batch_docs) == batch_size:
            passage_tokens = tokenizer(batch_docs, return_tensors="pt", truncation=True, 
                                      max_length=512, padding=True).to(device)
            
            with torch.no_grad():
                batch_reps = model(d_kwargs=passage_tokens)["d_rep"]
            
            for i, (doc_id, doc_rep) in enumerate(zip(batch_ids, batch_reps)):
                doc_rep = doc_rep.squeeze()
                mask = doc_rep > 0.01
                indices = torch.arange(doc_rep.size(0), device=device)[mask]
                weights = doc_rep[mask]
                
                sorted_indices = weights.argsort(descending=True)
                indices = indices[sorted_indices].cpu().numpy()
                weights = weights[sorted_indices].cpu().numpy()
                
                for idx, weight in zip(indices, weights):
                    reverse_index[reverse_voc[idx]].append((doc_id, float(weight)))
            
            total_processed += len(batch_docs)
            batch_docs = []
            batch_ids = []
            
            current_time = time.time()
            if current_time - last_save_time >= save_interval or counter > 5_800_000:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                
                with open(f"backups/reverse_index_{timestamp}.pkl", "wb") as f:
                    pickle.dump(dict(reverse_index), f)
                
                with open(f"backups/progress_{timestamp}.txt", "w") as f:
                    f.write(f"Documents processed: {total_processed}\n")
                    f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                
                print(f"\nBackup saved at {timestamp} - {total_processed} documents processed")
                last_save_time = current_time
                break
    
    
print(f"\nProcessing complete. Total documents processed: {total_processed}")


5800030it [35:58, 2686.55it/s]  


Backup saved at 20251103_153037 - 1524032 documents processed

Processing complete. Total documents processed: 1524032



