Skip to content

Commit

Permalink
Updated embedding svc (#169)
Browse files Browse the repository at this point in the history
* updates to pdf table parser

* Updated embedding svc class
  • Loading branch information
S1LV3RJ1NX authored May 27, 2024
1 parent f1e1680 commit d8a6257
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 104 deletions.
106 changes: 17 additions & 89 deletions backend/modules/embedder/embedding_svc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import concurrent.futures
from typing import Dict, List
from typing import List, Union

import requests
from langchain.embeddings.base import Embeddings
from tqdm.auto import tqdm

from backend.logger import logger
from backend.settings import settings


Expand All @@ -28,95 +25,26 @@ class EmbeddingSvc(Embeddings):
def __init__(self, **kwargs) -> None:
# ideally get url from settings
self.url = settings.EMBEDDING_SVC_URL
# 8 tokens for query and 2 for start and stop tokens
self.embedding_ctx_length = 500
self.chunk_size = 4

def encode(self, text: str) -> Dict:
response = requests.post(
f"{self.url}/tokenize",
json={"inputs": text, "add_special_tokens": False},
)

if response.status_code != 200:
logger.error(f"Tokenization failed: {response.text}")
return {}
return [token["id"] for tokens_list in response.json() for token in tokens_list]

def decode(self, token_ids: List[int]) -> str:
response = requests.post(
f"{self.url}/decode",
json={"ids": token_ids, "skip_special_tokens": True},
)
if response.status_code != 200:
logger.error(f"Detokenization failed: {response.text}")
return []
return response.json()

def embed_all(self, tokens_batch: List[List[str]]) -> List[List[float]]:
logger.info("Embedding...")
response = requests.post(
f"{self.url}/embed_all",
json={"inputs": tokens_batch, "truncate": True},
)
if response.status_code != 200:
logger.error(f"Embedding failed: {response.text}")
return []
logger.info("Embedding done...")
return response.json()

# For retrieval you need to pass this prompt.
def transform_query(self, query: str) -> str:
"""For retrieval, add the prompt for query (not for documents)."""
return f"Represent this sentence for searching relevant passages: {query}"

def _get_len_safe_embeddings(self, texts: str) -> List[List[float]]:
"""
Generate length-safe embeddings for a list of texts.
This method handles tokenization and embedding generation,
respecting the set embedding context length and chunk size.
"""

tokens = []
indices = []

def process_text(text):
tokenized = self.encode(text)
# Split tokens into chunks respecting the embedding_ctx_length
for j in range(0, len(tokenized), self.embedding_ctx_length):
token_chunk = tokenized[j : j + self.embedding_ctx_length]
# Convert token IDs back to a string
chunk_text = self.decode(token_chunk)[0]
tokens.append(chunk_text)
indices.append(j)

logger.info("Chunking...")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(process_text, text)
for text in tqdm(texts, total=len(texts))
]
concurrent.futures.wait(futures)

logger.info("Embedding...")
_iter = tqdm(range(0, len(tokens), self.chunk_size))
batched_embeddings: List[List[float]] = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.embed_all, tokens[i : i + self.chunk_size])
for i in _iter
]
concurrent.futures.wait(futures)
for future in futures:
embeddings = future.result()
batched_embeddings.extend(embeddings)
return batched_embeddings[0]
def call_embedding_service(
self, texts: Union[str, List[str]], type: str
) -> Union[List[float], List[List[float]]]:
"""Call the embedding service."""
if type == "query":
response = requests.post(
self.url.rstrip("/") + "/embed-query", json={"text": texts}
)
elif type == "documents":
response = requests.post(
self.url.rstrip("/") + "/embed-documents", json={"texts": texts}
)
response.raise_for_status()
return response.json()["embeddings"]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return self._get_len_safe_embeddings(texts)
return self.call_embedding_service(texts, "documents")

def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([self.transform_query(text)])[0]
return self.call_embedding_service(text, "query")
2 changes: 1 addition & 1 deletion backend/modules/query_controllers/example/payload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
PROMPT = "You are an AI assistant specialising in information retieval and analysis. Answer the following question based only on the following context:\nContext: {context} \nQuestion: {question}"
PROMPT = "You are an AI assistant specialising in information retrieval and analysis. Answer the following question based only on the given context:\nContext: {context} \nQuestion: {question}"
QUERY_WITH_VECTOR_STORE_RETRIEVER_SIMILARITY = {
"collection_name": "creditcard",
"query": "Explain in detail different categories of credit cards",
Expand Down
40 changes: 26 additions & 14 deletions local.metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@
# config:
# model: "mixedbread-ai/mxbai-embed-large-v1"

# Test 2
collection_name: sdocs2string
data_source:
type: localdir
uri: /Users/prathamesh/Desktop/s-docs
parser_config:
chunk_size: 1000
chunk_overlap: 40
parser_map:
".pdf": PdfTableParser
embedder_config:
provider: mixedbread
config:
model: "mixedbread-ai/mxbai-embed-large-v1"
# # Test 2
# collection_name: sdocs2string
# data_source:
# type: localdir
# uri: /Users/prathamesh/Desktop/s-docs
# parser_config:
# chunk_size: 1000
# chunk_overlap: 40
# parser_map:
# ".pdf": PdfTableParser
# embedder_config:
# provider: mixedbread
# config:
# model: "mixedbread-ai/mxbai-embed-large-v1"
# Test 3
# collection_name: sdocs2stringsvc
# data_source:
# type: localdir
# uri: /Users/prathamesh/Desktop/s-docs
# parser_config:
# chunk_size: 512
# chunk_overlap: 40
# parser_map:
# ".pdf": PdfTableParser
# embedder_config:
# provider: embedding-svc

0 comments on commit d8a6257

Please sign in to comment.