In [None]:
from typing import Any, Tuple, Sequence
import yaml
import numpy as np
import copy
from functools import reduce

In [None]:
def sweep_constructor(
    loader: yaml.SafeLoader, node: yaml.nodes.MappingNode
) -> dict[str, Any]:
    values = loader.construct_sequence(node.value[0][1])
    return {"type": "sweep", "values": values}


def coupled_sweep_constructor(
    loader: yaml.SafeLoader, node: yaml.nodes.MappingNode
) -> dict[str, Any]:
    target = node.value[0][1].value.split(".")
    values = loader.construct_sequence(node.value[1][1])
    return {"target": target, "values": values, "type": "coupled-sweep"}


def range_constructor(
    loader: yaml.SafeLoader, node: yaml.nodes.MappingNode
) -> dict[str, Any]:
    start = node.value[0][1].value
    end = node.value[1][1].value
    step = node.value[2][1].value if len(node.value) > 2 else 1
    return {"type": "range", "start": start, "end": end, "step": step}


def get_loader():
    loader = yaml.SafeLoader
    loader.add_constructor("!sweep", sweep_constructor)
    loader.add_constructor("!coupled-sweep", coupled_sweep_constructor)
    loader.add_constructor("!range", range_constructor)
    return loader

In [None]:
cfg = yaml.load(open("./config.yaml"), Loader=get_loader())

In [None]:
cfg

In [None]:
def assign_at_path(cfg: dict, path: Sequence[str], value: Any) -> None:
    for p in path[:-1]:
        cfg = cfg[p]
    cfg[path[-1]] = value


def get_at_path(cfg: dict, path: Sequence[str]) -> Any:
    for p in path[:-1]:
        cfg = cfg[p]
    return cfg[path[-1]]

In [None]:
class ConfigHandler:
    def __init__(self, config: dict):
        self.config = config
        self.run_configs = []
        sweep_targets = {}
        coupled_targets = {}

        self._extract_sweep_dims([], self.config, sweep_targets, coupled_targets)
        self.run_configs = self._construct_run_configs(sweep_targets)

    def _extract_sweep_dims(
        self, k, cfg_node: dict, sweep_targets, coupled_targets
    ) -> Tuple[list[Any], list[Any]]:
        for key, node in cfg_node.items():
            if isinstance(node, dict) and "type" in node:
                if node["type"] == "sweep":
                    k.append(key)
                    sweep_targets[key] = {
                        "path": k,
                        "values": node["values"],
                        "partner": None,
                    }
                    k = []
                elif node["type"] == "coupled-sweep":
                    k.append(key)
                    coupled_targets[key] = {
                        "path": k,
                        "target": node["target"],
                        "values": node["values"],
                    }
                    k = []
            elif isinstance(node, dict):
                k.append(key)
                self._extract_sweep_dims(k, node, sweep_targets, coupled_targets)
                k = []
            else:
                k = []

        for k, v in coupled_targets.items():
            last = v["target"][-1]
            key = [last, "partner"]
            assign_at_path(sweep_targets, key, v["values"])
            key = [last, "partner_path"]
            assign_at_path(sweep_targets, key, v["path"])

        return sweep_targets

    def _construct_cartesian_product(
        self, elements, current_list, possible_partner, all_lists, *args, i=0
    ) -> None:
        i += 1
        for k in range(len(current_list)):
            v = current_list[k]
            if i < len(all_lists):
                if possible_partner is not None:
                    w = possible_partner[k]
                    self._construct_cartesian_product(
                        elements,
                        all_lists[i][0],
                        all_lists[i][1],
                        all_lists,
                        *args,
                        v,
                        w,
                        i=i,
                    )
                else:
                    self._construct_cartesian_product(
                        elements,
                        all_lists[i][0],
                        all_lists[i][1],
                        all_lists,
                        *args,
                        v,
                        i=i,
                    )
            else:
                if possible_partner is not None:
                    w = possible_partner[k]
                    elements.append([*args, v, w])
                else:
                    elements.append([*args, v])

    def _construct_run_configs(self, sweep_targets) -> list[dict[str, Any]]:
        # make list of tuple values
        sweep_keys = [k for k in sweep_targets.keys()]
        lookup_keys = []

        for v in sweep_targets.values():
            lookup_keys.append(tuple(v["path"]))
            if "partner_path" in v:
                lookup_keys.append(tuple(v["partner_path"]))

        lists = [
            (
                sweep_targets[k]["values"],
                sweep_targets[k]["partner"] if "partner" in sweep_targets[k] else None,
            )
            for k in sweep_keys
        ]

        i = 0

        # # make the elements
        elements = []

        self._construct_cartesian_product(
            elements, lists[i][0], lists[i][1], lists, i=i
        )

        configs = [copy.deepcopy(self.config) for _ in elements]

        for c, e in zip(configs, elements):
            for k, v in enumerate(e):
                assign_at_path(c, lookup_keys[k], v)
        return configs