In [None]:
"""
This module parses the config.yaml file and verifies it with pydantic,
creating a Config object populated with the specified configurations.
"""
from pathlib import Path

from hydra import compose, initialize_config_dir
from pydantic import BaseModel, Field


ROOT = Path().resolve()
CONFIG_DIR = ROOT / "examples" / "configs" / "project_config"


class _DatabaseConfig(BaseModel):
    """
    Configuration class for the database.

    Attributes:
        project_id (str): The project ID.
    """

    project_id: str = Field(..., min_length=1, max_length=30)


class _TrainingConfig(BaseModel):
    """
    Configuration class for training.

    Attributes:
        savepath_model (str): The path to save the trained model.
        model_name (str): The name of the model file. Must match the pattern "^*.pkl$".
        dep_variable (str): The dependent variable for the training.
        num_layers (int): The number of layers for the training.
    """

    savepath_model: str = Field(pattern=r"^models/.*$")
    model_name: str = Field(pattern=r"^*.pkl$")
    dep_variable: str
    num_layers: int = Field(..., ge=1)


class _InferenceConfig(BaseModel):
    """
    Configuration class for inference settings.

    Attributes:
        savepath_inference (str): The path to save the inference results.
        inference_name (str): The name of the inference file. Must match the pattern "^*.json$".
    """

    savepath_inference: str = Field(pattern=r"^data/.*$")
    inference_name: str = Field(pattern=r"^*.json$")


class _PreprocessingConfig(BaseModel):
    """
    Configuration class for preprocessing data.

    Attributes:
        savepath_raw (str): The save path for raw data.
        savepath_report (str): The save path for the report.
        savepath_processed (str): The save path for the processed data.
        savepath_features (str): The save path for the features.
        path_queries (str): The path to the queries.
        raw_name (str): The name pattern for raw data files.
        query_name (str): The name pattern for query files.
        report_name (str): The name pattern for report files.
        report_sheet (str): The name of the report sheet.
        processed_name (str): The name pattern for processed data files.
        train_name (str): The name pattern for train data files.
        test_name (str): The name pattern for test data files.
        feature_names (Optional[List[str]]): The list of feature names (optional).
    """

    savepath_raw: str = Field(pattern=r"^data/.*$")
    savepath_report: str = Field(pattern=r"^reports/.*$")
    savepath_processed: str = Field(pattern=r"^data/.*$")
    savepath_features: str = Field(pattern=r"^data/.*$")
    path_queries: str
    raw_name: str = Field(pattern=r"^*.parquet$")
    query_name: str = Field(pattern=r"^*.sql$")
    report_name: str = Field(pattern=r"^*.xlsx$")
    report_sheet: str = Field(..., min_length=1)
    processed_name: str = Field(pattern=r"^*.parquet$")
    train_name: str = Field(pattern=r"^*.parquet$")
    test_name: str = Field(pattern=r"^*.parquet$")
    feature_names: list[str] | None = Field(default=None)


class Config(BaseModel):
    """Configuration object populated with the specified configurations.

    Attributes:
        database (DatabaseConfig): The database configuration.
        inference (InferenceConfig): The inference configuration.
        training (TrainingConfig): The training configuration.
        preprocessing (PreprocessingConfig): The preprocessing configuration.
    """

    database: _DatabaseConfig
    inference: _InferenceConfig
    training: _TrainingConfig
    preprocessing: _PreprocessingConfig


def create_config(
    *,
    config_path: Path | None = CONFIG_DIR,
    config_name: str | None = "config",
    overrides: list[str] | None = None,
) -> Config:
    """
    Build a configuration object by combining Hydra's configuration system
    with Pydantic's data validation capabilities.

    Args:
        config_path (str): The path to the directory where configuration files are located.
        config_name (str): The name of the configuration to be loaded.
        overrides (list, optional): An optional list of configurations to be overridden.
    Returns:
        Config: A validated configuration object.
    """
    if overrides is None:
        overrides = []

    with initialize_config_dir(version_base=None, config_dir=str(config_path)):
        cfg = compose(config_name=config_name, overrides=overrides)

    return Config(**dict(cfg))

In [None]:
# To load the configuration as a class, call the create_config() function
cfg = create_config()

In [None]:
# Now you can access the config as an object
print(cfg)

# And access the different configuration modules via its attributes, e.g.
print(cfg.preprocessing.savepath_report)