This repository was archived by the owner on Aug 1, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
Fix stack summary order to comply with fx's stack trace order #899
Merged
SherlockNoMad
merged 1 commit into
gh/SherlockNoMad/3/base
from
gh/SherlockNoMad/3/head
Aug 19, 2022
Merged
Fix stack summary order to comply with fx's stack trace order #899
SherlockNoMad
merged 1 commit into
gh/SherlockNoMad/3/base
from
gh/SherlockNoMad/3/head
Aug 19, 2022
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ghstack-poisoned]
jansel
approved these changes
Aug 19, 2022
SherlockNoMad
added a commit
that referenced
this pull request
Aug 19, 2022
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 23, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 23, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` Pull Request resolved: #83706 Approved by: https://github.com/Chillee, https://github.com/ezyang
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 24, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 25, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 25, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` [ghstack-poisoned]
facebook-github-bot
pushed a commit
to pytorch/functorch
that referenced
this pull request
Aug 26, 2022
Summary: Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` X-link: pytorch/pytorch#83706 Approved by: https://github.com/Chillee, https://github.com/ezyang Reviewed By: weiwangmeta Differential Revision: D39008162 Pulled By: SherlockNoMad fbshipit-source-id: 073ac63c7efb1abee14f83792040dd679223e345
mehtanirav
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Aug 26, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` Pull Request resolved: #83706 Approved by: https://github.com/Chillee, https://github.com/ezyang
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 29, 2022
…dable()" Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` diff-train-skip-merge [ghstack-poisoned]
SherlockNoMad
added a commit
to pytorch/pytorch
that referenced
this pull request
Aug 29, 2022
Precondition: pytorch/torchdynamo#899 Given following function ``` def my_relu(a): return a.relu() def func(a, b): d = torch.square(a + b) e = my_relu(d) f = d.sin() s = torch.stack([e, f]) s = s.sum() ``` Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx - joint graph with torchdynamo.optimize("aot_nop") Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace ``` def forward(self, primals, tangents): primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None # No stacktrace found for following nodes is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None getitem = unbind_int[0] getitem_1 = unbind_int[1]; unbind_int = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True) view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec) ``` - default symbolic_trace Notice that nodes without stacktrace are folded under same region ``` def forward(self, a, b): # No stacktrace found for following nodes add = a + b; a = b = None square = torch.square(add); add = None relu = square.relu() sin = square.sin(); square = None stack = torch.stack([relu, sin]); relu = sin = None sum_1 = stack.sum(); stack = None return sum_1 ``` - symbolic_trace with record_stack_traces=True ``` def forward(self, a, b): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add = a + b; a = b = None square = torch.square(add); add = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu = square.relu() # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin = square.sin(); square = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack = torch.stack([relu, sin]); relu = sin = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_1 = stack.sum(); stack = None return sum_1 ``` - make_fx without decomposition ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() relu_default = torch.ops.aten.relu.default(pow_tensor_scalar) detach_default = torch.ops.aten.detach.default(relu_default) # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None return sum_default ``` - make_fx with decomposition to prims ``` def forward(self, a_1, b_1): # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b) broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu() le_default = torch.ops.prims.le.default(mul_default, 0.0) where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin() sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f]) cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum() convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None return sum_default ``` diff-train-skip-merge [ghstack-poisoned]
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
This is needed for pytorch/pytorch#83706