@@ -2143,17 +2143,21 @@ def deserialize_meta_func(serialized_target: str):
21432143 def import_nn_module_stack (key , path , ty ):
21442144 return key , (path , ty )
21452145
2146- # Helper function that splits strings by commas except for those
2147- # encapsulated by parens, which are valid traces.
2148- # TODO: Currently this is needed due to indexing Sequential
2149- # layers introducing names in the form "layer.slice(1, None, None)".
2150- # If that naming is improved, this fancier splitting can probably be
2151- # reverted to a simple split by comma.
2146+ # Helper function to split string by commas, accounting for nested parentheses/brackets
21522147 def metadata_split (metadata ):
2153- # Remove the parentheses and commas inside them
2154- metadata = re .sub (r"\(.*?\)" , "" , metadata )
2155- # Split the string by comma, except for those inside parentheses
2156- return re .split (r"(?<!\()\s*,\s*(?!\()" , metadata )
2148+ out = []
2149+ start , depth = 0 , 0
2150+ for position , char in enumerate (metadata ):
2151+ if char in "[(" :
2152+ depth += 1
2153+ elif char in ")]" :
2154+ depth -= 1
2155+ elif char == "," and depth == 0 :
2156+ out .append (metadata [start :position ])
2157+ start = position + 1
2158+ out .append (metadata [start :])
2159+ assert len (out ) == 3
2160+ return out
21572161
21582162 nn_module_stack = dict (
21592163 import_nn_module_stack (* metadata_split (item ))
0 commit comments