In [1]:
!pip install -qq git+https://github.com/pinecone-io/pinecone-python-client.git@add-hybridapi-wiring git+https://github.com/naver/splade.git

## Dataset Preparation

We will use the PubMed dataset from Hugging Face Spaces...

In [1]:
from datasets import load_dataset

pubmed = load_dataset(
    'pubmed_qa',
    'pqa_labeled',
    split='train'
)
pubmed

Reusing dataset pubmed_qa (/Users/jamesbriggs/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)


Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})

In [2]:
pubmed[0]['pubid'], pubmed[0]['context']

(21645374,
 {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondr

We need to cut our contexts into digestable chunks for our models. We'll be using BERT which has a max sequence length of `512` tokens, *but* typical sentence transformers limit this to `128`.

To be safe and ensure we're not over the smaller `128` token limit we will assume an average token length of `3` characters (in reality it is more like *3-5*) and therefore our required length will be `128*3 == 384` characters.

To build passages of this length we will define a processing function called `chunker`.

In [3]:
limit = 384

def chunker(contexts: list):
    chunks = []
    all_contexts = ' '.join(contexts).split('.')
    chunk = []
    for context in all_contexts:
        chunk.append(context)
        if len(chunk) >= 3 and len('.'.join(chunk)) > limit:
            # surpassed limit so add to chunks and reset
            chunks.append('.'.join(chunk).strip()+'.')
            # add some overlap between passages
            chunk = chunk[-2:]
    # if we finish and still have a chunk, add it
    if chunk is not None:
        chunks.append('.'.join(chunk))
    return chunks

In [4]:
chunks = chunker(pubmed[0]['context']['contexts'])
chunks

['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.',
 'The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A.',
 'The role of mitochondria during PCD has been recognized in animals; however, it has been less

We need to give each chunk a unique ID, like so:

In [5]:
ids = []
for i in range(len(chunks)):
    ids.append(f"{pubmed[0]['pubid']}-{i}")
ids

['21645374-0',
 '21645374-1',
 '21645374-2',
 '21645374-3',
 '21645374-4',
 '21645374-5',
 '21645374-6']

We create the full contexts dataset with this logic like so:

In [6]:
data = []
for record in pubmed:
    chunks = chunker(record['context']['contexts'])
    for i, context in enumerate(chunks):
        data.append({
            'id': f"{record['pubid']}-{i}",
            'context': context
        })

data[:10]

[{'id': '21645374-0',
  'context': 'Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.'},
 {'id': '21645374-1',
  'context': 'The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A.'},
 {'id': '21645374-2',
  '

## Model Initialization and Vectors

With our dataset prepared we can move on to initializing the required models and setting up some helper functions to make sparse and dense vector building easy.

### Dense Vectors

Starting with the dense vectors, we will use an off-the-shelf model from the `sentence-transformers` library.

In [7]:
from sentence_transformers import SentenceTransformer

dense_model = SentenceTransformer(
    'multi-qa-mpnet-base-dot-v1'
)
dense_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

We then create an embedding very easily like so:

In [8]:
emb = dense_model.encode(data[0]['context'])
emb.shape

(768,)

The model returns `768` dimensional dense vectors, this is also reflected in the model attributes.

In [9]:
dim = dense_model.get_sentence_embedding_dimension()
dim

768

### Sparse Vectors

We will also need to create sparse vectors. For that we will be using a learned sparse embedding model called SPLADE. SPLADE actually consists of many models that use similar embedding methods, we will be using the `naver/splade-cocondenser-ensembledistil` model.

We initialize the model like so:

In [10]:
from splade.models.transformer_rep import Splade

sparse_model_id = 'naver/splade-cocondenser-ensembledistil'

sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.eval()

The model takes tokenized inputs that are built using a tokenizer initialized with the same model ID.

In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)

tokens = tokenizer(data[0]['context'], return_tensors='pt')

To create sparse vectors we do:

In [12]:
import torch

with torch.no_grad():
    sparse_emb = sparse_model(
        d_kwargs=tokens
    )['d_rep'].squeeze()
sparse_emb.shape



torch.Size([30522])

In [13]:
sparse_emb

tensor([0., 0., 0.,  ..., 0., 0., 0.])

Leaving us with a `30522` dimensional sparse vector embedding. Pinecone will expect a dictionary style format of the sparse vector. To build it we take a couple more steps.

First we get a list of non-zero positions in the vector.

In [14]:
cols = sparse_emb.nonzero().squeeze().cpu().tolist()
print(len(cols))

174


We have `174` non-zero values, we use them to create a dictionary of index positions to scores like so:

In [15]:
weights = sparse_emb[cols].cpu().tolist()
sparse_dict = dict(zip(cols, weights))
sparse_dict

{1000: 0.6246446967124939,
 1039: 0.45678916573524475,
 1052: 0.3088974058628082,
 1997: 0.15812619030475616,
 1999: 0.07194626331329346,
 2003: 0.6496524810791016,
 2024: 0.9411943554878235,
 2049: 0.31614890694618225,
 2083: 0.7597626447677612,
 2094: 1.9501703977584839,
 2173: 0.3237406015396118,
 2239: 0.3950256109237671,
 2278: 0.23537182807922363,
 2290: 0.2457151710987091,
 2306: 0.4253382980823517,
 2331: 1.9602458477020264,
 2415: 0.628946840763092,
 2427: 0.4244104325771332,
 2523: 0.018045416101813316,
 2537: 0.19568762183189392,
 2550: 0.6684792637825012,
 2565: 0.8162299990653992,
 2566: 1.095424771308899,
 2597: 0.19797049462795258,
 2644: 0.22766251862049103,
 2754: 0.013308855704963207,
 2757: 0.9048304557800293,
 2832: 0.6024836301803589,
 2974: 0.610008716583252,
 3030: 0.03979627043008804,
 3081: 0.12952247262001038,
 3102: 0.023475565016269684,
 3252: 0.3975664973258972,
 3269: 1.2144653797149658,
 3274: 0.7056951522827148,
 3280: 1.5106254816055298,
 3370: 0.533285

This is the format that Pinecone requires, from here we can move on to indexing. Skip the next section to do so (in this we just try to understand the sparse vectors better).

#### Reading Sparse Embedding

But moving onto indexing everything in Pinecone, let's take a moment to understand what our sparse vector means. We can translate these into a human readable format so we can see what this sparse vector is actually representing.

We create a way of mapping from index positions to actual BERT tokenizer tokens.

In [16]:
idx2token = {idx: token for token, idx in tokenizer.get_vocab().items()}

Then create the mappings like we did with the Pinecone-friendly sparse format above.

In [17]:
sparse_dict_tokens = {
    idx2token[idx]: round(weight, 2) for idx, weight in zip(cols, weights)
}
# sort so we can see most relevant tokens first
sparse_dict_tokens = {
    k: v for k, v in sorted(
        sparse_dict_tokens.items(),
        key=lambda item: item[1],
        reverse=True
    )
}
sparse_dict_tokens

{'pc': 3.02,
 'lace': 2.95,
 'programmed': 2.36,
 '##for': 2.28,
 'madagascar': 2.26,
 'death': 1.96,
 '##d': 1.95,
 'lattice': 1.81,
 'cell': 1.69,
 '##iensis': 1.64,
 'malaga': 1.6,
 '##get': 1.56,
 'regulated': 1.53,
 'die': 1.51,
 'lacey': 1.5,
 '##ono': 1.46,
 '##ole': 1.45,
 '##oles': 1.45,
 'transverse': 1.39,
 '##scu': 1.39,
 'leaves': 1.34,
 'cells': 1.31,
 'longitudinal': 1.31,
 'plant': 1.21,
 'plants': 1.16,
 'leaf': 1.15,
 'ap': 1.14,
 'organism': 1.12,
 'per': 1.1,
 'regulation': 1.03,
 'veins': 1.02,
 '##work': 1.0,
 'organisms': 1.0,
 'are': 0.94,
 'modified': 0.93,
 'controlled': 0.92,
 'dead': 0.9,
 'occur': 0.9,
 'disorder': 0.87,
 'program': 0.82,
 '##lat': 0.82,
 'through': 0.76,
 '##cl': 0.74,
 'computer': 0.71,
 '##ations': 0.7,
 'abbreviation': 0.69,
 'produced': 0.67,
 'is': 0.65,
 'center': 0.63,
 '"': 0.62,
 'produce': 0.62,
 'technology': 0.61,
 'process': 0.6,
 '##osing': 0.59,
 'matt': 0.54,
 'cc': 0.54,
 '##ation': 0.53,
 'outward': 0.53,
 'gage': 0.52,
 

## Indexing Everything

To build the vector DB we will need to index everything, for this we will need to initialize our connection to Pinecone, create an index, and insert everything in the format:

```json
(
    "id",
    [0.1, 0.2, ...],  # dense vec
    {"21": 2.25, "182": 1.77, ...},  # sparse vec
    {"context": "some text here"}  # metadata dict
)
```

To make things easier we can create a helper function to transform a list of records from `data` into this format, we'll call it `builder`:

In [24]:
def builder(records: list):
    ids = [x['id'] for x in records]
    contexts = [x['context'] for x in records]
    # create dense vecs
    dense_vecs = dense_model.encode(contexts).tolist()
    # create sparse vecs
    input_ids = tokenizer(
        contexts, return_tensors='pt',
        padding=True, truncation=True
    )
    with torch.no_grad():
        sparse_embs = sparse_model(
            d_kwargs=input_ids
        )['d_rep'].squeeze()
    # convert to dictionary format
    sparse_dicts = []
    for embs in sparse_embs:
        # extract columns where there are non-zero weights
        cols = embs.nonzero().squeeze().cpu().tolist()
        weights = embs[cols].cpu().tolist()
        # build sparse dictionary
        cols = [str(idx) for idx in cols]
        sparse_dict = dict(zip(cols, weights))
        sparse_dicts.append(sparse_dict)
    # build metadata dict
    metadata = [{'context': x} for x in contexts]
    to_upsert = list(zip(ids, dense_vecs, sparse_dicts, metadata))
    return to_upsert

In [25]:
builder(data[:3])

[('21645374-0',
  [0.12863560020923615,
   -0.5226883888244629,
   -0.04859215393662453,
   0.012381777167320251,
   0.09862420707941055,
   -0.20195111632347107,
   -0.11263673007488251,
   -0.18145780265331268,
   0.167202889919281,
   0.6121328473091125,
   0.4899376928806305,
   0.4270266592502594,
   -0.11661393195390701,
   0.11439234763383865,
   0.28221872448921204,
   0.4575916826725006,
   0.3398391306400299,
   0.17317768931388855,
   -0.29793834686279297,
   -0.069961316883564,
   -0.018010497093200684,
   0.1724797487258911,
   -0.008942380547523499,
   0.28541284799575806,
   -0.14820480346679688,
   -0.1004815474152565,
   -0.09481605887413025,
   0.18546342849731445,
   0.10759811848402023,
   -0.0789220929145813,
   -0.0750652626156807,
   -0.2787976861000061,
   -0.3459736704826355,
   0.3100650906562805,
   -0.0001287878694711253,
   -0.28226014971733093,
   0.1778789758682251,
   -0.022834639996290207,
   0.019490480422973633,
   0.5969097018241882,
   -0.0247566904

Now we initialize our connection to Pinecone using a [free API key](https://app.pinecone.io/).

In [61]:
import pinecone

pinecone.init(
    api_key="35b4e620-2cd5-407d-b699-bddaaea88cd0",
    environment="us-west1-gcp"
)

Then create a new **hybrid** index using a `s1h` pod:

In [62]:
index_name = 'james-splade-pubmed'

pinecone.create_index(
    index_name,
    dimension=dim,
    metric="dotproduct",
    pod_type="s1h.x1"
)

In [63]:
index = pinecone.Index(index_name)
index.describe_index_stats()

{'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {},
 'total_vector_count': 0}

Upsert to hybrid is simple:

---

#### Temp Upsert

While waiting for Product to fix issue in Python client

In [71]:
res = pinecone.whoami()
res

WhoAmIResponse(username='c78f2bd', user_label='default', projectname='9a4fbb6')

In [73]:
res.projectname

'9a4fbb6'

In [74]:
host = f"{index_name}-{res.projectname}.svc.us-west1-gcp.pinecone.io"

In [84]:
def to_requests(builder_output: list):
    vecs = [{
        'id': x[0],
        'values': x[1],
        'sparse_values': x[2],
        'metadata': x[3]
    } for x in builder_output]
    return vecs

In [85]:
import requests

headers = headers = {'Api-Key': "35b4e620-2cd5-407d-b699-bddaaea88cd0"}

res = requests.post(
    f"https://{host}/vectors/upsert",
    headers=headers,
    json={'vectors': to_requests(builder(data[:3]))}
)



In [87]:
res, res.json()

(<Response [200]>, {'upsertedCount': 3})

---

We can repeat this and iterate through (and index) the full dataset:

In [88]:
from tqdm.auto import tqdm

batch_size = 64

for i in tqdm(range(0, len(data), batch_size)):
    i_end = min(i+batch_size, len(data))
    batch = data[i:i_end]
    res = requests.post(
        f"https://{host}/vectors/upsert",
        headers=headers,
        json={'vectors': to_requests(builder(data[i:i+batch_size]))}
    )

  0%|          | 0/93 [00:00<?, ?it/s]



We can check the number of upserted records aligns with the length of `data`.

In [89]:
len(data), index.describe_index_stats()

(5930,
 {'dimension': 768,
  'index_fullness': 0.0,
  'namespaces': {'': {'vector_count': 5930}},
  'total_vector_count': 5930})

And now we can move on to querying...

## Queries

Our queries need to contain both dense and sparse vectors, we will define a function `query` to handle the construction of vectors from text and handle the query to Pinecone.

In [96]:
def query(text: str, top_k: int = 3, alpha: float = 0.3):
    # create dense vec
    dense_vec = dense_model.encode(text).tolist()
    # create sparse vec
    input_ids = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        sparse_emb = sparse_model(
            d_kwargs=input_ids
        )['d_rep'].squeeze()
    # convert to dictionary format
    cols = sparse_emb.nonzero().squeeze().cpu().tolist()
    weights = sparse_emb[cols].cpu().tolist()
    cols = [str(idx) for idx in cols]
    sparse_dict = dict(zip(cols, weights))
    # query
    xc = index.query(
        vector=dense_vec,
        sparse_vector=sparse_dict,
        top_k=top_k,
        alpha=alpha,
        include_metadata=True
    )
    return xc

In [97]:
query('hello')



{'matches': [{'id': '24374414-0',
              'metadata': {'context': 'Broad-based electronic health '
                                      'information exchange (HIE), in which '
                                      "patients' clinical data follow them "
                                      'between care delivery settings, is '
                                      'expected to produce large quality gains '
                                      'and cost savings. Although these '
                                      'benefits are assumed to result from '
                                      'reducing redundant care, there is '
                                      'limited supporting empirical evidence. '
                                      'To evaluate whether HIE adoption is '
                                      'associated with decreases in repeat '
                                      'imaging in emergency departments '
                                      '(EDs).'},
