In [1]:
%pip install -U sentence-transformers datasets torch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [6]:

import torch
from sentence_transformers import SparseEncoder
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm

# Загрузка модели SPLADE
print("Загрузка модели SPLADE...")
model = SparseEncoder("naver/splade-cocondenser-ensembledistil")

# Загрузка датасета MS MARCO (только маленькая часть для демонстрации)
print("Загрузка датасета MS MARCO...")
dataset = load_dataset("microsoft/ms_marco", "v1.1")

# Подготовка данных - берем первые 100 запросов для быстрого тестирования
queries = []
qids = []
documents = []
relevant_docs = {}

print("Подготовка данных...")
for i, example in enumerate(dataset['test']):
    query = example['query']
    passages = example['passages']
    
    if query and passages['passage_text']:
        queries.append(query)
        qids.append(example['query_id'])
        
        # Берем первые 10 пассажей для каждого запроса
        query_docs = passages['passage_text'][:10]
        documents.extend(query_docs)
        
        # Сохраняем релевантные документы для каждого запроса
        relevant_docs[example['query_id']] = [
            passages['passage_text'][j] for j in range(len(passages['is_selected'])) 
            if passages['is_selected'][j] == 1
        ]

print(f"Загружено {len(queries)} запросов и {len(documents)} документов")

# Функция для вычисления MRR@10
def calculate_mrr_at_10(queries, documents, relevant_docs, model, top_k=10):
    mrr_scores = []
    
    for i, query in enumerate(tqdm(queries, desc="Обработка запросов")):
        query_embedding = model.encode_query([query])
        
        # Вычисляем схожесть со всеми документами
        similarities = []
        batch_size = 32
        
        for j in range(0, len(documents), batch_size):
            batch_docs = documents[j:j+batch_size]
            doc_embeddings = model.encode_document(batch_docs)
            batch_similarities = model.similarity(query_embedding, doc_embeddings)
            similarities.extend(batch_similarities[0].cpu().numpy())
        
        # Получаем топ-K наиболее релевантных документов
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        top_documents = [documents[idx] for idx in top_indices]
        
        # Проверяем, есть ли релевантные документы в топ-10
        current_qid = qids[i]
        relevant_for_query = relevant_docs.get(current_qid, [])
        
        if relevant_for_query:
            reciprocal_rank = 0
            for rank, doc in enumerate(top_documents, 1):
                if doc in relevant_for_query:
                    reciprocal_rank = 1.0 / rank
                    break
            mrr_scores.append(reciprocal_rank)
        else:
            mrr_scores.append(0.0)
    
    return np.mean(mrr_scores)

# Вычисляем MRR@10
print("Вычисление MRR@10...")
mrr_score = calculate_mrr_at_10(queries, documents, relevant_docs, model, top_k=10)
print(f"MRR@10: {mrr_score:.4f}")

# Дополнительно: пример работы модели на одном запросе
print("\n--- Пример работы ---")
test_query = queries[0]
print(f"Запрос: {test_query}")

query_embedding = model.encode_query([test_query])
doc_embeddings = model.encode_document(documents[:5])

similarities = model.similarity(query_embedding, doc_embeddings)
print("Схожести с первыми 5 документами:")
for i, sim in enumerate(similarities[0]):
    print(f"Документ {i+1}: {sim:.4f}")

Загрузка модели SPLADE...
Загрузка датасета MS MARCO...
Подготовка данных...
Загружено 9650 запросов и 79175 документов
Вычисление MRR@10...


Обработка запросов:   0%|          | 0/9650 [00:00<?, ?it/s]


NotImplementedError: Could not run 'aten::as_strided' with arguments from the 'SparseMPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::as_strided' is only available for these backends: [CPU, MPS, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradMAIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastMTIA, AutocastMAIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCPU_0.cpp:1823 [kernel]
MPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterMPS_0.cpp:3023 [kernel]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterMeta_0.cpp:2431 [kernel]
QuantizedCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterQuantizedCPU_0.cpp:192 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp:23259 [kernel]
Named: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
Conjugate: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:21 [kernel]
Negative: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:22 [kernel]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterZeroTensor_0.cpp:137 [kernel]
ADInplaceOrView: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/ADInplaceOrViewType_0.cpp:4978 [kernel]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradHIP: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradIPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradVE: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradMTIA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradMAIA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradPrivateUse1: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradPrivateUse2: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradPrivateUse3: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradMeta: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
AutogradNestedTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18228 [autograd kernel]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/TraceType_0.cpp:17237 [kernel]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
AutocastMTIA: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:468 [backend fallback]
AutocastMAIA: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:506 [backend fallback]
AutocastXPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:544 [backend fallback]
AutocastMPS: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:735 [kernel]
BatchedNestedTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079 [kernel]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]
