Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ WEB_SOURCES=["https://ai-on-openshift.io/getting-started/openshift/", "https://a
CHUNK_SIZE=1024
CHUNK_OVERLAP=40
DB_TYPE=DRYRUN
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2

# === Redis ===
REDIS_URL=redis://localhost:6379
Expand Down
80 changes: 47 additions & 33 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,41 +42,67 @@ def _get_required_env_var(key: str) -> str:
raise ValueError(f"{key} environment variable is required.")
return value

@staticmethod
def _parse_log_level(log_level_name: str) -> int:
log_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
if log_level_name not in log_levels:
raise ValueError(
f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels.keys())}"
)
return log_levels[log_level_name]

@staticmethod
def _init_db_provider(db_type: str) -> DBProvider:
"""
Initialize the correct DBProvider subclass based on DB_TYPE.
"""
get = Config._get_required_env_var
db_type = db_type.upper()
embedding_model = get("EMBEDDING_MODEL")

if db_type == "REDIS":
url = Config._get_required_env_var("REDIS_URL")
url = get("REDIS_URL")
index = os.getenv("REDIS_INDEX", "docs")
schema = os.getenv("REDIS_SCHEMA", "redis_schema.yaml")
return RedisProvider(url, index, schema)
return RedisProvider(embedding_model, url, index, schema)

elif db_type == "ELASTIC":
url = Config._get_required_env_var("ELASTIC_URL")
password = Config._get_required_env_var("ELASTIC_PASSWORD")
url = get("ELASTIC_URL")
password = get("ELASTIC_PASSWORD")
index = os.getenv("ELASTIC_INDEX", "docs")
user = os.getenv("ELASTIC_USER", "elastic")
return ElasticProvider(url, password, index, user)
return ElasticProvider(embedding_model, url, password, index, user)

elif db_type == "PGVECTOR":
url = Config._get_required_env_var("PGVECTOR_URL")
collection = Config._get_required_env_var("PGVECTOR_COLLECTION_NAME")
return PGVectorProvider(url, collection)
url = get("PGVECTOR_URL")
collection = get("PGVECTOR_COLLECTION_NAME")
return PGVectorProvider(embedding_model, url, collection)

elif db_type == "SQLSERVER":
return SQLServerProvider() # Handles its own env var loading
host = get("SQLSERVER_HOST")
port = get("SQLSERVER_PORT")
user = get("SQLSERVER_USER")
password = get("SQLSERVER_PASSWORD")
database = get("SQLSERVER_DB")
table = get("SQLSERVER_TABLE")
driver = get("SQLSERVER_DRIVER")
return SQLServerProvider(
embedding_model, host, port, user, password, database, table, driver
)

elif db_type == "QDRANT":
url = Config._get_required_env_var("QDRANT_URL")
collection = Config._get_required_env_var("QDRANT_COLLECTION")
return QdrantProvider(url, collection)
url = get("QDRANT_URL")
collection = get("QDRANT_COLLECTION")
return QdrantProvider(embedding_model, url, collection)

elif db_type == "DRYRUN":
return DryRunProvider()
return DryRunProvider(embedding_model)

raise ValueError(f"Unsupported DB_TYPE '{db_type}'")

Expand All @@ -99,44 +125,32 @@ def load() -> "Config":
get = Config._get_required_env_var

# Initialize logger
log_level_name = get("LOG_LEVEL").lower()
log_levels = {
"debug": 10,
"info": 20,
"warning": 30,
"error": 40,
"critical": 50,
}
if log_level_name not in log_levels:
raise ValueError(
f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels)}"
)
log_level = log_levels[log_level_name]
logging.basicConfig(level=log_level)
log_level = get("LOG_LEVEL").upper()
logging.basicConfig(level=Config._parse_log_level(log_level))
logger = logging.getLogger(__name__)
logger.debug("Logging initialized at level: %s", log_level_name.upper())
logger.debug("Logging initialized at level: %s", log_level)

# Initialize db
db_type = get("DB_TYPE")
db_provider = Config._init_db_provider(db_type)

# Web URLs
web_sources_raw = get("WEB_SOURCES")
try:
web_sources = json.loads(web_sources_raw)
web_sources = json.loads(get("WEB_SOURCES"))
except json.JSONDecodeError as e:
raise ValueError(f"WEB_SOURCES must be a valid JSON list: {e}")

# Repo sources
repo_sources_json = get("REPO_SOURCES")
try:
repo_sources = json.loads(repo_sources_json)
repo_sources = json.loads(get("REPO_SOURCES"))
except json.JSONDecodeError as e:
raise ValueError(f"Invalid REPO_SOURCES JSON: {e}") from e

# Misc
# Embedding settings
chunk_size = int(get("CHUNK_SIZE"))
chunk_overlap = int(get("CHUNK_OVERLAP"))

# Misc
temp_dir = get("TEMP_DIR")

return Config(
Expand Down
7 changes: 5 additions & 2 deletions vector_db/db_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ class DBProvider(ABC):
"""
Abstract base class for vector DB providers.
Subclasses must implement `add_documents`.

Args:
embedding_model (str): Embedding model to use
"""

def __init__(self) -> None:
self.embeddings: Embeddings = HuggingFaceEmbeddings()
def __init__(self, embedding_model: str) -> None:
self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model)

@abstractmethod
def add_documents(self, docs: List[Document]) -> None:
Expand Down
9 changes: 6 additions & 3 deletions vector_db/dryrun_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ class DryRunProvider(DBProvider):
chunked documents to stdout. It is useful for debugging document loading,
chunking, and metadata before committing to a real embedding operation.

Args:
embedding_model (str): Embedding model to use

Example:
>>> from vector_db.dry_run_provider import DryRunProvider
>>> provider = DryRunProvider()
>>> provider = DryRunProvider("sentence-transformers/all-mpnet-base-v2")
>>> provider.add_documents(docs) # docs is a List[Document]
"""

def __init__(self):
super().__init__() # ensures embeddings are initialized
def __init__(self, embedding_model: str):
super().__init__(embedding_model)

def add_documents(self, docs: List[Document]) -> None:
"""
Expand Down
8 changes: 6 additions & 2 deletions vector_db/elastic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ class ElasticProvider(DBProvider):
Elasticsearch-based vector DB provider using LangChain's ElasticsearchStore.

Args:
embedding_model (str): Embedding model to use
url (str): Full URL to the Elasticsearch cluster (e.g. http://localhost:9200)
password (str): Authentication password for the cluster
index (str): Index name to use for vector storage
user (str): Username for Elasticsearch (default: "elastic")

Example:
>>> provider = ElasticProvider(
... embedding_model="sentence-transformers/all-mpnet-base-v2",
... url="http://localhost:9200",
... password="changeme",
... index="docs",
Expand All @@ -29,8 +31,10 @@ class ElasticProvider(DBProvider):
>>> provider.add_documents(chunks)
"""

def __init__(self, url: str, password: str, index: str, user: str):
super().__init__()
def __init__(
self, embedding_model: str, url: str, password: str, index: str, user: str
):
super().__init__(embedding_model)

self.db = ElasticsearchStore(
embedding=self.embeddings,
Expand Down
6 changes: 4 additions & 2 deletions vector_db/pgvector_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@ class PGVectorProvider(DBProvider):
document embeddings in a PostgreSQL-compatible backend with pgvector enabled.

Args:
embedding_model (str): Embedding model to use
url (str): PostgreSQL connection string (e.g. postgresql://user:pass@host:5432/db)
collection_name (str): Name of the pgvector table or collection

Example:
>>> provider = PGVectorProvider(
... embedding_model="sentence-transformers/all-mpnet-base-v2",
... url="postgresql://user:pass@localhost:5432/mydb",
... collection_name="documents"
... )
>>> provider.add_documents(chunks)
"""

def __init__(self, url: str, collection_name: str):
super().__init__()
def __init__(self, embedding_model: str, url: str, collection_name: str):
super().__init__(embedding_model)

self.db = PGVector(
connection=url,
Expand Down
12 changes: 10 additions & 2 deletions vector_db/qdrant_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class QdrantProvider(DBProvider):
Qdrant-based vector DB provider using LangChain's QdrantVectorStore.

Args:
embedding_model (str): Embedding model to use
url (str): Base URL of the Qdrant service (e.g., http://localhost:6333)
collection (str): Name of the vector collection to use or create
api_key (Optional[str]): API key if authentication is required (optional)
Expand All @@ -24,15 +25,22 @@ class QdrantProvider(DBProvider):

Example:
>>> provider = QdrantProvider(
... embedding_model="sentence-transformers/all-mpnet-base-v2",
... url="http://localhost:6333",
... collection="embedded_docs",
... api_key=None
... )
>>> provider.add_documents(docs)
"""

def __init__(self, url: str, collection: str, api_key: Optional[str] = None):
super().__init__()
def __init__(
self,
embedding_model: str,
url: str,
collection: str,
api_key: Optional[str] = None,
):
super().__init__(embedding_model)
self.collection = collection
self.url = url

Expand Down
6 changes: 4 additions & 2 deletions vector_db/redis_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class RedisProvider(DBProvider):
Redis-based vector DB provider using RediSearch and LangChain's Redis integration.

Args:
embedding_model (str): Embedding model to use
url (str): Redis connection string (e.g. redis://localhost:6379)
index (str): RediSearch index name (must be provided via .env)
schema (str): Path to RediSearch schema YAML file (must be provided via .env)
Expand All @@ -24,15 +25,16 @@ class RedisProvider(DBProvider):

Example:
>>> provider = RedisProvider(
... embedding_model="sentence-transformers/all-mpnet-base-v2",
... url="redis://localhost:6379",
... index="docs",
... schema="redis_schema.yaml"
... )
>>> provider.add_documents(chunks)
"""

def __init__(self, url: str, index: str, schema: str):
super().__init__()
def __init__(self, embedding_model: str, url: str, index: str, schema: str):
super().__init__(embedding_model)
self.url = url
self.index = index
self.schema = schema
Expand Down
5 changes: 4 additions & 1 deletion vector_db/sqlserver_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SQLServerProvider(DBProvider):
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore.

Args:
embedding_model (str): Embedding model to use
host (str): Hostname of the SQL Server
port (str): Port number
user (str): SQL login username
Expand All @@ -25,6 +26,7 @@ class SQLServerProvider(DBProvider):

Example:
>>> provider = SQLServerProvider(
... embedding_model="sentence-transformers/all-mpnet-base-v2",
... host="localhost",
... port="1433",
... user="sa",
Expand All @@ -38,6 +40,7 @@ class SQLServerProvider(DBProvider):

def __init__(
self,
embedding_model: str,
host: str,
port: str,
user: str,
Expand All @@ -46,7 +49,7 @@ def __init__(
table: str,
driver: str,
) -> None:
super().__init__()
super().__init__(embedding_model)

self.host = host
self.port = port
Expand Down