In [159]:
import yaml
import pprint
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass

YAML_PATH = "tests/test_data/prot_dna_experiment/test_config.yaml"



In [160]:
with open(YAML_PATH, "r") as f:
    data = yaml.safe_load(f)

print(data)

{'global_params': {'seed': 0}, 'columns': [{'column_name': 'hello', 'column_type': 'input', 'data_type': 'str', 'parsing': 'OneHotEncoder'}, {'column_name': 'bonjour', 'column_type': 'input', 'data_type': 'str', 'parsing': 'OneHotEncoder'}], 'transforms': [{'transformation_name': 'A', 'columns': [{'column_name': 'col1', 'transformations': [{'name': 'ReverseComplement', 'params': None}]}]}, {'transformation_name': 'B', 'columns': [{'column_name': 'col1', 'transformations': [{'name': 'UniformTextMasker', 'params': {'probability': [0.1, 0.2, 0.3]}}]}]}, {'transformation_name': 'C', 'columns': [{'column_name': 'col1', 'transformations': [{'name': 'ReverseComplement', 'params': None}, {'name': 'UniformTextMasker', 'params': {'probability': [0.1, 0.2, 0.3, 0.4]}}, {'name': 'GaussianNoise', 'params': {'std': [0.1, 0.2, 0.3, 0.4]}}]}, {'column_name': 'col2', 'transformations': [{'name': 'GaussianNoise', 'params': {'std': [0.1, 0.2, 0.1, 0.2]}}]}]}], 'split': [{'split_method': 'random_split', '

In [161]:
try_expand = data['transforms'][1]
try_expand_harder = data['transforms'][2]


In [162]:
@dataclass(frozen=True)
class TransformKeys:
    COLUMN_NAME_KEY: str = "column_name"
    COLUMN_KEY: str = "columns"
    NAME_KEY: str = "name"
    PARAMS_KEY: str = "params"
    TRANSFORMATIONS_KEY: str = "transformations"

def get_length_of_params_dict(dict_to_split: dict) -> int:
    """
    This function takes as input a dictionary and returns the length of a params keys in the nested dictionaries (assumes all lengths are equal)
    """
    for column in dict_to_split[TransformKeys.COLUMN_KEY]:
        for transformation in column[TransformKeys.TRANSFORMATIONS_KEY]:
            if isinstance(transformation[TransformKeys.PARAMS_KEY], dict):
                # check for lists within the params dict
                for key, value in transformation[TransformKeys.PARAMS_KEY].items():
                    if isinstance(value, list):
                        # check that the list has more than one element
                        if len(value) > 1:
                            return len(value)
    return 0

def get_transform_base_dict(dict_to_split: dict) -> dict:
    """
    This function takes as input a dictionary to expand and returns a dictionary with the params keys reset to empty dictionaries.
    """
    # Create a defaultdict that will return empty string for missing keys
    base_dict  = defaultdict(str, deepcopy(dict_to_split))
    # Reset all the params keys in the nested dicts
    for column in base_dict[TransformKeys.COLUMN_KEY]:
        for transformation in column[TransformKeys.TRANSFORMATIONS_KEY]:
            # type check that transformation[PARAM_KEY] is a dictionary
            if isinstance(transformation[TransformKeys.PARAMS_KEY], dict):
                transformation[TransformKeys.PARAMS_KEY] = {}
    return dict(base_dict)

def split_transform_dict(dict_to_split: dict, base_dict: dict, split_index: int) -> dict: 
    """
    This function takes as input a dictionary to split and returns a dictionary with a single param value.
    """

    split_dict = deepcopy(base_dict)

    for column_index, column in enumerate(dict_to_split[TransformKeys.COLUMN_KEY]):
        for transformation_index, transformation in enumerate(column[TransformKeys.TRANSFORMATIONS_KEY]):
            if isinstance(transformation[TransformKeys.PARAMS_KEY], dict):
                # create a new empty dictionary that has the same keys as the transformation[PARAMS_KEY] dict 
                temp_dict = dict.fromkeys(transformation[TransformKeys.PARAMS_KEY].keys())
                for key, value in transformation[TransformKeys.PARAMS_KEY].items():
                    if isinstance(value, list):
                        # check that the list has more than one element
                        if len(value) > 1:
                            temp_dict[key] = value[split_index]
                        else:
                            temp_dict[key] = value[0]
                    else:
                        temp_dict[key] = value
                split_dict[TransformKeys.COLUMN_KEY][column_index][TransformKeys.TRANSFORMATIONS_KEY][transformation_index][TransformKeys.PARAMS_KEY] = temp_dict

    return split_dict

def get_all_transform_dicts(dict_to_split: dict) -> list[dict]:
    """
    This function takes as input a dictionary to split and returns a list of dictionaries, each with a single param value.
    """
    length_of_params_dict = get_length_of_params_dict(dict_to_split)
    base_dict = get_transform_base_dict(dict_to_split)
    transform_dicts = []
    for i in range(length_of_params_dict):
        transform_dicts.append(split_transform_dict(dict_to_split, base_dict, i))
    return transform_dicts



In [163]:

#pprint.pprint(get_length_of_params_dict(try_expand))
test_dict = split_transform_dict(try_expand_harder, get_transform_base_dict(try_expand_harder), 1)
pprint.pprint(test_dict)

{'columns': [{'column_name': 'col1',
              'transformations': [{'name': 'ReverseComplement', 'params': None},
                                  {'name': 'UniformTextMasker',
                                   'params': {'probability': 0.2}},
                                  {'name': 'GaussianNoise',
                                   'params': {'std': 0.2}}]},
             {'column_name': 'col2',
              'transformations': [{'name': 'GaussianNoise',
                                   'params': {'std': 0.2}}]}],
 'transformation_name': 'C'}


In [164]:
pprint.pprint(get_all_transform_dicts(try_expand_harder))


[{'columns': [{'column_name': 'col1',
               'transformations': [{'name': 'ReverseComplement',
                                    'params': None},
                                   {'name': 'UniformTextMasker',
                                    'params': {'probability': 0.1}},
                                   {'name': 'GaussianNoise',
                                    'params': {'std': 0.1}}]},
              {'column_name': 'col2',
               'transformations': [{'name': 'GaussianNoise',
                                    'params': {'std': 0.1}}]}],
  'transformation_name': 'C'},
 {'columns': [{'column_name': 'col1',
               'transformations': [{'name': 'ReverseComplement',
                                    'params': None},
                                   {'name': 'UniformTextMasker',
                                    'params': {'probability': 0.2}},
                                   {'name': 'GaussianNoise',
                                    'para