[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/sparse/splade/splade-queries.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/sparse/splade/splade-queries.ipynb)

# Sparse-Dense Vector Search with SPLADE

## Overview

SPLADE is a class of models that produce sparse embeddings. Unlike dense embeddings which can be difficult to interpret sparse embeddings map to tokens for easier interpretability. SPLADE models have been shown to consistently outperform dense models, particularly in out-of-domain settings.

The following guide will show you how to construct SPLADE embeddings to use with Pinecone's sparse-dense index. See the [companion guide]() to learn how to generate embeddings

## Install

In [1]:
!pip install -qU \
          git+https://git@github.com/pinecone-io/pinecone-python-client.git#egg=pinecone-client[grpc] \
          polars \
          transformers \
          torch \
          sentence_transformers \
          gcsfs \
          pyarrow

In [2]:
import polars as pl
from tqdm import tqdm
import os
import pinecone

## Init Pinecone

In [3]:
# Init Pinecone
api_key = os.getenv('PINECONE_API_KEY') or None
environment = None
if (api_key is None) or (environment is None):
    raise ValueError('You must specify an environment and API Key')

pinecone.init(
    api_key=api_key,
    environment=environment
)

## Quora Dataset

Load the popular Quora dataset with embeddings precomputed using

* Dense: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
* Sparse: [naver/splade-cocondenser-ensembledistil](https://huggingface.co/naver/splade-cocondenser-ensembledistil)


In [4]:
df = pl.read_parquet('https://storage.googleapis.com/gareth-pinecone-datasets/quora.parquet')

In [None]:
len(df)

In [5]:
df.head()

values,id,sparse_values,text
list[f32],i64,struct[2],str
"[0.402489, -0.234254, ... 0.222465]",1,"{[0.08416, 0.055447, ... 0.197036],[1012, 2000, ... 20138]}",""" What is the s..."
"[0.511194, -0.198763, ... 0.238244]",2,"{[0.189603, 0.21525, ... 0.147162],[1000, 1012, ... 20138]}",""" What is the s..."
"[-0.223715, 0.741517, ... 0.191113]",3,"{[0.134441, 0.188095, ... 0.021336],[1005, 1006, ... 17070]}",""" What is the s..."
"[-0.37124, 0.709703, ... -0.049561]",4,"{[0.029134, 0.360959, ... 0.042988],[1005, 1011, ... 21156]}",""" What would ha..."
"[-0.166566, 0.218813, ... -0.097715]",5,"{[0.033697, 0.44331, ... 0.629484],[2006, 2017, ... 21628]}",""" How can I inc..."


## Index Creation

In [10]:
index_name = "splade-embedding-query"
batch_size = 300
dimension = 384

In [11]:
pinecone.create_index(
    index_name,
    pod_type='s1',
    metric='dotproduct',
    dimension=dimension,
    metadata_config={"indexed": []}
)

## Upsert

In [14]:
from pinecone import GRPCVector, GRPCSparseValues
from google.protobuf.struct_pb2 import Struct

with pinecone.GRPCIndex(index_name) as index:
    for i in tqdm(range(0, len(df), batch_size)):
        batch = df[i:min(i+batch_size, len(df))]
        upserts = []
        for row in batch.rows(named=True):
            metadata = Struct()
            metadata.update(dict(text=row['text']))
            u = GRPCVector(
                values=row['values'],
                id=str(row['id']),
                metadata=metadata,
                sparse_values=GRPCSparseValues(**row['sparse_values'])
            )
            upserts.append(u)
        index.upsert(vectors=upserts, async_req=False)

100%|██████████| 1744/1744 [03:50<00:00,  7.55it/s]


## Sparse-Dense Queries with SPLADE


In [15]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

class SPLADE:
    def __init__(self, model):
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModelForMaskedLM.from_pretrained(model)

    def __call__(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")

        with torch.no_grad():
            logits = self.model(**inputs).logits

        inter = torch.log1p(torch.relu(logits[0]))
        token_max = torch.max(inter, dim=0)  # sum over input tokens
        nz_tokens = torch.where(token_max.values > 0)[0]
        nz_weights = token_max.values[nz_tokens]

        order = torch.sort(nz_weights, descending=True)
        nz_weights = nz_weights[order[1]]
        nz_tokens = nz_tokens[order[1]]
        return {'indices': nz_tokens.numpy().tolist(), 'values': nz_weights.numpy().tolist()}

def hybrid_score_norm(dense, sparse, alpha: float):
    """Hybrid score using a convex combination

    alpha * dense + (1 - alpha) * sparse

    Args:
        dense: Array of floats representing
        sparse: a dict of `indices` and `values`
        alpha: scale between 0 and 1
    """
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    hs = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    return [v * alpha for v in dense], hs

In [17]:
from sentence_transformers import SentenceTransformer

splade = SPLADE("naver/splade-cocondenser-ensembledistil")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [18]:
index = pinecone.Index(index_name)

In [24]:
text = "what is the capital of france"
sparse = splade(text)
dense = model.encode(text).tolist()
hdense, hsparse = hybrid_score_norm(dense, sparse, 0.85)

In [25]:
from pinecone import SparseValues

index.query(top_k=10, vector=hdense, sparse_vector=SparseValues(**hsparse), include_metadata=True)

{'matches': [{'id': '29784',
              'metadata': {'text': ' What is the capital city of France?'},
              'score': 9.20412,
              'values': []},
             {'id': '184797',
              'metadata': {'text': ' What are best tourist spots in france?'},
              'score': 6.47260284,
              'values': []},
             {'id': '380665',
              'metadata': {'text': ' What is France famous for?'},
              'score': 6.34292221,
              'values': []},
             {'id': '380666',
              'metadata': {'text': ' What is France most famous for?'},
              'score': 6.23249435,
              'values': []},
             {'id': '223775',
              'metadata': {'text': ' Is France safe to visit?'},
              'score': 6.20398569,
              'values': []},
             {'id': '99552',
              'metadata': {'text': ' What are some popular landmarks in '
                                   'France?'},
              'score': 6.