diff --git a/.env b/.env index 03153f0..8228b32 100644 --- a/.env +++ b/.env @@ -14,6 +14,7 @@ WEB_SOURCES=["https://ai-on-openshift.io/getting-started/openshift/", "https://a CHUNK_SIZE=1024 CHUNK_OVERLAP=40 DB_TYPE=DRYRUN +EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2 # === Redis === REDIS_URL=redis://localhost:6379 diff --git a/config.py b/config.py index c5f0f43..bbb7ae8 100644 --- a/config.py +++ b/config.py @@ -42,41 +42,67 @@ def _get_required_env_var(key: str) -> str: raise ValueError(f"{key} environment variable is required.") return value + @staticmethod + def _parse_log_level(log_level_name: str) -> int: + log_levels = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + if log_level_name not in log_levels: + raise ValueError( + f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels.keys())}" + ) + return log_levels[log_level_name] + @staticmethod def _init_db_provider(db_type: str) -> DBProvider: """ Initialize the correct DBProvider subclass based on DB_TYPE. """ + get = Config._get_required_env_var db_type = db_type.upper() + embedding_model = get("EMBEDDING_MODEL") if db_type == "REDIS": - url = Config._get_required_env_var("REDIS_URL") + url = get("REDIS_URL") index = os.getenv("REDIS_INDEX", "docs") schema = os.getenv("REDIS_SCHEMA", "redis_schema.yaml") - return RedisProvider(url, index, schema) + return RedisProvider(embedding_model, url, index, schema) elif db_type == "ELASTIC": - url = Config._get_required_env_var("ELASTIC_URL") - password = Config._get_required_env_var("ELASTIC_PASSWORD") + url = get("ELASTIC_URL") + password = get("ELASTIC_PASSWORD") index = os.getenv("ELASTIC_INDEX", "docs") user = os.getenv("ELASTIC_USER", "elastic") - return ElasticProvider(url, password, index, user) + return ElasticProvider(embedding_model, url, password, index, user) elif db_type == "PGVECTOR": - url = Config._get_required_env_var("PGVECTOR_URL") - collection = Config._get_required_env_var("PGVECTOR_COLLECTION_NAME") - return PGVectorProvider(url, collection) + url = get("PGVECTOR_URL") + collection = get("PGVECTOR_COLLECTION_NAME") + return PGVectorProvider(embedding_model, url, collection) elif db_type == "SQLSERVER": - return SQLServerProvider() # Handles its own env var loading + host = get("SQLSERVER_HOST") + port = get("SQLSERVER_PORT") + user = get("SQLSERVER_USER") + password = get("SQLSERVER_PASSWORD") + database = get("SQLSERVER_DB") + table = get("SQLSERVER_TABLE") + driver = get("SQLSERVER_DRIVER") + return SQLServerProvider( + embedding_model, host, port, user, password, database, table, driver + ) elif db_type == "QDRANT": - url = Config._get_required_env_var("QDRANT_URL") - collection = Config._get_required_env_var("QDRANT_COLLECTION") - return QdrantProvider(url, collection) + url = get("QDRANT_URL") + collection = get("QDRANT_COLLECTION") + return QdrantProvider(embedding_model, url, collection) elif db_type == "DRYRUN": - return DryRunProvider() + return DryRunProvider(embedding_model) raise ValueError(f"Unsupported DB_TYPE '{db_type}'") @@ -99,44 +125,32 @@ def load() -> "Config": get = Config._get_required_env_var # Initialize logger - log_level_name = get("LOG_LEVEL").lower() - log_levels = { - "debug": 10, - "info": 20, - "warning": 30, - "error": 40, - "critical": 50, - } - if log_level_name not in log_levels: - raise ValueError( - f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels)}" - ) - log_level = log_levels[log_level_name] - logging.basicConfig(level=log_level) + log_level = get("LOG_LEVEL").upper() + logging.basicConfig(level=Config._parse_log_level(log_level)) logger = logging.getLogger(__name__) - logger.debug("Logging initialized at level: %s", log_level_name.upper()) + logger.debug("Logging initialized at level: %s", log_level) # Initialize db db_type = get("DB_TYPE") db_provider = Config._init_db_provider(db_type) # Web URLs - web_sources_raw = get("WEB_SOURCES") try: - web_sources = json.loads(web_sources_raw) + web_sources = json.loads(get("WEB_SOURCES")) except json.JSONDecodeError as e: raise ValueError(f"WEB_SOURCES must be a valid JSON list: {e}") # Repo sources - repo_sources_json = get("REPO_SOURCES") try: - repo_sources = json.loads(repo_sources_json) + repo_sources = json.loads(get("REPO_SOURCES")) except json.JSONDecodeError as e: raise ValueError(f"Invalid REPO_SOURCES JSON: {e}") from e - # Misc + # Embedding settings chunk_size = int(get("CHUNK_SIZE")) chunk_overlap = int(get("CHUNK_OVERLAP")) + + # Misc temp_dir = get("TEMP_DIR") return Config( diff --git a/vector_db/db_provider.py b/vector_db/db_provider.py index 168510d..123abc0 100644 --- a/vector_db/db_provider.py +++ b/vector_db/db_provider.py @@ -10,10 +10,13 @@ class DBProvider(ABC): """ Abstract base class for vector DB providers. Subclasses must implement `add_documents`. + + Args: + embedding_model (str): Embedding model to use """ - def __init__(self) -> None: - self.embeddings: Embeddings = HuggingFaceEmbeddings() + def __init__(self, embedding_model: str) -> None: + self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model) @abstractmethod def add_documents(self, docs: List[Document]) -> None: diff --git a/vector_db/dryrun_provider.py b/vector_db/dryrun_provider.py index c8fa41a..a220af3 100644 --- a/vector_db/dryrun_provider.py +++ b/vector_db/dryrun_provider.py @@ -13,14 +13,17 @@ class DryRunProvider(DBProvider): chunked documents to stdout. It is useful for debugging document loading, chunking, and metadata before committing to a real embedding operation. + Args: + embedding_model (str): Embedding model to use + Example: >>> from vector_db.dry_run_provider import DryRunProvider - >>> provider = DryRunProvider() + >>> provider = DryRunProvider("sentence-transformers/all-mpnet-base-v2") >>> provider.add_documents(docs) # docs is a List[Document] """ - def __init__(self): - super().__init__() # ensures embeddings are initialized + def __init__(self, embedding_model: str): + super().__init__(embedding_model) def add_documents(self, docs: List[Document]) -> None: """ diff --git a/vector_db/elastic_provider.py b/vector_db/elastic_provider.py index 415c98e..bb363e0 100644 --- a/vector_db/elastic_provider.py +++ b/vector_db/elastic_provider.py @@ -14,6 +14,7 @@ class ElasticProvider(DBProvider): Elasticsearch-based vector DB provider using LangChain's ElasticsearchStore. Args: + embedding_model (str): Embedding model to use url (str): Full URL to the Elasticsearch cluster (e.g. http://localhost:9200) password (str): Authentication password for the cluster index (str): Index name to use for vector storage @@ -21,6 +22,7 @@ class ElasticProvider(DBProvider): Example: >>> provider = ElasticProvider( + ... embedding_model="sentence-transformers/all-mpnet-base-v2", ... url="http://localhost:9200", ... password="changeme", ... index="docs", @@ -29,8 +31,10 @@ class ElasticProvider(DBProvider): >>> provider.add_documents(chunks) """ - def __init__(self, url: str, password: str, index: str, user: str): - super().__init__() + def __init__( + self, embedding_model: str, url: str, password: str, index: str, user: str + ): + super().__init__(embedding_model) self.db = ElasticsearchStore( embedding=self.embeddings, diff --git a/vector_db/pgvector_provider.py b/vector_db/pgvector_provider.py index 2941dca..035156e 100644 --- a/vector_db/pgvector_provider.py +++ b/vector_db/pgvector_provider.py @@ -18,19 +18,21 @@ class PGVectorProvider(DBProvider): document embeddings in a PostgreSQL-compatible backend with pgvector enabled. Args: + embedding_model (str): Embedding model to use url (str): PostgreSQL connection string (e.g. postgresql://user:pass@host:5432/db) collection_name (str): Name of the pgvector table or collection Example: >>> provider = PGVectorProvider( + ... embedding_model="sentence-transformers/all-mpnet-base-v2", ... url="postgresql://user:pass@localhost:5432/mydb", ... collection_name="documents" ... ) >>> provider.add_documents(chunks) """ - def __init__(self, url: str, collection_name: str): - super().__init__() + def __init__(self, embedding_model: str, url: str, collection_name: str): + super().__init__(embedding_model) self.db = PGVector( connection=url, diff --git a/vector_db/qdrant_provider.py b/vector_db/qdrant_provider.py index 48cd679..d8d3c47 100644 --- a/vector_db/qdrant_provider.py +++ b/vector_db/qdrant_provider.py @@ -16,6 +16,7 @@ class QdrantProvider(DBProvider): Qdrant-based vector DB provider using LangChain's QdrantVectorStore. Args: + embedding_model (str): Embedding model to use url (str): Base URL of the Qdrant service (e.g., http://localhost:6333) collection (str): Name of the vector collection to use or create api_key (Optional[str]): API key if authentication is required (optional) @@ -24,6 +25,7 @@ class QdrantProvider(DBProvider): Example: >>> provider = QdrantProvider( + ... embedding_model="sentence-transformers/all-mpnet-base-v2", ... url="http://localhost:6333", ... collection="embedded_docs", ... api_key=None @@ -31,8 +33,14 @@ class QdrantProvider(DBProvider): >>> provider.add_documents(docs) """ - def __init__(self, url: str, collection: str, api_key: Optional[str] = None): - super().__init__() + def __init__( + self, + embedding_model: str, + url: str, + collection: str, + api_key: Optional[str] = None, + ): + super().__init__(embedding_model) self.collection = collection self.url = url diff --git a/vector_db/redis_provider.py b/vector_db/redis_provider.py index 9fc24b5..e1acc2a 100644 --- a/vector_db/redis_provider.py +++ b/vector_db/redis_provider.py @@ -15,6 +15,7 @@ class RedisProvider(DBProvider): Redis-based vector DB provider using RediSearch and LangChain's Redis integration. Args: + embedding_model (str): Embedding model to use url (str): Redis connection string (e.g. redis://localhost:6379) index (str): RediSearch index name (must be provided via .env) schema (str): Path to RediSearch schema YAML file (must be provided via .env) @@ -24,6 +25,7 @@ class RedisProvider(DBProvider): Example: >>> provider = RedisProvider( + ... embedding_model="sentence-transformers/all-mpnet-base-v2", ... url="redis://localhost:6379", ... index="docs", ... schema="redis_schema.yaml" @@ -31,8 +33,8 @@ class RedisProvider(DBProvider): >>> provider.add_documents(chunks) """ - def __init__(self, url: str, index: str, schema: str): - super().__init__() + def __init__(self, embedding_model: str, url: str, index: str, schema: str): + super().__init__(embedding_model) self.url = url self.index = index self.schema = schema diff --git a/vector_db/sqlserver_provider.py b/vector_db/sqlserver_provider.py index 873e7a0..4b0bac2 100644 --- a/vector_db/sqlserver_provider.py +++ b/vector_db/sqlserver_provider.py @@ -15,6 +15,7 @@ class SQLServerProvider(DBProvider): SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore. Args: + embedding_model (str): Embedding model to use host (str): Hostname of the SQL Server port (str): Port number user (str): SQL login username @@ -25,6 +26,7 @@ class SQLServerProvider(DBProvider): Example: >>> provider = SQLServerProvider( + ... embedding_model="sentence-transformers/all-mpnet-base-v2", ... host="localhost", ... port="1433", ... user="sa", @@ -38,6 +40,7 @@ class SQLServerProvider(DBProvider): def __init__( self, + embedding_model: str, host: str, port: str, user: str, @@ -46,7 +49,7 @@ def __init__( table: str, driver: str, ) -> None: - super().__init__() + super().__init__(embedding_model) self.host = host self.port = port