Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from schema.Line import Line

from utils import io


def main():
drift1 = DriftElement(
Expand Down Expand Up @@ -52,7 +54,9 @@ def main():
]
)
# Serialize to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
yaml_data = yaml.dump(
io.custom_encoder(line.model_dump()), default_flow_style=False
)
print("Dumping YAML data...")
print(f"{yaml_data}")
# Write YAML data to file
Expand All @@ -63,11 +67,13 @@ def main():
with open(yaml_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse YAML data
loaded_line = Line(**yaml_data)
loaded_line = Line(**io.custom_decoder(yaml_data))
# Validate loaded data
assert line == loaded_line
# Serialize to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
json_data = json.dumps(
io.custom_encoder(line.model_dump()), sort_keys=True, indent=2
)
print("Dumping JSON data...")
print(f"{json_data}")
# Write JSON data to file
Expand All @@ -78,7 +84,7 @@ def main():
with open(json_file, "r") as file:
json_data = json.loads(file.read())
# Parse JSON data
loaded_line = Line(**json_data)
loaded_line = Line(**io.custom_decoder(json_data))
# Validate loaded data
assert line == loaded_line

Expand Down
24 changes: 24 additions & 0 deletions utils/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
def custom_encoder(data):
if isinstance(data, dict):
kind = data.get("kind")
if kind in ["Drift", "Quadrupole"]:
return {kind: {k: v for k, v in data.items() if k != "kind"}}
return {k: custom_encoder(v) for k, v in data.items()}
elif isinstance(data, list):
return [custom_encoder(item) for item in data]
else:
return data


def custom_decoder(data):
if isinstance(data, dict):
if len(data) == 1:
kind, element_data = next(iter(data.items()))
if kind in ["Drift", "Quadrupole"]:
element_data["kind"] = kind
return element_data
return {k: custom_decoder(v) for k, v in data.items()}
elif isinstance(data, list):
return [custom_decoder(item) for item in data]
else:
return data