Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TORCH_LOGS=dynamo,aot shows pattern_matcher.py usage of AotAutograd #98778

Closed
awgu opened this issue Apr 10, 2023 · 9 comments
Closed

TORCH_LOGS=dynamo,aot shows pattern_matcher.py usage of AotAutograd #98778

awgu opened this issue Apr 10, 2023 · 9 comments
Assignees
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@awgu
Copy link
Contributor

awgu commented Apr 10, 2023

Define a file repro.py:

import torch
x = torch.randn(3)
@torch.compile()
def f():
    return x + x
f()

Run on viable/strict:

TORCH_LOGS=dynamo,aot python repro.py

This shows not only the forward graph for f but also 6 joint graphs containing ops that should not be related to f, e.g. bmm, permute, etc.

Full output
[2023-04-10 12:28:08,886] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing f
[2023-04-10 12:28:08,891] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing f (RETURN_VALUE)
[2023-04-10 12:28:08,892] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-04-10 12:28:08,991] torch._inductor.utils: [WARNING] make_fallback(aten.cumprod): a decomposition exists, we should switch to it
[2023-04-10 12:28:09,922] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 0 ======
 <eval_with_key>.4 class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[3], arg1_1: f32[3]):
        # File: /.../repro.py:8, code: return x + x
        add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
        return (add,)
        

[2023-04-10 12:28:11,607] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 1 =====
 <eval_with_key>.6 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(div, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(div, amax);  div = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div_1, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div_1);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul, [-1], True)
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None], self._out_spec)
        

[2023-04-10 12:28:11,630] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 1 ======
 <eval_with_key>.8 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[]):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(div, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(div, amax);  div = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div_1, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, view_1, view, div_1, view_4, primals_4, view_3]
        

[2023-04-10 12:28:11,632] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 1 ======
 <eval_with_key>.9 class GraphModule(torch.nn.Module):
    def forward(self, view_1: f32[8, 16, 8], view: f32[8, 8, 16], div_1: f32[2, 4, 8, 8], view_4: f32[8, 8, 16], primals_4: f32[], view_3: f32[8, 8, 8], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div_1);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul, [-1], True)
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None]
        

[2023-04-10 12:28:11,768] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 2 =====
 <eval_with_key>.14 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(mul, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, amax);  mul = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_1, [-1], True)
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div, sum_2);  div = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(mul_3, [8, 8, 8]);  mul_3 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None], self._out_spec)
        

[2023-04-10 12:28:11,791] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 2 ======
 <eval_with_key>.16 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[]):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(mul, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, amax);  mul = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, view, view_4, div, view_3, view_1, primals_4]
        

[2023-04-10 12:28:11,793] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 2 ======
 <eval_with_key>.17 class GraphModule(torch.nn.Module):
    def forward(self, view: f32[8, 8, 16], view_4: f32[8, 8, 16], div: f32[2, 4, 8, 8], view_3: f32[8, 8, 8], view_1: f32[8, 16, 8], primals_4: f32[], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_1, [-1], True)
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div, sum_2);  div = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(mul_3, [8, 8, 8]);  mul_3 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None]
        

[2023-04-10 12:28:11,945] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 3 =====
 <eval_with_key>.22 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], primals_5, tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(div, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(div, amax);  div = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div_1, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul);  view_10 = mul = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_1, memory_format = torch.contiguous_format);  mul_1 = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div_1);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_2, [-1], True)
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_2, mul_3);  mul_2 = mul_3 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None, None], self._out_spec)
        

[2023-04-10 12:28:11,971] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 3 ======
 <eval_with_key>.24 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], primals_5):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(div, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(div, amax);  div = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div_1, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, getitem_1, view_3, view_1, primals_4, view, view_4, div_1]
        

[2023-04-10 12:28:11,973] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 3 ======
 <eval_with_key>.25 class GraphModule(torch.nn.Module):
    def forward(self, getitem_1: b8[2, 4, 8, 8], view_3: f32[8, 8, 8], view_1: f32[8, 16, 8], primals_4: f32[], view: f32[8, 8, 16], view_4: f32[8, 8, 16], div_1: f32[2, 4, 8, 8], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul);  view_10 = mul = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_1, memory_format = torch.contiguous_format);  mul_1 = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div_1);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_2, [-1], True)
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_2, mul_3);  mul_2 = mul_3 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None, None]
        

[2023-04-10 12:28:12,123] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 4 =====
 <eval_with_key>.30 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], primals_5, tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(mul, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, amax);  mul = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul_1);  view_10 = mul_1 = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_2, memory_format = torch.contiguous_format);  mul_2 = None
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_3, [-1], True)
        mul_4: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div, sum_2);  div = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
        mul_5: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(mul_5, [8, 8, 8]);  mul_5 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None, None], self._out_spec)
        

[2023-04-10 12:28:12,150] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 4 ======
 <eval_with_key>.32 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[], primals_5):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_2, primals_4);  view_2 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(mul, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, amax);  mul = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, primals_4, view_3, view_1, view, view_4, getitem_1, div]
        

[2023-04-10 12:28:12,152] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 4 ======
 <eval_with_key>.33 class GraphModule(torch.nn.Module):
    def forward(self, primals_4: f32[], view_3: f32[8, 8, 8], view_1: f32[8, 16, 8], view: f32[8, 8, 16], view_4: f32[8, 8, 16], getitem_1: b8[2, 4, 8, 8], div: f32[2, 4, 8, 8], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul_1);  view_10 = mul_1 = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_2, memory_format = torch.contiguous_format);  mul_2 = None
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_3, [-1], True)
        mul_4: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div, sum_2);  div = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
        mul_5: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(sub_1, primals_4);  sub_1 = primals_4 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(mul_5, [8, 8, 8]);  mul_5 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None, None]
        

[2023-04-10 12:28:12,293] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 5 =====
 <eval_with_key>.38 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[1, 1, 8, 8], tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, 4.0);  view_2 = None
        add: f32[2, 4, 8, 8] = torch.ops.aten.add.Tensor(div, primals_4);  div = primals_4 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(add, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(add, amax);  add = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div_1, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div_1);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul, [-1], True)
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, 4.0);  sub_1 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None], self._out_spec)
        

[2023-04-10 12:28:12,314] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 5 ======
 <eval_with_key>.40 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[1, 1, 8, 8]):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, 4.0);  view_2 = None
        add: f32[2, 4, 8, 8] = torch.ops.aten.add.Tensor(div, primals_4);  div = primals_4 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(add, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(add, amax);  add = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(div_1, [2, 4, 8, 8])
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, view_4, div_1, view_1, view, view_3]
        

[2023-04-10 12:28:12,315] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 5 ======
 <eval_with_key>.41 class GraphModule(torch.nn.Module):
    def forward(self, view_4: f32[8, 8, 16], div_1: f32[2, 4, 8, 8], view_1: f32[8, 16, 8], view: f32[8, 8, 16], view_3: f32[8, 8, 8], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, div_1);  view_10 = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul, [-1], True)
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, 4.0);  sub_1 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None]
        

[2023-04-10 12:28:12,464] torch._functorch.aot_autograd.__aot_joint_graph: [INFO] TRACED GRAPH
 ====== Joint graph 6 =====
 <eval_with_key>.46 class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[1, 1, 8, 8], primals_5, tangents_1: f32[2, 4, 8, 16], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, 4.0);  view_2 = None
        add: f32[2, 4, 8, 8] = torch.ops.aten.add.Tensor(div, primals_4);  div = primals_4 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(add, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(add, amax);  add = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div_1, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul);  view_10 = mul = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_1, memory_format = torch.contiguous_format);  mul_1 = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div_1);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_2, [-1], True)
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_2, mul_3);  mul_2 = mul_3 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, 4.0);  sub_1 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return pytree.tree_unflatten([view_6, view_13, permute_5, view_9, None, None], self._out_spec)
        

[2023-04-10 12:28:12,489] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 6 ======
 <eval_with_key>.48 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 4, 8, 16], primals_2: f32[2, 4, 8, 16], primals_3: f32[2, 4, 8, 16], primals_4: f32[1, 1, 8, 8], primals_5):
        # No stacktrace found for following nodes
        permute: f32[2, 4, 16, 8] = torch.ops.aten.permute.default(primals_2, [0, 1, 3, 2]);  primals_2 = None
        expand: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_1, [2, 4, 8, 16]);  primals_1 = None
        view: f32[8, 8, 16] = torch.ops.aten.view.default(expand, [8, 8, 16]);  expand = None
        expand_1: f32[2, 4, 16, 8] = torch.ops.aten.expand.default(permute, [2, 4, 16, 8]);  permute = None
        view_1: f32[8, 16, 8] = torch.ops.aten.view.default(expand_1, [8, 16, 8]);  expand_1 = None
        bmm: f32[8, 8, 8] = torch.ops.aten.bmm.default(view, view_1)
        view_2: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm, [2, 4, 8, 8]);  bmm = None
        div: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(view_2, 4.0);  view_2 = None
        add: f32[2, 4, 8, 8] = torch.ops.aten.add.Tensor(div, primals_4);  div = primals_4 = None
        amax: f32[2, 4, 8, 1] = torch.ops.aten.amax.default(add, [-1], True)
        sub: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(add, amax);  add = amax = None
        exp: f32[2, 4, 8, 8] = torch.ops.aten.exp.default(sub);  sub = None
        sum_1: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
        div_1: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(exp, sum_1);  exp = sum_1 = None
        native_dropout = torch.ops.aten.native_dropout.default(div_1, 0.113377, True)
        getitem: f32[2, 4, 8, 8] = native_dropout[0]
        getitem_1: b8[2, 4, 8, 8] = native_dropout[1];  native_dropout = None
        expand_2: f32[2, 4, 8, 8] = torch.ops.aten.expand.default(getitem, [2, 4, 8, 8]);  getitem = None
        view_3: f32[8, 8, 8] = torch.ops.aten.view.default(expand_2, [8, 8, 8]);  expand_2 = None
        expand_3: f32[2, 4, 8, 16] = torch.ops.aten.expand.default(primals_3, [2, 4, 8, 16]);  primals_3 = None
        view_4: f32[8, 8, 16] = torch.ops.aten.view.default(expand_3, [8, 8, 16]);  expand_3 = None
        bmm_1: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_3, view_4)
        view_5: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_1, [2, 4, 8, 16]);  bmm_1 = None
        view_6: f32[2, 4, 8, 16] = torch.ops.aten.view.default(view_5, [2, 4, 8, 16]);  view_5 = None
        return [view_6, view_1, view, getitem_1, view_3, view_4, div_1]
        

[2023-04-10 12:28:12,491] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Backward graph 6 ======
 <eval_with_key>.49 class GraphModule(torch.nn.Module):
    def forward(self, view_1: f32[8, 16, 8], view: f32[8, 8, 16], getitem_1: b8[2, 4, 8, 8], view_3: f32[8, 8, 8], view_4: f32[8, 8, 16], div_1: f32[2, 4, 8, 8], tangents_1: f32[2, 4, 8, 16]):
        # No stacktrace found for following nodes
        view_7: f32[2, 4, 8, 16] = torch.ops.aten.view.default(tangents_1, [2, 4, 8, 16]);  tangents_1 = None
        
        # 
        view_8: f32[8, 8, 16] = torch.ops.aten.view.default(view_7, [8, 8, 16]);  view_7 = None
        permute_1: f32[8, 8, 8] = torch.ops.aten.permute.default(view_3, [0, 2, 1]);  view_3 = None
        bmm_2: f32[8, 8, 16] = torch.ops.aten.bmm.default(permute_1, view_8);  permute_1 = None
        permute_2: f32[8, 16, 8] = torch.ops.aten.permute.default(view_4, [0, 2, 1]);  view_4 = None
        bmm_3: f32[8, 8, 8] = torch.ops.aten.bmm.default(view_8, permute_2);  view_8 = permute_2 = None
        view_9: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_2, [2, 4, 8, 16]);  bmm_2 = None
        view_10: f32[2, 4, 8, 8] = torch.ops.aten.view.default(bmm_3, [2, 4, 8, 8]);  bmm_3 = None
        convert_element_type: f32[2, 4, 8, 8] = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32);  getitem_1 = None
        mul: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(convert_element_type, 1.1278750946005236);  convert_element_type = None
        mul_1: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(view_10, mul);  view_10 = mul = None
        clone: f32[2, 4, 8, 8] = torch.ops.aten.clone.default(mul_1, memory_format = torch.contiguous_format);  mul_1 = None
        mul_2: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(clone, div_1);  clone = None
        sum_2: f32[2, 4, 8, 1] = torch.ops.aten.sum.dim_IntList(mul_2, [-1], True)
        mul_3: f32[2, 4, 8, 8] = torch.ops.aten.mul.Tensor(div_1, sum_2);  div_1 = sum_2 = None
        sub_1: f32[2, 4, 8, 8] = torch.ops.aten.sub.Tensor(mul_2, mul_3);  mul_2 = mul_3 = None
        div_2: f32[2, 4, 8, 8] = torch.ops.aten.div.Tensor(sub_1, 4.0);  sub_1 = None
        view_11: f32[8, 8, 8] = torch.ops.aten.view.default(div_2, [8, 8, 8]);  div_2 = None
        permute_3: f32[8, 16, 8] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        bmm_4: f32[8, 16, 8] = torch.ops.aten.bmm.default(permute_3, view_11);  permute_3 = None
        permute_4: f32[8, 8, 16] = torch.ops.aten.permute.default(view_1, [0, 2, 1]);  view_1 = None
        bmm_5: f32[8, 8, 16] = torch.ops.aten.bmm.default(view_11, permute_4);  view_11 = permute_4 = None
        view_12: f32[2, 4, 16, 8] = torch.ops.aten.view.default(bmm_4, [2, 4, 16, 8]);  bmm_4 = None
        view_13: f32[2, 4, 8, 16] = torch.ops.aten.view.default(bmm_5, [2, 4, 8, 16]);  bmm_5 = None
        permute_5: f32[2, 4, 8, 16] = torch.ops.aten.permute.default(view_12, [0, 1, 3, 2]);  view_12 = None
        return [view_13, permute_5, view_9, None, None]
        

[2023-04-10 12:28:16,974] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper

In case this was a build issue, I had @H-Huang also run the script, and he similarly saw these extraneous ops/traced graphs.

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@awgu awgu changed the title TORCH_LOGS=dynamo, aot shows phantom ops TORCH_LOGS=dynamo,aot shows phantom joint graphs Apr 10, 2023
@awgu awgu changed the title TORCH_LOGS=dynamo,aot shows phantom joint graphs TORCH_LOGS=dynamo,aot shows phantom traced graphs Apr 10, 2023
@wconstab
Copy link
Contributor

hmm, i just ran this on master @ ff825de and can't repro.

[2023-04-10 22:00:14,108] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
 ====== Forward graph 0 ======
 <eval_with_key>.4 class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[3], arg1_1: f32[3]):
        # File: repro_awgu.py:5, code: return x + x
        add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
        return (add,)

@awgu
Copy link
Contributor Author

awgu commented Apr 10, 2023

Is there any way that there could be some strange caching behavior? Both @H-Huang and I ran on the AWS cluster.

@wconstab
Copy link
Contributor

I also ran on aws, and iiuc only triton is disk-caching at the moment, so it wouldn't be an explanation for your symptom of phantom dynamo compilations. it could be an actual bug that's fixed on master since the time you ran?

@awgu
Copy link
Contributor Author

awgu commented Apr 10, 2023

Update: I rebased to the same commit as you, and I cannot repro this either. It seems like there was a regression between ff825de and ad88afc (where ad88afc is my base commit).

Let me manually bisect this.

@wconstab
Copy link
Contributor

ok, thanks- i was just gonna say i'd bisect (since i'm oncall this week) but i'll let you do it

@awgu
Copy link
Contributor Author

awgu commented Apr 10, 2023

It looks like #97741 introduced the issue, which passes the sanity check since some of those joint graphs look like attention-type computations.

@wconstab wconstab added the bug label Apr 10, 2023
@wconstab wconstab added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 10, 2023
@jansel jansel changed the title TORCH_LOGS=dynamo,aot shows phantom traced graphs TORCH_LOGS=dynamo,aot shows pattern_matcher.py usage of AotAutograd Apr 11, 2023
@jansel
Copy link
Contributor

jansel commented Apr 11, 2023

The usage of AotAutograd is pattern matching is intended behavior and should only happen once on lazy_init when creating patterns. (And also when the pattern matcher fires.)

We should modify logging to hide these by default.

@awgu
Copy link
Contributor Author

awgu commented Apr 11, 2023

Is there any workaround for the interim? Debugging is a bit unwieldy with these six extra graphs being printed each time.

Edit: Nevermind, I will just make lazy_init() do nothing since I do not need the fused attention.

@jansel
Copy link
Contributor

jansel commented Apr 11, 2023

You can set torch._inductor.config.pattern_matcher=False

jansel added a commit that referenced this issue Apr 12, 2023
jansel added a commit that referenced this issue Apr 12, 2023
Fixes #98778

ghstack-source-id: 60bc387a50b5edf51cdb606b265c8339a607637f
Pull Request resolved: #98936
ZainRizvi pushed a commit that referenced this issue Apr 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants