In [139]:
import os
import time
from typing import Dict

import base64
import polars as pl
import torch
from tqdm import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer


model_id = 'naver/efficient-splade-VI-BT-large-doc'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)  # Add .to(device) if using a specific device

def get_sparse_doc(text: str) -> Dict[int, float]:
    """
    Generate a sparse document representation. 
    
    In order to improve search, non-english characters, punctionation, and 
    single letters are filtered out. I will consider adding them back if the model is fine-tuned on the
    quotes dataset.
    """
    # Tokenize texts
    tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True)
    # Perform inference
    output = model(**tokens)
    
    # Apply log(1+ReLU(x)) to logits and sum across tokens
    vec = torch.log(1 + torch.relu(output.logits)).sum(axis=1).squeeze()
    
    # Extract non-zero elements as sparse representation
    cols = vec.nonzero(as_tuple=False).squeeze()
    weights = vec[cols].squeeze().detach()
    
    # Filter tokens based on conditions and round weights
    sparse_dict = {idx.item(): round(weight.item(), 2) 
                   for idx, weight in zip(cols, weights) 
                   if 1996 <= idx < 29612 and weight > 0.5}
    
    return sparse_dict

def ids_to_tokens(sparse_dict: Dict[int, float]) -> Dict[str, float]:
    """
    Convert token IDs in the sparse representation back to token strings.
    """
    # Generate idx to token mapping
    idx2token = {idx: token for token, idx in tokenizer.get_vocab().items()}
    
    return {idx2token[idx]: weight for idx, weight in sparse_dict.items()}


In [140]:
### Example Usage
text = "This is a sample text for transformation."
sparse_doc = get_sparse_doc(text)
print(ids_to_tokens(sparse_doc))

{'change': 0.92, 'text': 1.91, 'converted': 0.88, 'translation': 1.5, 'document': 0.78, 'transition': 0.89, 'texts': 0.93, 'sample': 1.77, 'conversion': 1.31, 'samples': 0.84, 'transformed': 1.77, 'transformation': 3.52, 'trans': 0.54, 'convert': 1.27, 'transform': 2.56, 'merge': 0.51, 'sampling': 0.61, 'converting': 1.05, 'transforming': 2.23, 'transformers': 0.97, 'preview': 0.63, 'transforms': 1.69, 'transformations': 2.11, 'transitions': 0.64}


### Preprocess Quotes Data

In [46]:
def isEnglish(s):
    try:
        s.encode(encoding='utf-8').decode('ascii')
    except UnicodeDecodeError:
        return False
    else:
        return True
    
def batch_iter_slice(data, bs):
    for i in range(0, len(data), bs):
        yield data[i:i+bs]
        

df = pl.read_parquet("quotes.parquet")
df = df.filter(df["quote"].map_elements(lambda x: isEnglish(x)))
df_sample = df.sample(110000) # max size of free Pinecone tier

### Embed Documents

In [None]:
quotes = df_sample["quote"].to_list()
embeddings = []
for quote in tqdm(quotes):
    embeddings.append(get_sparse_doc(quote))

vectors = []
# loop through the data and create dictionaries for upserts
for row, output in zip(df_sample.head(n=152424).to_dicts(), embeddings):
    filtered_output = {k: v for k, v in output.items() if v > 0.5}
    
    if list(filtered_output.keys()) and row["author"] is not None:
        vectors.append({
            'id': str(base64.b64encode(f"1234{row['row_nr']}".encode()).decode()),
            'sparse_values': {'indices': list(filtered_output.keys()), 'values': list(filtered_output.values())},
            'values': [0.0000001], 
            'metadata': {"author": row["author"], "quote": row["quote"]}
        })

### Push to Pinecone

In [70]:
index = pinecone.Index(os.environ["PINECONE_API_KEY"], 
                       os.environ["PINECONE_HOST"])

for batch in batch_iter_slice(vectors, 1000):
    index.upsert(batch)
    time.sleep(0.00001)

### Test Endpoint

In [None]:
curl -X POST http://127.0.0.1:8000/query \
     -H "Content-Type: application/json" \
     -d '{"text": "computers", "filter": 0.01}'