### Загружаются различные модели: 

```python
    model = Splade(model_type_or_dir, agg="max")    # предобученный splade с библиотеки

    model = SPLADE().to(device)                     # самописная обученная версия splade 
    model.load_state_dict(torch.load('model.pt'))
```

In [None]:
from datasets import load_dataset

import torch
import torch.nn as nn
from transformers import AutoModelForMaskedLM, AutoTokenizer

df = load_dataset("microsoft/ms_marco", "v1.1")

model_type_or_dir = "naver/splade-cocondenser-ensembledistil"

# model = Splade(model_type_or_dir, agg="max")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(model_name)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_mlm(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        relu_logits = F.relu(logits)
        relu_logits = relu_logits * attention_mask.unsqueeze(-1)
        pooled, _ = torch.max(relu_logits, dim=1)
        weights = torch.log1p(pooled)
        return weights
    
model = SPLADE().to(device)
model.load_state_dict(torch.load('model.pt'))
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

model = model.to(device)

### Запускаю N инстансов блокнота, каждый из которых обрабатывает часть датасета для обратного индекса

```python
end_position = 9_000_000    # всего документов в датасете
sline_num = 0               # номер текущего инстанса -- только это поле нужно менять
slice_cnt = 3               # общее количество инстансов
```

In [None]:
import csv
import torch
from tqdm import tqdm
from collections import defaultdict
import pickle
import os
import time
from datetime import datetime
import torch.nn.functional as F


reverse_index = defaultdict(list)
batch_size = 32

last_save_time = time.time()
save_interval = 30 * 600
counter = 0

start_position = 0
end_position = 9_000_000
sline_num = 0
slice_cnt = 3
actual_start = (end_position - start_position) // slice_cnt * sline_num
actual_end = (end_position - start_position) // slice_cnt * (sline_num + 1)

if not os.path.exists("backups"):
    os.makedirs("backups")

with open("collection.tsv") as fd:
    rd = csv.reader(fd, delimiter="\t", quotechar='"')
    
    batch_docs = []
    batch_ids = []
    total_processed = 0
    
    for row in tqdm(rd):
        counter += 1
        if counter < actual_start:
            continue
            
        batch_ids.append(row[0])
        batch_docs.append(row[1])
        
        if len(batch_docs) == batch_size:
            passage_tokens = tokenizer(batch_docs, return_tensors="pt", truncation=True, 
                                      max_length=512, padding=True)
            
            with torch.no_grad():
                input_ids = passage_tokens['input_ids'].to(device)
                attention_mask = passage_tokens['attention_mask'].to(device)
                batch_reps = model(input_ids, attention_mask)
            
            for i, (doc_id, doc_rep) in enumerate(zip(batch_ids, batch_reps)):
                doc_rep = doc_rep.squeeze()
                mask = doc_rep > 0.01
                indices = torch.arange(doc_rep.size(0), device=device)[mask]
                weights = doc_rep[mask]
                
                sorted_indices = weights.argsort(descending=True)
                indices = indices[sorted_indices].cpu().numpy()
                weights = weights[sorted_indices].cpu().numpy()
                
                for idx, weight in zip(indices, weights):
                    reverse_index[reverse_voc[idx]].append((doc_id, float(weight)))
            
            total_processed += len(batch_docs)
            batch_docs = []
            batch_ids = []
            
            current_time = time.time()
            if current_time - last_save_time >= save_interval or counter > actual_end:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                
                with open(f"backups/reverse_index_{timestamp}.pkl", "wb") as f:
                    pickle.dump(dict(reverse_index), f)
                
                with open(f"backups/progress_{timestamp}.txt", "w") as f:
                    f.write(f"Documents processed: {total_processed}\n")
                    f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                
                print(f"\nBackup saved at {timestamp} - {total_processed} documents processed")
                last_save_time = current_time
                
                if counter > actual_end:
                    break
    
    
print(f"\nProcessing complete. Total documents processed: {total_processed}")


### После инференса всех частей выполняю merge

In [None]:
from tqdm import tqdm
import pickle

result = {}

def merge_pickles(file_paths, output_path):
    global result
    for file_path in tqdm(file_paths):
        with open(file_path, 'rb') as f:
            current = pickle.load(f)
        
        for key, value in tqdm(current.items()):
            if key in result:
                result[key].extend(value)
            else:
                result[key] = value
    
    with open(output_path, 'wb') as f:
        pickle.dump(result, f, protocol=pickle.HIGHEST_PROTOCOL)

files = [
    '/home/jupyter/datasphere/project/backups/reverse_index_20251123_161520.pkl',
    '/home/jupyter/datasphere/project/backups/reverse_index_20251123_174727.pkl',
    '/home/jupyter/datasphere/project/backups/reverse_index_20251123_174815.pkl'
]
merge_pickles(files, 'merged.pkl')