Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,17 +2143,23 @@ def deserialize_meta_func(serialized_target: str):
def import_nn_module_stack(key, path, ty):
return key, (path, ty)

# Helper function that splits strings by commas except for those
# encapsulated by parens, which are valid traces.
# TODO: Currently this is needed due to indexing Sequential
# layers introducing names in the form "layer.slice(1, None, None)".
# If that naming is improved, this fancier splitting can probably be
# reverted to a simple split by comma.
# Helper function to split string by commas, accounting for nested parentheses/brackets
def metadata_split(metadata):
# Remove the parentheses and commas inside them
metadata = re.sub(r"\(.*?\)", "", metadata)
# Split the string by comma, except for those inside parentheses
return re.split(r"(?<!\()\s*,\s*(?!\()", metadata)
out = []
start, depth = 0, 0
for position, char in enumerate(metadata):
if char in "[(":
depth += 1
elif char in ")]":
depth -= 1
if depth < 0:
raise ValueError(f"Mismatched brackets in metadata: {metadata}")
elif char == "," and depth == 0:
out.append(metadata[start:position].strip())
start = position + 1
out.append(metadata[start:].strip())
assert len(out) == 3
return out

nn_module_stack = dict(
import_nn_module_stack(*metadata_split(item))
Expand Down
Loading