In [None]:
import pickle

with open('./merged.pkl', 'rb') as f:
    documents = pickle.load(f)

In [None]:
%pip install ir_datasets

In [None]:
import ir_datasets

dataset = ir_datasets.load("msmarco-passage/dev")

queries_count = sum(1 for _ in dataset.queries_iter())
print(f"Number of queries: {queries_count}")

print("=== Примеры релевантных пар query-passage ===")
for i, qrel in enumerate(dataset.qrels_iter()):
    if i >= 10:  # Покажем первые 10
        break
    print(f"Query {qrel.query_id} -> Passage {qrel.doc_id}, Relevance: {qrel.relevance}")

from tqdm import tqdm

qrels_mapping = {}
for qrel in dataset.qrels_iter():
    qrels_mapping.setdefault(qrel.query_id, {})[qrel.doc_id] = qrel.relevance
    
    
for query in tqdm(dataset.queries_iter()):
    print(query)
    break

### Загружаю Splade библиотечный

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from ..spladev2.cloned.splade.splade.models.transformer_rep import Splade

model_type_or_dir = "naver/splade-cocondenser-ensembledistil"

model = Splade(model_type_or_dir, agg="max")
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### Или Splade самописный

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForMaskedLM, AutoTokenizer

model_type_or_dir = "naver/splade-cocondenser-ensembledistil"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(model_name)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_mlm(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        relu_logits = F.relu(logits)
        relu_logits = relu_logits * attention_mask.unsqueeze(-1)
        pooled, _ = torch.max(relu_logits, dim=1)
        weights = torch.log1p(pooled)
        return weights
    
model = SPLADE().to(device)
model.load_state_dict(torch.load('model.pt'))
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

model = model.to(device)

In [None]:
import torch
import heapq

model.to(device)
model.eval()

count = 0
mrr = 0.0

import csv

def get_document_text(doc_id, file_path="collection.tsv"):
    doc_id = str(doc_id)
    with open(file_path, "r", encoding="utf-8") as fd:
        rd = csv.reader(fd, delimiter="\t", quotechar='"')
        for row in rd:
            if row[0] == doc_id:
                return row[1]
    return None


for query in dataset.queries_iter():
    if count >= 100:
        break
    
    query_id = query.query_id
    relevant_passages = qrels_mapping.get(query_id, {})
    
    if len(relevant_passages) == 0:
        continue

    count += 1
    
    inputs = tokenizer(query.text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        q_rep = model(inputs['input_ids'], inputs['attention_mask'])
        

    q_rep = q_rep.squeeze()
    indices = torch.nonzero(q_rep).flatten()
    weights = q_rep[indices]
    
    q_indices = indices.tolist()
    q_weights = weights.tolist()
    
    q_rep = q_rep.squeeze()
    indices = torch.nonzero(q_rep).flatten()
    weights = q_rep[indices]
    
    q_indices = indices.tolist()
    q_weights = weights.tolist()
    
    scores = {}
    for idx, q_w in zip(q_indices, q_weights):
        token = reverse_voc[idx]
        if token in documents:
            for doc_id, d_w in documents[token]:
                if doc_id in scores:
                    scores[doc_id] += q_w * d_w
                else:
                    scores[doc_id] = q_w * d_w
    
    top_docs = heapq.nlargest(10, scores.items(), key=lambda x: x[1])
    
    rank = 0
    for i, (doc_id, score) in enumerate(top_docs):
        if doc_id in relevant_passages:
            rank = i + 1
            break

    if rank > 0:
        mrr += 1 / rank

    print(query, get_document_text(top_docs[0][0]))
    print(count, rank)

print(mrr)


In [None]:
import torch
import heapq
import torch.nn.functional as F

model.to(device)
model.eval()

count = 0
mrr = 0.0
x = 0
for query in dataset.queries_iter():
    x += 1
    if x < 150: 
        continue
        
    if count >= 100:
        break
    
    query_id = query.query_id
    relevant_passages = qrels_mapping.get(query_id, {})
    
    if len(relevant_passages) == 0:
        continue

    count += 1
    
    inputs = tokenizer(query.text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        q_rep = model(inputs['input_ids'], inputs['attention_mask'])
        

    q_rep = q_rep.squeeze()
    indices = torch.nonzero(q_rep).flatten()
    weights = q_rep[indices]
    
    q_indices = indices.tolist()
    q_weights = weights.tolist()
    
    q_rep = q_rep.squeeze()
    indices = torch.nonzero(q_rep).flatten()
    weights = q_rep[indices]
    
    q_indices = indices.tolist()
    q_weights = weights.tolist()
    
    scores = {}
    for idx, q_w in zip(q_indices, q_weights):
        token = reverse_voc[idx]
        if token in documents:
            for doc_id, d_w in documents[token]:
                if doc_id in scores:
                    scores[doc_id] += q_w * d_w
                else:
                    scores[doc_id] = q_w * d_w
    
    top_docs = heapq.nlargest(10, scores.items(), key=lambda x: x[1])
    
    rank = 0
    for i, (doc_id, score) in enumerate(top_docs):
        if doc_id in relevant_passages:
            rank = i + 1
            break

    if rank > 0:
        mrr += 1 / rank

    print(query, get_document_text(top_docs[0][0]))
    print(count, rank)

print(mrr, count, mrr / count)

### Метрики MRR@10 для различных моделей

- **naver/splade-cocondenser-ensembledistil: 0.392**

- **tuned mlm bert head: 0.078**

    *   Q: cause of pimples on tip of nose  
        A: Causes of pimples on nose. A pimple, zit or bump on nose can have different causes. Some of the possible causes would include the following: 1. Stress. Being emotionally stress has the ability to trigger hormonal changes that can interfere with normal function of the body. Common such change is increased secretion of sebum by the sebaceous glands.  
        RANK: 1

    *   Q: what is web accessibility  
        A: If you want to learn the skills required to ensure your organization's web site is accessible, this training will provide what you need. We'll also provide the resources and information you need to empower your organization to meet all of your future accessibility needs.  
        RANK: 5

    *   Q: who produced transformers  
        A: Assuming the question related to the toy line; Transformers first came out in 1984 after Hasbro bought out the Diaclone, Diakron, and micro toy lines. Hasbro created a cartoâ€¦on as an ad campaign and named the toys something better than DK-1.  
        RANK: 0

- **pretrained mlm bert head, not trained: ~0**

    *   Q: what is opportunities in acfta  
        A: At the end of a cell cycle, including mitosis, the new cells will have. 23 pairs of chromosomes; 46 total. only the 23 maternal chromosomes. only the 23 paternal chromosomes. 92 chromosomes, as a result of doubling during the S-phase of the cell cycle. Interphase is broken into phases known as. G1, S and G2.  
        RANK: 0

    *   Q: trend what does it means  
        A: Editor s note: This post is a Care2 favorite, back by popular demand. It was originally posted on December 26, 2012. Enjoy! People wear perfume and cologne to feel like they smell clean and fresh or sexy or elegant.  
        RANK: 0  

    *   Q: trending topic meaning  
        A: The term âjaundice, means âyellow in color and refers to the yellowing of a babyâs skin, owing to the accumulation of orange-coloured bilirubin. Yellow Jaundice is common in healthy new-born babies; about 60% of all babies get physiological jaundice ( normal neonatal jaundice ). Physiological jaundice in newborns will usually peak at day 5 and 6 and then fade within a week.  
        RANK: 0





