Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ CHUNK_SIZE=1024
CHUNK_OVERLAP=40
DB_TYPE=DRYRUN
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
EMBEDDING_LENGTH=768

# === Redis ===
REDIS_URL=redis://localhost:6379
Expand Down
18 changes: 8 additions & 10 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List

from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings

from vector_db.db_provider import DBProvider
from vector_db.dryrun_provider import DryRunProvider
Expand Down Expand Up @@ -108,40 +109,37 @@ def _init_db_provider(db_type: str) -> DBProvider:
"""
get = Config._get_required_env_var
db_type = db_type.upper()
embedding_model = get("EMBEDDING_MODEL")
embedding_length = int(get("EMBEDDING_LENGTH"))
embeddings = HuggingFaceEmbeddings(model_name=get("EMBEDDING_MODEL"))

if db_type == "REDIS":
url = get("REDIS_URL")
index = os.getenv("REDIS_INDEX", "docs")
return RedisProvider(embedding_model, url, index)
return RedisProvider(embeddings, url, index)

elif db_type == "ELASTIC":
url = get("ELASTIC_URL")
password = get("ELASTIC_PASSWORD")
index = os.getenv("ELASTIC_INDEX", "docs")
user = os.getenv("ELASTIC_USER", "elastic")
return ElasticProvider(embedding_model, url, password, index, user)
return ElasticProvider(embeddings, url, password, index, user)

elif db_type == "PGVECTOR":
url = get("PGVECTOR_URL")
collection = get("PGVECTOR_COLLECTION_NAME")
return PGVectorProvider(embedding_model, url, collection, embedding_length)
return PGVectorProvider(embeddings, url, collection)

elif db_type == "MSSQL":
connection_string = get("MSSQL_CONNECTION_STRING")
table = get("MSSQL_TABLE")
return MSSQLProvider(
embedding_model, connection_string, table, embedding_length
)
return MSSQLProvider(embeddings, connection_string, table)

elif db_type == "QDRANT":
url = get("QDRANT_URL")
collection = get("QDRANT_COLLECTION")
return QdrantProvider(embedding_model, url, collection)
return QdrantProvider(embeddings, url, collection)

elif db_type == "DRYRUN":
return DryRunProvider(embedding_model)
return DryRunProvider(embeddings)

raise ValueError(f"Unsupported DB_TYPE '{db_type}'")

Expand Down
26 changes: 13 additions & 13 deletions vector_db/db_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_huggingface import HuggingFaceEmbeddings


Expand All @@ -11,42 +10,43 @@ class DBProvider(ABC):
Abstract base class for vector database providers.

This class standardizes how vector databases are initialized and how documents
are added to them. All concrete implementations (e.g., Qdrant, FAISS) must
are added to them. All concrete implementations (e.g., Qdrant, Redis) must
subclass `DBProvider` and implement the `add_documents()` method.

Attributes:
embeddings (Embeddings): An instance of HuggingFace embeddings based on the
specified model.
embeddings (HuggingFaceEmbeddings): An instance of HuggingFace embeddings.
embedding_length (int): Dimensionality of the embedding vector.

Args:
embedding_model (str): HuggingFace-compatible model name to be used for computing
dense vector embeddings for documents.
embeddings (HuggingFaceEmbeddings): A preconfigured HuggingFaceEmbeddings instance.

Example:
>>> class MyProvider(DBProvider):
... def add_documents(self, docs):
... print(f"Would add {len(docs)} docs with model {self.embeddings.model_name}")
... print(f"Would add {len(docs)} docs with vector size {self.embedding_length}")

>>> provider = MyProvider("BAAI/bge-small-en")
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
>>> provider = MyProvider(embeddings)
>>> provider.add_documents([Document(page_content="Hello")])
"""

def __init__(self, embedding_model: str) -> None:
def __init__(self, embeddings: HuggingFaceEmbeddings) -> None:
"""
Initialize a DB provider with a specific embedding model.
Initialize a DB provider with a HuggingFaceEmbeddings instance.

Args:
embedding_model (str): The HuggingFace model name to be used for generating embeddings.
embeddings (HuggingFaceEmbeddings): The embeddings object used for vectorization.
"""
self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
self.embeddings: HuggingFaceEmbeddings = embeddings
self.embedding_length: int = len(self.embeddings.embed_query("query"))

@abstractmethod
def add_documents(self, docs: List[Document]) -> None:
"""
Add documents to the vector database.

This method must be implemented by subclasses to define how documents
(with or without precomputed embeddings) are stored in the backend vector DB.
are embedded and stored in the backend vector DB.

Args:
docs (List[Document]): A list of LangChain `Document` objects to be embedded and added.
Expand Down
28 changes: 14 additions & 14 deletions vector_db/dryrun_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings

from vector_db.db_provider import DBProvider

Expand All @@ -9,36 +10,35 @@ class DryRunProvider(DBProvider):
"""
A mock vector DB provider for debugging document loading and chunking.

`DryRunProvider` does not persist any documents or perform embedding operations.
Instead, it prints a preview of the documents and their metadata to stdout,
allowing users to validate chunking, structure, and metadata before pushing
to a production vector store.

Useful for development, testing, or understanding how your documents are
being processed.
`DryRunProvider` does not persist any documents or perform actual embedding.
It prints a preview of the documents and their metadata to stdout, allowing users
to validate chunking, structure, and metadata before pushing to a production vector store.

Attributes:
embeddings (Embeddings): HuggingFace embedding model for compatibility.
embeddings (HuggingFaceEmbeddings): HuggingFace embedding instance, used for interface consistency.
embedding_length (int): Dimensionality of embeddings (computed for validation, not used).

Args:
embedding_model (str): The model name to initialize HuggingFaceEmbeddings.
Used only for compatibility — no embeddings are generated.
embeddings (HuggingFaceEmbeddings): A HuggingFace embedding model instance.

Example:
>>> from langchain_core.documents import Document
>>> provider = DryRunProvider("BAAI/bge-small-en")
>>> from langchain_huggingface import HuggingFaceEmbeddings
>>> from vector_db.dryrun_provider import DryRunProvider
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
>>> provider = DryRunProvider(embeddings)
>>> docs = [Document(page_content="Hello world", metadata={"source": "test.txt"})]
>>> provider.add_documents(docs)
"""

def __init__(self, embedding_model: str):
def __init__(self, embeddings: HuggingFaceEmbeddings):
"""
Initialize the dry run provider with a placeholder embedding model.

Args:
embedding_model (str): The name of the embedding model (used for interface consistency).
embeddings (HuggingFaceEmbeddings): A HuggingFace embedding model (used for compatibility).
"""
super().__init__(embedding_model)
super().__init__(embeddings)

def add_documents(self, docs: List[Document]) -> None:
"""
Expand Down
29 changes: 16 additions & 13 deletions vector_db/elastic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from langchain_core.documents import Document
from langchain_elasticsearch.vectorstores import ElasticsearchStore
from langchain_huggingface import HuggingFaceEmbeddings

from vector_db.db_provider import DBProvider

Expand All @@ -13,25 +14,27 @@ class ElasticProvider(DBProvider):
"""
Vector database provider backed by Elasticsearch using LangChain's ElasticsearchStore.

This provider allows storing and querying vectorized documents in an Elasticsearch
cluster. Documents are embedded using a HuggingFace model and stored with associated
metadata in the specified index.
This provider stores and queries vectorized documents in an Elasticsearch cluster.
Documents are embedded using the provided HuggingFace embeddings model and stored
with associated metadata in the specified index.

Attributes:
db (ElasticsearchStore): LangChain-compatible wrapper around Elasticsearch vector storage.
embeddings (Embeddings): HuggingFace embedding model for generating document vectors.
db (ElasticsearchStore): LangChain-compatible Elasticsearch vector store.
embeddings (HuggingFaceEmbeddings): HuggingFace embedding model instance.

Args:
embedding_model (str): HuggingFace model name for computing embeddings.
url (str): Full URL to the Elasticsearch cluster (e.g. "http://localhost:9200").
embeddings (HuggingFaceEmbeddings): Pre-initialized embeddings instance.
url (str): Full URL to the Elasticsearch cluster (e.g., "http://localhost:9200").
password (str): Password for the Elasticsearch user.
index (str): The index name where documents will be stored.
user (str): Elasticsearch username (default is typically "elastic").

Example:
>>> from langchain_huggingface import HuggingFaceEmbeddings
>>> from vector_db.elastic_provider import ElasticProvider
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
>>> provider = ElasticProvider(
... embedding_model="BAAI/bge-small-en",
... embeddings=embeddings,
... url="http://localhost:9200",
... password="changeme",
... index="rag-docs",
Expand All @@ -42,7 +45,7 @@ class ElasticProvider(DBProvider):

def __init__(
self,
embedding_model: str,
embeddings: HuggingFaceEmbeddings,
url: str,
password: str,
index: str,
Expand All @@ -52,13 +55,13 @@ def __init__(
Initialize an Elasticsearch-based vector DB provider.

Args:
embedding_model (str): The model name for computing embeddings.
embeddings (HuggingFaceEmbeddings): HuggingFace embeddings instance.
url (str): Full URL of the Elasticsearch service.
password (str): Elasticsearch user's password.
index (str): Name of the Elasticsearch index to use.
user (str): Elasticsearch username (e.g., "elastic").
"""
super().__init__(embedding_model)
super().__init__(embeddings)

self.db = ElasticsearchStore(
embedding=self.embeddings,
Expand All @@ -74,8 +77,8 @@ def add_documents(self, docs: List[Document]) -> None:
"""
Add a batch of LangChain documents to the Elasticsearch index.

Each document will be embedded using the configured model and stored
in the specified index with any associated metadata.
Each document is embedded using the provided model and stored
in the specified index with its associated metadata.

Args:
docs (List[Document]): List of documents to index.
Expand Down
51 changes: 12 additions & 39 deletions vector_db/mssql_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyodbc
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_sqlserver import SQLServer_VectorStore

from vector_db.db_provider import DBProvider
Expand All @@ -16,49 +17,45 @@ class MSSQLProvider(DBProvider):
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore integration.

This provider connects to a Microsoft SQL Server instance using a full ODBC connection string,
and stores document embeddings in a specified table. If the target database does not exist,
it will be created automatically.
and stores document embeddings in a specified table. The target database will be created if it
does not already exist.

Attributes:
db (SQLServer_VectorStore): Underlying LangChain-compatible vector store.
connection_string (str): Full ODBC connection string to the SQL Server instance.

Args:
embedding_model (str): HuggingFace-compatible embedding model to use.
embeddings (HuggingFaceEmbeddings): Pre-initialized embeddings instance.
connection_string (str): Full ODBC connection string (including target DB).
table (str): Table name to store vector embeddings.
embedding_length (int): Dimensionality of the embeddings (e.g., 768 for all-mpnet-base-v2).

Example:
>>> from langchain_huggingface import HuggingFaceEmbeddings
>>> from vector_db.mssql_provider import MSSQLProvider
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
>>> provider = MSSQLProvider(
... embedding_model="BAAI/bge-large-en-v1.5",
... embeddings=embeddings,
... connection_string="Driver={ODBC Driver 18 for SQL Server};Server=localhost,1433;Database=docs;UID=sa;PWD=StrongPassword!;TrustServerCertificate=yes;Encrypt=no;",
... table="embedded_docs",
... embedding_length=768,
... )
>>> provider.add_documents(docs)
"""

def __init__(
self,
embedding_model: str,
embeddings: HuggingFaceEmbeddings,
connection_string: str,
table: str,
embedding_length: int,
) -> None:
"""
Initialize the MSSQLProvider.

Args:
embedding_model (str): HuggingFace-compatible embedding model to use for generating embeddings.
embeddings (HuggingFaceEmbeddings): HuggingFace-compatible embedding model instance.
connection_string (str): Full ODBC connection string including target database name.
table (str): Table name to store document embeddings.
embedding_length (int): Size of the embeddings (number of dimensions).

Raises:
RuntimeError: If the database specified in the connection string cannot be found or created.
"""
super().__init__(embedding_model)
super().__init__(embeddings)

self.connection_string = connection_string
self.table = table
Expand All @@ -77,36 +74,18 @@ def __init__(
connection_string=self.connection_string,
embedding_function=self.embeddings,
table_name=self.table,
embedding_length=embedding_length,
embedding_length=self.embedding_length,
)

def _extract_server_address(self) -> str:
"""
Extract the server address (host,port) from the connection string.

Returns:
str: The server address portion ("host,port") or "unknown" if not found.
"""
match = re.search(r"Server=([^;]+)", self.connection_string, re.IGNORECASE)
return match.group(1) if match else "unknown"

def _extract_database_name(self) -> Optional[str]:
"""
Extract the database name from the connection string.

Returns:
str: Database name if found, else None.
"""
match = re.search(r"Database=([^;]+)", self.connection_string, re.IGNORECASE)
return match.group(1) if match else None

def _build_connection_string_for_master(self) -> str:
"""
Modify the connection string to point to the 'master' database.

Returns:
str: Modified connection string.
"""
parts = self.connection_string.split(";")
updated_parts = [
"Database=master" if p.strip().lower().startswith("database=") else p
Expand All @@ -116,12 +95,6 @@ def _build_connection_string_for_master(self) -> str:
return ";".join(updated_parts) + ";"

def _ensure_database_exists(self) -> None:
"""
Connect to the SQL Server master database and create the target database if missing.

Raises:
RuntimeError: If the database cannot be created or accessed.
"""
database = self._extract_database_name()
if not database:
raise RuntimeError("No database name found in connection string.")
Expand Down
Loading