## Implementation of sparse embeddings using SPLADE

References:
* https://www.pinecone.io/learn/splade/

In [2]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import tqdm
import numpy as np
import scipy
from typing import List
from llmsearch.utils import set_cache_folder
from llmsearch.config import get_config, Document
from llmsearch.parsers.splitter import DocumentSplitter

In [3]:
torch.cuda.empty_cache()
torch.cuda.memory_allocated()

0

In [4]:
CACHE_FOLDER = "/storage/llm/cache"
CONFIG_PATH = "dev_config.yaml"

In [179]:
device = (f"cuda:{torch.cuda.current_device()}"
          if torch.cuda.is_available()
                else "cpu"
            )

print(device)

cuda:0


In [116]:

set_cache_folder(CACHE_FOLDER)

config = get_config(CONFIG_PATH)

[32m2023-08-19 20:40:13.581[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m33[0m - [1mSetting SENTENCE_TRANSFORMERS_HOME folder: /storage/llm/cache[0m
[32m2023-08-19 20:40:13.582[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m34[0m - [1mSetting TRANSFORMERS_CACHE folder: /storage/llm/cache/transformers[0m
[32m2023-08-19 20:40:13.582[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m35[0m - [1mSetting HF_HOME: /storage/llm/cache/hf_home[0m
[32m2023-08-19 20:40:13.582[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m36[0m - [1mSetting MODELS_CACHE_FOLDER: /storage/llm/cache[0m
[32m2023-08-19 20:40:13.586[0m | [1mINFO    [0m | [36mllmsearch.config[0m:[36mvalidate_params[0m:[36m115[0m - [1mLoading model paramaters in configuration class OpenAIModelConfig[0m


In [117]:
model_id = 'naver/splade-cocondenser-ensembledistil'

tokenizer = AutoTokenizer.from_pretrained(model_id, device = device, use_fast = True)
model = AutoModelForMaskedLM.from_pretrained(model_id)

In [118]:
model.to(device)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [119]:
model.device

device(type='cuda', index=0)

In [120]:
splitter = DocumentSplitter(config)

In [121]:
docs = splitter.split()

[32m2023-08-19 20:40:18.104[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36msplit[0m:[36m42[0m - [1mScanning path for extension: md[0m


[32m2023-08-19 20:40:18.154[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m87[0m - [1mWill add the following passage prefix: passage: [0m
[32m2023-08-19 20:40:18.155[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m90[0m - [1mProcessing path using custom splitter: /home/snexus/projects/knowledge-base/financial-risk-management.md, chunk size: 1024[0m
[32m2023-08-19 20:40:18.155[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m96[0m - [1m/home/snexus/projects/knowledge-base/financial-risk-management.md[0m
[32m2023-08-19 20:40:18.155[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m90[0m - [1mProcessing path using custom splitter: /home/snexus/projects/knowledge-base/deployment.md, chunk size: 1024[0m
[32m2023-08-19 20:40:18.1

In [122]:
docs[0].page_content

'passage: # Techincal indicators\n* Market breadh (% of stocks above 200 SMA) - indicates how healthy is the broad market rally is - **S5TH**  in trading view\n\n* Shot term RSI divergence - weak trend. RSI + trend in the same direction - strong trend:\n* ![[Pasted image 20220110140302.png]]\n\n\n# Macro indicators\n* [[financial-risk-macro-indicators]]\n* [[dao-of-capital-austrian-investing]]'

In [123]:
def split(list_a, chunk_size):

  for i in range(0, len(list_a), chunk_size):
    yield list_a[i:i + chunk_size]

In [124]:
def get_splade_embeddings(docs: List[str], device: str) -> np.ndarray:
    
    tokens = tokenizer(
    docs, return_tensors='pt',
    padding=True, truncation=True
    ).to(device)
    
    output = model(**tokens)
    # print(output)
    
    # aggregate the token-level vecs and transform to sparse
    vecs = torch.max(
        torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), dim=1
    )[0].squeeze().detach().cpu().numpy()
    
    del output
    del tokens
    torch.cuda.synchronize()
    
    return vecs

In [None]:
### Generate SPLADE Embeddings

In [128]:
vecs = []
for chunk in tqdm.tqdm(split(docs, chunk_size=5)):
    texts = [d.page_content for d in chunk]
    vecs.append(get_splade_embeddings(texts, device=device))



[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

In [129]:
vecs_flat = [item for row in vecs for item in row]
embeddings = np.stack(vecs_flat)
csr_embeddings = scipy.sparse.csr_matrix(embeddings)

In [196]:
csr_embeddings.shape

(5760, 30522)

In [168]:
print("Size of the csr matrix: ", csr_embeddings.data.nbytes)

Size of the csr matrix:  3353656


In [169]:
print("Sparsity: ", csr_embeddings.count_nonzero() / (csr_embeddings.shape[0] * csr_embeddings.shape[1]))

Sparsity:  0.004768953086662444


In [161]:
scipy.sparse.save_npz("splade_embeddings.npz", csr_embeddings)

### Get document relevant to the query

In [206]:
query = "What type of hashing schemes exist in databases?"
embed_query = get_splade_embeddings([query], device=device)

In [207]:
embed_query

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [208]:
l2_norm_matrix = scipy.sparse.linalg.norm(csr_embeddings, axis=1)
l2_norm_query =  scipy.linalg.norm(embed_query)
print(l2_norm_matrix, l2_norm_query)

[10.31852    4.5923758 10.460884  ... 10.833925   7.622963   7.701627 ] 5.328778266906738


In [209]:
cosine_similarity = csr_embeddings.dot(embed_query) / (l2_norm_matrix * l2_norm_query)

In [210]:
most_similar = np.argsort(cosine_similarity)

In [211]:
top_similar_indices = most_similar[-10:][::-1]

In [212]:
top_similar_indices

array([1094, 1088, 1089, 1091, 1097, 1095, 2020, 2018, 1096, 1090])

In [213]:
cosine_similarity[most_similar]

array([0.        , 0.        , 0.        , ..., 0.31968024, 0.3263687 ,
       0.33867854], dtype=float32)

In [214]:
for ind in top_similar_indices:
    print("---------------------")
    print(docs[ind].page_content)

---------------------
passage: Metadata applicable to the next chunk of text delimited by five stars:
>> METADATA START
Document name: cmu-databases-hash-tables-07.md
Subsection of: Dynamic Hashing Schemes
>> METADATA END

*****
# Dynamic Hashing Schemes


Dynamic hash tables resize themselves on demand
*****
---------------------
passage: Metadata applicable to the next chunk of text delimited by five stars:
>> METADATA START
Document name: cmu-databases-hash-tables-07.md
Subsection of: Static Hashing Schemes
>> METADATA END

*****
# Static Hashing Schemes

>[!INFO] Specifying number of potential locations ahead of time
*****
---------------------
passage: Metadata applicable to the next chunk of text delimited by five stars:
>> METADATA START
Document name: cmu-databases-hash-tables-07.md
Subsection of: Static Hashing Schemes
>> METADATA END

*****
## Linear Probe Hashing

>[!INFO] Most frequently used hashing technique

hash(key) mod N

See also [python-dd-hashmaps-theory](python-dd

## Testing packaged code

In [1]:
from llmsearch.splade import SparseEmbeddingsSplade
from llmsearch.config import get_config
from llmsearch.parsers.splitter import DocumentSplitter
import scipy
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060


In [2]:

CONFIG_PATH = "dev_config.yaml"
config = get_config(CONFIG_PATH)

[32m2023-08-20 17:59:13.462[0m | [1mINFO    [0m | [36mllmsearch.config[0m:[36mvalidate_params[0m:[36m115[0m - [1mLoading model paramaters in configuration class OpenAIModelConfig[0m


In [3]:
splitter = DocumentSplitter(config)
docs = splitter.split()

[32m2023-08-20 17:59:13.481[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36msplit[0m:[36m44[0m - [1mScanning path for extension: md[0m
[32m2023-08-20 17:59:13.492[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m89[0m - [1mWill add the following passage prefix: passage: [0m
[32m2023-08-20 17:59:13.492[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m92[0m - [1mProcessing path using custom splitter: /home/snexus/projects/knowledge-base/financial-risk-management.md, chunk size: 1024[0m
[32m2023-08-20 17:59:13.493[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[0m:[36m98[0m - [1m/home/snexus/projects/knowledge-base/financial-risk-management.md[0m
[32m2023-08-20 17:59:13.493[0m | [1mINFO    [0m | [36mllmsearch.parsers.splitter[0m:[36m_get_documents_from_custom_splitter[

[32m2023-08-20 17:59:13.602[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mpostprocess_sections[0m:[36m370[0m - [1mRemoving first section ...[0m
[32m2023-08-20 17:59:13.603[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mmarkdown_splitter[0m:[36m315[0m - [1mGot 21 text chunks:[0m
[32m2023-08-20 17:59:13.603[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mmarkdown_splitter[0m:[36m317[0m - [1m	Chunk length: 281[0m
[32m2023-08-20 17:59:13.603[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mmarkdown_splitter[0m:[36m317[0m - [1m	Chunk length: 362[0m
[32m2023-08-20 17:59:13.604[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mmarkdown_splitter[0m:[36m317[0m - [1m	Chunk length: 417[0m
[32m2023-08-20 17:59:13.604[0m | [1mINFO    [0m | [36mllmsearch.parsers.markdown[0m:[36mmarkdown_splitter[0m:[36m317[0m - [1m	Chunk length: 318[0m
[32m2023-08-20 17:59:13.604[0m |

In [4]:
def get_docs_by_ids(docs, ids):
    ids = set(ids)
    return [d for d in docs if d.metadata['document_id'] in ids]

In [5]:
len(docs)

5760

In [6]:

splade = SparseEmbeddingsSplade(config=config)

[32m2023-08-20 17:59:17.519[0m | [1mINFO    [0m | [36mllmsearch.splade[0m:[36m__init__[0m:[36m24[0m - [1mSetting device to cuda:0[0m
[32m2023-08-20 17:59:17.520[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m33[0m - [1mSetting SENTENCE_TRANSFORMERS_HOME folder: /storage/llm/cache[0m
[32m2023-08-20 17:59:17.520[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m34[0m - [1mSetting TRANSFORMERS_CACHE folder: /storage/llm/cache/transformers[0m
[32m2023-08-20 17:59:17.520[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m35[0m - [1mSetting HF_HOME: /storage/llm/cache/hf_home[0m
[32m2023-08-20 17:59:17.521[0m | [1mINFO    [0m | [36mllmsearch.utils[0m:[36mset_cache_folder[0m:[36m36[0m - [1mSetting MODELS_CACHE_FOLDER: /storage/llm/cache[0m


In [7]:
_ = splade.generate_embeddings_from_docs(docs)

[32m2023-08-20 17:59:20.459[0m | [1mINFO    [0m | [36mllmsearch.splade[0m:[36mgenerate_embeddings_from_docs[0m:[36m85[0m - [1mCalculating SPLADE embeddings for 5760 documents.[0m
1152it [01:23, 13.86it/s]
[32m2023-08-20 18:00:44.579[0m | [1mINFO    [0m | [36mllmsearch.splade[0m:[36mgenerate_embeddings_from_docs[0m:[36m107[0m - [1mSaved embeddings to /storage/llm/temp_embeddings/splade/splade_embeddings.npz[0m


In [8]:

query = "How to merge new updates to a delta table?"
ids, scores = splade.query(search= query, n = 10)
print(scores)

[32m2023-08-20 18:00:44.586[0m | [1mINFO    [0m | [36mllmsearch.splade[0m:[36mquery[0m:[36m125[0m - [1mLoading embeddings...[0m
[32m2023-08-20 18:00:44.610[0m | [1mINFO    [0m | [36mllmsearch.splade[0m:[36mload[0m:[36m74[0m - [1mLoaded embeddings from /storage/llm/temp_embeddings/splade/splade_embeddings.npz[0m


[0.00566773 0.         0.02607495 ... 0.00730678 0.0052176  0.00503043]
[0.30086386 0.28034347 0.27854884 0.2742592  0.26999858 0.2664942
 0.2602524  0.24732761 0.24313846 0.23810072]


In [11]:
relevant_docs = get_docs_by_ids(docs, ids)

In [12]:
for d in relevant_docs:
    print("XXXXXXXXXXXXXXXXX")
    print(d.page_content)

XXXXXXXXXXXXXXXXX
passage: Metadata applicable to the next chunk of text delimited by five stars:
>> METADATA START
Document name: github.md
Subsection of: Branches
>> METADATA END

*****
## Creating a new branch
* Checkout moves HEAD to new branch
* Switch branching updates the working directory


Following is a code section in bash, delimited by triple backticks:
```bash

# Create branch1
git branch branch1
git checkout branch1

# Create and switch to branch1
git checkout -c branch1

# Checking difference between branches
git diff main..branch1

# Merge changes from branch1
# 1. Switch back to main
git switch main
# 2. merge branch1
git merge branch1
# 3. delete branch1 because it is not needed
git branch -d branch1


```
*****
XXXXXXXXXXXXXXXXX
passage: Metadata applicable to the next chunk of text delimited by five stars:
>> METADATA START
Document name: databricks-silver-ingestion-patterns.md
Subsection of: Streaming Deduplication
>> METADATA END

*****
## Batch deduplication


Fo