In [None]:
tst = {
    "tst1": ["1", "2", "3"],
    "tst2": {"1": "4", "2": "5", "3": "6"},
    "tst3": "7",
    "tst4": {"1": ["8", "9"], "2": {"1": "10", "2": "11"}},
}
tst2 = "tst2"

In [None]:
from __future__ import annotations
from typing import Any

NestedInputs = str | list['NestedInputs'] | dict[str, 'NestedInputs']

def _flatten_inputs(inputs: NestedInputs) -> str | list[str]:
    if isinstance(inputs, str):
        return inputs
    
    out = []
    for nested in inputs.values() if isinstance(inputs, dict) else inputs:
        data = _flatten_inputs(nested)
        if isinstance(data, str):
            out.append(data)
        else:
            out.extend(data)
    
    return out

_flatten_inputs(tst)

['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']

In [None]:
_flatten_inputs(tst2)

'tst2'

In [None]:
tst_data = {str(d): d for d in range(12)}
tst_data

{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11}

In [None]:
def _bind_inputs(inputs: NestedInputs, datasets: dict[str, Any]):
    if isinstance(inputs, str):
        return datasets[inputs]
    
    if isinstance(inputs, dict):
        out = {}
        for name, nested in inputs.items():
            out[name] = _bind_inputs(nested, datasets)
        return out
    
    assert isinstance(inputs, list)

    out = []
    for nested in inputs:
        out.append(_bind_inputs(nested, datasets))
    return out

out = _bind_inputs(tst, tst_data)
out

{'tst1': [1, 2, 3],
 'tst2': {'1': 4, '2': 5, '3': 6},
 'tst3': 7,
 'tst4': {'1': [8, 9], '2': {'1': 10, '2': 11}}}

In [None]:

NestedOutputs = Any | list["NestedInputs"] | dict[str, "NestedInputs"]

def _flatten_outputs(nested: NestedInputs, outputs: NestedOutputs) -> dict[str, Any]:
    if isinstance(nested, str):
        return {nested: outputs}

    out = {}
    if isinstance(nested, dict):
        assert isinstance(outputs, dict)
        for idx, vals in nested.items():
            data = _flatten_outputs(vals, outputs[idx])
            out.update(data)
    else:
        assert isinstance(outputs, list) and isinstance(nested, list)
        for vals, outs in zip(nested, outputs):
            data = _flatten_outputs(vals, outs)
            out.update(data)

    return out

_flatten_outputs(tst, out)

{'1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11}