In [1]:
from dim_est.datasets.data_generation import make_data_generator
from dim_est.config.dataset_defaults import DATASET_DEFAULTS 
from typing import Any, Dict, Optional, Mapping
import copy

def merge_with_validation(
    defaults: Mapping[str, Any],
    overrides: Mapping[str, Any],
    error_prefix: str = "",
    _path: str = "",
) -> dict:
    """
    Recursively merge `overrides` into `defaults`, validating keys.

    - Any key in `overrides` that does not exist in `defaults` raises KeyError.
    - If the default value is a dict, the override must also be a dict, and we recurse.
    - Returns a *new* merged dict, does not mutate `defaults`.
    """
    # Always work on a deep copy so we never mutate the defaults in-place
    merged = copy.deepcopy(defaults)

    for k, v in overrides.items():
        if k not in defaults:
            prefix = (error_prefix + ": ") if error_prefix else ""
            full_path = f"{_path}{k}"
            raise KeyError(
                f"{prefix}Invalid override key '{full_path}'. "
                f"Allowed keys at this level: {list(defaults.keys())}"
            )

        default_val = defaults[k]

        # If the default is a dict, we expect a dict and recurse
        if isinstance(default_val, dict):
            if not isinstance(v, Mapping):
                prefix = (error_prefix + ": ") if error_prefix else ""
                full_path = f"{_path}{k}"
                raise TypeError(
                    f"{prefix}Override for '{full_path}' must be a mapping, "
                    f"got {type(v).__name__}"
                )

            merged[k] = merge_with_validation(
                default_val, v, error_prefix=error_prefix, _path=f"{_path}{k}."
            )

        else:
            # Leaf value: just override
            merged[k] = v

    return merged

In [5]:
dataset_type = "gaussian_mixture"
dataset_overrides = dict(latent=dict(n_peaks=8, mi_bits_peak=2.0, mu=2.0, sig=1.0), transform = dict(mode='identity', observe_dim_x =None, observe_dim_y =None))
ds_defaults = copy.deepcopy(DATASET_DEFAULTS[dataset_type])
ds_cfg = merge_with_validation(ds_defaults, dataset_overrides, "dataset overrides")

In [6]:
ds_cfg


{'latent': {'n_peaks': 8,
  'mu': 2.0,
  'sig': 1.0,
  'mi_bits_peak': 2.0,
  'latent_dim': 1},
 'transform': {'mode': 'identity',
  'observe_dim_x': None,
  'observe_dim_y': None,
  'sig_embed_x': 0.0,
  'sig_embed_y': 0.0,
  'noise_mode': 'white_relative'}}

In [7]:
data_generator = make_data_generator(dataset_type, ds_cfg, device = 'cuda')


In [8]:
zx, zy = data_generator(100)

torch.Size([100, 1])