Skip to content

Commit

Permalink
[FX] Fix python code having spurious newlines from placeholders
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
James Reed committed Dec 22, 2020
1 parent 5b163e2 commit cca698a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch/fx/graph.py
Expand Up @@ -619,6 +619,8 @@ def delete_unused_values(user : Node):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
return
if user.op == 'output':
body.append('\n')
return
Expand All @@ -637,7 +639,7 @@ def emit_node(node : Node):
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != node.name:
body.append(f'{node.name} = {raw_name}')
body.append(f'{node.name} = {raw_name}\n')
return
elif node.op == 'call_method':
assert isinstance(node.target, str)
Expand Down

0 comments on commit cca698a

Please sign in to comment.