Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

@SherlockNoMad
Copy link
Contributor

@SherlockNoMad SherlockNoMad commented Aug 18, 2022

@SherlockNoMad SherlockNoMad merged commit e162faf into gh/SherlockNoMad/3/base Aug 19, 2022
SherlockNoMad added a commit that referenced this pull request Aug 19, 2022
SherlockNoMad added a commit that referenced this pull request Aug 19, 2022
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```
Pull Request resolved: #83706
Approved by: https://github.com/Chillee, https://github.com/ezyang
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 24, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 25, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 25, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/functorch that referenced this pull request Aug 26, 2022
Summary:
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

X-link: pytorch/pytorch#83706
Approved by: https://github.com/Chillee, https://github.com/ezyang

Reviewed By: weiwangmeta

Differential Revision: D39008162

Pulled By: SherlockNoMad

fbshipit-source-id: 073ac63c7efb1abee14f83792040dd679223e345
mehtanirav pushed a commit to pytorch/pytorch that referenced this pull request Aug 26, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```
Pull Request resolved: #83706
Approved by: https://github.com/Chillee, https://github.com/ezyang
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 29, 2022
…dable()"


Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

diff-train-skip-merge

[ghstack-poisoned]
SherlockNoMad added a commit to pytorch/pytorch that referenced this pull request Aug 29, 2022
Precondition: pytorch/torchdynamo#899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    
    # No stacktrace found for following nodes 
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None
    
    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region 
```
def forward(self, a, b):
    
    # No stacktrace found for following nodes 
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```
    
- make_fx without decomposition
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None
    
    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```

diff-train-skip-merge

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/SherlockNoMad/3/head branch September 18, 2022 14:20
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants