[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/sparse/splade/splade-embedding-generation.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/sparse/splade/splade-embedding-generation.ipynb)

# SPLADE Sparse-Dense Embedding Generation

## Overview

SPLADE is a class of models that produce sparse embeddings. Unlike dense embeddings which can be difficult to interpret sparse embeddings map to tokens for easier interpretability. SPLADE models have been shown to consistently outperform dense models, particularly in out-of-domain settings.

The following guide will show you how to construct SPLADE embeddings to use with Pinecone's sparse-dense index. See the [companion guide]() to learn to skip embedding generation  

## Install & Import

In [1]:
!pip install \
          git+https://git@github.com/pinecone-io/pinecone-python-client.git#egg=pinecone-client[grpc] \
          polars \
          transformers \
          torch \
          sentence_transformers \
          gcsfs \
          pyarrow

In [2]:
import polars as pl

## Sparse Embeddings with SPLADE

In the following example we will use. SPLADE Model: [naver/splade-cocondenser-ensembledistil](https://huggingface.co/naver/splade-cocondenser-ensembledistilhttps://huggingface.co/naver/splade-cocondenser-ensembledistil)

In [72]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

class SPLADE:
    def __init__(self, model):
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModelForMaskedLM.from_pretrained(model)

    def __call__(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")

        with torch.no_grad():
            logits = self.model(**inputs).logits

        inter = torch.log1p(torch.relu(logits[0]))
        token_max = torch.max(inter, dim=0)  # sum over input tokens
        nz_tokens = torch.where(token_max.values > 0)[0]
        nz_weights = token_max.values[nz_tokens]

        order = torch.sort(nz_weights, descending=True)
        nz_weights = nz_weights[order[1]]
        nz_tokens = nz_tokens[order[1]]
        return {'indices': nz_tokens.numpy().tolist(), 'values': nz_weights.numpy().tolist()}

In [85]:
splade = SPLADE("naver/splade-cocondenser-ensembledistil")

In [86]:
doc = "what is the capital of france?"
sparse_vector = splade(doc)

indices,values,tokens
i64,f64,str
3007,3.099394,"""capital"""
2605,2.851794,"""france"""
2413,2.387511,"""french"""
2885,1.737867,"""europe"""
9424,1.73612,"""capitol"""
2103,1.334281,"""city"""
10505,0.802148,"""geography"""
3000,0.707464,"""paris"""
8709,0.678849,"""michel"""
5288,0.534112,"""switzerland"""


SPLADE query & document expansion

In [None]:
pl.DataFrame({
    **sparse_vector,
    **{'tokens': splade.tokenizer.convert_ids_to_tokens(sparse_vector['indices'])} # Fetch original tokens
})[:10]

## Dense Embeddings

For dense embeddings we use a popular model from sentence-transformers

In [76]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

## Processing the Quora Dataset

Quora is a popular dataset for evaluating text search

In [77]:
sample = 500

In [65]:
df = pl.read_parquet('https://storage.googleapis.com/gareth-pinecone-datasets/quora.parquet').select([pl.col(['id', 'text'])]).head(sample)

In [66]:
df.head(10)

id,text
i64,str
1,""" What is the s..."
2,""" What is the s..."
3,""" What is the s..."
4,""" What would ha..."
5,""" How can I inc..."
6,""" How can Inter..."
7,""" Why am I ment..."
8,""" Find the rema..."
9,""" Which one dis..."
10,""" Which fish wo..."


In [87]:
df = df.with_columns([
    pl.col('text').apply(lambda x: splade(x)).alias('sparse_values'),
    pl.col('text').apply(lambda x: model.encode(x).tolist()).alias('values'),
])

## Upsert to Pinecone

In [24]:
# # Load Pinecone API key
import os
import pinecone

api_key = os.getenv('PINECONE_API_KEY') or None
environment = None
pinecone.init(
    api_key=api_key,
    environment=environment
)

In [25]:
index_name = "splade-embedding-generation"
batch_size = 300
dimension = 384

In [None]:
pinecone.create_index(
    index_name,
    pod_type='s1',
    metric='dotproduct',
    dimension=dimension,
    metadata_config={"indexed": []}
)

In [88]:
from pinecone import GRPCVector, GRPCSparseValues
from google.protobuf.struct_pb2 import Struct
from tqdm import tqdm

with pinecone.GRPCIndex(index_name) as index:
    for i in tqdm(range(0, len(df), batch_size)):
        batch = df[i:min(i+batch_size, len(df))]
        upserts = []
        for row in batch.rows(named=True):
            metadata = Struct()
            metadata.update(dict(text=row['text']))
            u = GRPCVector(
                values=row['values'],
                id=str(row['id']),
                metadata=metadata, 
                sparse_values=GRPCSparseValues(**row['sparse_values'])
            )
            upserts.append(u)
        r = index.upsert(vectors=upserts, async_req=False)

100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


## Querying

In [89]:
index = pinecone.Index(index_name)

In [90]:
index.query(id="1", top_k=10)

{'matches': [{'id': '1', 'score': 34.0199394, 'values': []},
             {'id': '2', 'score': 30.4394989, 'values': []},
             {'id': '260', 'score': 11.6323013, 'values': []},
             {'id': '259', 'score': 9.89988232, 'values': []},
             {'id': '498', 'score': 8.63788795, 'values': []},
             {'id': '301', 'score': 8.40716457, 'values': []},
             {'id': '202', 'score': 8.00751114, 'values': []},
             {'id': '212', 'score': 7.43195486, 'values': []},
             {'id': '403', 'score': 6.58435822, 'values': []},
             {'id': '410', 'score': 6.3796258, 'values': []}],
 'namespace': ''}

### Hybrid Query

We use a weighted combination of sparse and dense query vectors

In [96]:
def hybrid_score_norm(dense, sparse, alpha: float):
    """Hybrid score using a convex combination

    alpha * dense + (1 - alpha) * sparse

    Args:
        dense: Array of floats representing
        sparse: a dict of `indices` and `values`
        alpha: scale between 0 and 1
    """
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    hs = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    return [v * alpha for v in dense], hs

In [97]:
text = "how to invest in india"
sparse = splade(text)
dense = model.encode(text).tolist()
hdense, hsparse = hybrid_score_norm(dense, sparse, 0.85)

In [99]:
from pinecone import SparseValues

index.query(top_k=10, vector=hdense, sparse_vector=SparseValues(**hsparse), include_metadata=True)

{'matches': [{'id': '1',
              'metadata': {'text': ' What is the step by step guide to invest '
                                   'in share market in india?'},
              'score': 4.3153429,
              'values': []},
             {'id': '260',
              'metadata': {'text': ' How do I access Google.com from India?'},
              'score': 2.86029983,
              'values': []},
             {'id': '2',
              'metadata': {'text': ' What is the step by step guide to invest '
                                   'in share market?'},
              'score': 2.7389214,
              'values': []},
             {'id': '498',
              'metadata': {'text': ' How can I best invest ₹5000 over the next '
                                   '6 months?'},
              'score': 2.55986595,
              'values': []},
             {'id': '259',
              'metadata': {'text': ' How do I access Torbox in India?'},
              'score': 2.49220681,
              'va