## In this jupyter notebook we will calculate the SPLADE representation of the trec-cast dataset (just the queries)

First we will download the trec-cast v1 2020 dataset from [ir-datasets.com](https://ir-datasets.com/trec-cast.html#trec-cast/v1/2020)
In this JN we only look at the queries, since running splade on all documents takes way to long to run locally

In [13]:
import ir_datasets
dataset = ir_datasets.load("trec-cast/v1/2020")
topics = set()
for query in dataset.queries_iter():
    topics.add(query.topic_number)

print(f"amount of topics\t: \t len(topics)")
print(f"amount of queries\t: \t {len(dataset.queries)}")

amount of topics	: 	 len(topics)
amount of queries	: 	 216


Now lets initialize the model (from huggingface, see SPLADE test for more info)

In [36]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

Now lets run the code on the queries

In [35]:
import torch
import csv

def process_queries_and_write_tsv_file(model_id, dataset, N, utterance_type, output_file_path):
    from transformers import AutoTokenizer, AutoModelForMaskedLM

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(model_id)

    idx2token = {
        idx: token for token, idx in tokenizer.get_vocab().items()
    }

    results = []

    for query in dataset.queries_iter():
        tokens = tokenizer(getattr(query, utterance_type), return_tensors='pt')
        output = model(**tokens)
        query_id = query.query_id

        vec = torch.max(
            torch.log(
                1 + torch.relu(output.logits)
            ) * tokens.attention_mask.unsqueeze(-1),
            dim=1)[0].squeeze()

        cols = vec.nonzero().squeeze().cpu().tolist()

        # Extract the non-zero values
        weights = vec[cols].cpu().tolist()
        # Create a dictionary of token ID to weight
        sparse_dict = dict(zip(cols, weights))

        # Sort the dictionary by weight in descending order
        sparse_dict = dict(sorted(sparse_dict.items(), key=lambda item: item[1], reverse=True))

        # Convert the token IDs to words
        sparse_dict_tokens = [
            idx2token[idx] for idx in cols for _ in range(int(round(sparse_dict[idx] * N, 0)))
        ]

        # Generate the output string in the desired format
        output_str = f"{query_id}\t{' '.join(sparse_dict_tokens)}"
        results.append(output_str)

    # Write the results to a TSV file
    with open(output_file_path, "w", newline="", encoding="utf-8") as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        writer.writerows([line.split('\t') for line in results])

# Usage example for N=1 and N=100, and three utterance types
model_id = 'naver/splade-cocondenser-selfdistil'
output_dir = 'SPLADE_embeddings'

# for N in [1, 100]:
#     for utterance_type in ["raw_utterance", "automatic_rewritten_utterance", "manual_rewritten_utterance"]:
#         output_file_path = f"{output_dir}/output_N{N}_{utterance_type}.tsv"
#         process_queries_and_write_tsv_file(model_id, dataset, N, utterance_type, output_file_path)
