diff --git a/examples/fodo.py b/examples/fodo.py index c2e6ed5..f425ad8 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -13,6 +13,8 @@ from schema.Line import Line +from utils import io + def main(): drift1 = DriftElement( @@ -52,7 +54,9 @@ def main(): ] ) # Serialize to YAML - yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) + yaml_data = yaml.dump( + io.custom_encoder(line.model_dump()), default_flow_style=False + ) print("Dumping YAML data...") print(f"{yaml_data}") # Write YAML data to file @@ -63,7 +67,7 @@ def main(): with open(yaml_file, "r") as file: yaml_data = yaml.safe_load(file) # Parse YAML data - loaded_line = Line(**yaml_data) + loaded_line = Line(**io.custom_decoder(yaml_data)) # Validate loaded data assert line == loaded_line # Serialize to JSON diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000..df35e53 --- /dev/null +++ b/utils/io.py @@ -0,0 +1,24 @@ +def custom_encoder(data): + if isinstance(data, dict): + kind = data.get("kind") + if kind in ["Drift", "Quadrupole"]: + return {kind: {k: v for k, v in data.items() if k != "kind"}} + return {k: custom_encoder(v) for k, v in data.items()} + elif isinstance(data, list): + return [custom_encoder(item) for item in data] + else: + return data + + +def custom_decoder(data): + if isinstance(data, dict): + if len(data) == 1: + kind, element_data = next(iter(data.items())) + if kind in ["Drift", "Quadrupole"]: + element_data["kind"] = kind + return element_data + return {k: custom_decoder(v) for k, v in data.items()} + elif isinstance(data, list): + return [custom_decoder(item) for item in data] + else: + return data