Skip to content

[inductor] ValueError: floor is not defined #93885

@ConnorBaker

Description

@ConnorBaker

🐛 Describe the bug

Starting with #93185, running torch.compile(model, fullgraph=False, dynamic=False) on my model fails because floor isn't defined. The commit in master immediately before it was merged (f1030dcc6d5b0157418e30c6fc96ef6dcf60d878) works without issue.

To be fair, I am using triton compiled from HEAD, but I don't believe that to be the cause.

The salient bits of the below trace:

...
  File "/triton/python/triton/compiler.py", line 156, in get_value
    raise ValueError(f'{name} is not defined')
ValueError: floor is not defined

The above exception was the direct cause of the following exception:
...
File "/tmp/torchinductor_root/y7/cy7fhxcb3tfqtnbml2yrqwbraeebhpahqzs6f27xa4leesxs2hay.py", line 19, in <module>
    triton__0 = async_compile.triton('''
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 664, in triton
    return _load_kernel(source_code)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 530, in _load_kernel
    kernel.precompile()
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 67, in precompile
    self.launchers = [
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 68, in <listcomp>
    self._precompile_config(c, warm_cache_only_with_cc)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 93, in _precompile_config
    binary = triton.compile(
  File "/triton/python/triton/compiler.py", line 1623, in compile
    next_module = compile(module)
  File "/triton/python/triton/compiler.py", line 1552, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/triton/python/triton/compiler.py", line 962, in ast_to_ttir
    mod, _ = build_triton_ir(fn, signature, specialization, constants)
  File "/triton/python/triton/compiler.py", line 942, in build_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 16:100:
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1176000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 60)
    x0 = xindex % 60
    x2 = xindex
    tmp0 = (7*((x1 // 490) % 10)) + (((x1 % 49) // 7))
    tmp1 = 64
    tmp2 = tmp0 < tmp1
    tmp3 = (7*((x1 // 49) % 10)) + ((x1 % 49) % 7) + tl.zeros([XBLOCK], tl.int32)
    tmp4 = 64
    tmp5 = tmp3 < tmp4
    tmp6 = tmp5 & tmp2
    tmp7 = tl.load(in_ptr0 + (x0 + (60*((((x1 % 49) % 7) // 1))) + (420*((x1 // 49) % 10)) + (3840*(floor(ModularIndexing(x1, 1, 49)//7))) + (26880*((x1 // 490) % 10)) + (245760*(x1 // 4900)) + tl.zeros([XBLOCK], tl.int32)), tmp6 & xmask, other=0)
                                                                                                    ^

Error logs

Traceback (most recent call last):
  File "/triton/python/triton/compiler.py", line 937, in build_triton_ir
    generator.visit(fn.parse())
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 183, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/local/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 252, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/triton/python/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 301, in visit_Assign
    values = self.visit(node.value)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 757, in visit_Call
    args = [self.visit(arg) for arg in node.args]
  File "/triton/python/triton/compiler.py", line 757, in <listcomp>
    args = [self.visit(arg) for arg in node.args]
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 338, in visit_BinOp
    lhs = self.visit(node.left)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 338, in visit_BinOp
    lhs = self.visit(node.left)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 338, in visit_BinOp
    lhs = self.visit(node.left)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 751, in visit_Call
    fn = self.visit(node.func)
  File "/triton/python/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/usr/local/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/triton/python/triton/compiler.py", line 325, in visit_Name
    return self.get_value(node.id)
  File "/triton/python/triton/compiler.py", line 156, in get_value
    raise ValueError(f'{name} is not defined')
ValueError: floor is not defined

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 525, in _worker_compile
    kernel.precompile(warm_cache_only_with_cc=cc)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 67, in precompile
    self.launchers = [
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 68, in <listcomp>
    self._precompile_config(c, warm_cache_only_with_cc)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 81, in _precompile_config
    triton.compile(
  File "/triton/python/triton/compiler.py", line 1623, in compile
    next_module = compile(module)
  File "/triton/python/triton/compiler.py", line 1552, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/triton/python/triton/compiler.py", line 962, in ast_to_ttir
    mod, _ = build_triton_ir(fn, signature, specialization, constants)
  File "/triton/python/triton/compiler.py", line 942, in build_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 16:100:
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1176000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 60)
    x0 = xindex % 60
    x2 = xindex
    tmp0 = (7*((x1 // 490) % 10)) + (((x1 % 49) // 7))
    tmp1 = 64
    tmp2 = tmp0 < tmp1
    tmp3 = (7*((x1 // 49) % 10)) + ((x1 % 49) % 7) + tl.zeros([XBLOCK], tl.int32)
    tmp4 = 64
    tmp5 = tmp3 < tmp4
    tmp6 = tmp5 & tmp2
    tmp7 = tl.load(in_ptr0 + (x0 + (60*((((x1 % 49) % 7) // 1))) + (420*((x1 // 49) % 10)) + (3840*(floor(ModularIndexing(x1, 1, 49)//7))) + (26880*((x1 // 490) % 10)) + (245760*(x1 // 4900)) + tl.zeros([XBLOCK], tl.int32)), tmp6 & xmask, other=0)
                                                                                                    ^

Minified repro

With repro_after="aot"

isolate_fails_code_str = None


import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

import torch._dynamo.config
import torch._inductor.config
torch._dynamo.config.load_config(b'\x80\x04\x95\x91\x07\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x14torch._dynamo.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\rtorch._dynamo\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8c?/usr/local/lib/python3.10/site-packages/torch/_dynamo/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8cX/usr/local/lib/python3.10/site-packages/torch/_dynamo/__pycache__/config.cpython-310.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x07abspath\x94\x8c\tposixpath\x94h\x1f\x93\x94\x8c\x07dirname\x94h h"\x93\x94\x8c\tlog_level\x94K\x1e\x8c\x0boutput_code\x94\x89\x8c\rlog_file_name\x94N\x8c\x07verbose\x94\x89\x8c\x11output_graph_code\x94\x89\x8c\x12verify_correctness\x94\x89\x8c\x12minimum_call_count\x94K\x01\x8c\x15dead_code_elimination\x94\x88\x8c\x10cache_size_limit\x94K@\x8c\x14specialize_int_float\x94\x88\x8c\x0edynamic_shapes\x94\x89\x8c\x10guard_nn_modules\x94\x89\x8c\x0cnormalize_ir\x94\x89\x8c\x1btraceable_tensor_subclasses\x94\x8f\x94\x8c\x0fsuppress_errors\x94\x89\x8c\x15replay_record_enabled\x94\x89\x8c rewrite_assert_with_torch_assert\x94\x88\x8c\x12print_graph_breaks\x94\x89\x8c\x07disable\x94\x89\x8c*allowed_functions_module_string_ignorelist\x94\x8f\x94(\x8c\x0ctorch._prims\x94\x8c\rtorch._decomp\x94\x8c\x13torch.distributions\x94\x8c\x0btorch._refs\x94\x8c\rtorch.testing\x94\x90\x8c\x0frepro_tolerance\x94G?PbM\xd2\xf1\xa9\xfc\x8c\x16capture_scalar_outputs\x94\x89\x8c\x19enforce_cond_guards_match\x94\x88\x8c\x0coptimize_ddp\x94\x88\x8c\x1araise_on_ctx_manager_usage\x94\x88\x8c\x1craise_on_unsafe_aot_autograd\x94\x89\x8c\rdynamo_import\x94\x8c\rtorch._dynamo\x94\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\x18error_on_nested_fx_trace\x94\x88\x8c\tallow_rnn\x94\x89\x8c\x08base_dir\x94\x8c\'/usr/local/lib/python3.10/site-packages\x94\x8c\x0edebug_dir_root\x94\x8c\x19/bsrt/torch_compile_debug\x94\x8c)DO_NOT_USE_legacy_non_fake_example_inputs\x94\x89\x8c\x15_AccessLimitingConfig\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x0b__setattr__\x94h\x02\x8c!_AccessLimitingConfig.__setattr__\x94\x93\x94h\x03Nu\x8c\x15_allowed_config_names\x94\x8f\x94(\x8c\x03sys\x94hP\x8c\x02os\x94h+h.hEhJ\x8c\x12constant_functions\x94h@h*h4h0h1\x8c\x07logging\x94\x8c\x0brepro_after\x94hMh?hAh\x06h\x04h\x01\x8c\x05torch\x94\x8c\x0eexternal_utils\x94h)hKh%h"h&h8h7h6h\x1fh3\x8c!skipfiles_inline_module_allowlist\x94h\x1eh\x03\x8c\nModuleType\x94h/h(h-h\'hGh\x0fhBh$hCh\x1dh,hD\x8c\x0brepro_level\x94hOhIh5\x8c\x0c__builtins__\x94\x90\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94hc\x93\x94u.')
torch._inductor.config.load_config(b'\x80\x04\x95\x0f\t\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x16torch._inductor.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\x0ftorch._inductor\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8cA/usr/local/lib/python3.10/site-packages/torch/_inductor/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8cZ/usr/local/lib/python3.10/site-packages/torch/_inductor/__pycache__/config.cpython-310.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x05debug\x94\x89\x8c\x10disable_progress\x94\x88\x8c\x10verbose_progress\x94\x89\x8c\x0bcpp_wrapper\x94\x89\x8c\x03dce\x94\x89\x8c\x14static_weight_shapes\x94\x88\x8c\x0csize_asserts\x94\x88\x8c\x10pick_loop_orders\x94\x88\x8c\x0finplace_buffers\x94\x88\x8c\x11benchmark_harness\x94\x88\x8c\x0fepilogue_fusion\x94\x89\x8c\x15epilogue_fusion_first\x94\x89\x8c\x0fpattern_matcher\x94\x88\x8c\nreordering\x94\x89\x8c\x0cmax_autotune\x94\x89\x8c\x17realize_reads_threshold\x94K\x04\x8c\x17realize_bytes_threshold\x94M\xd0\x07\x8c\x1brealize_acc_reads_threshold\x94K\x08\x8c\x0ffallback_random\x94\x89\x8c\x12implicit_fallbacks\x94\x88\x8c\rprefuse_nodes\x94\x88\x8c\x0btune_layout\x94\x89\x8c\x11aggressive_fusion\x94\x89\x8c\x0fmax_fusion_size\x94K@\x8c\x1bunroll_reductions_threshold\x94K\x08\x8c\x0ecomment_origin\x94\x89\x8c\tis_fbcode\x94h\x02h9\x93\x94\x8c\x0fcompile_threads\x94K \x8c\x13kernel_name_max_ops\x94K\n\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\rshape_padding\x94\x89\x8c\x0epermute_fusion\x94\x89\x8c\x1aprofiler_mark_wrapper_call\x94\x89\x8c\x03cpp\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x07threads\x94J\xff\xff\xff\xff\x8c\x0fdynamic_threads\x94\x89\x8c\x07simdlen\x94N\x8c\x0emin_chunk_size\x94M\x00\x10\x8c\x03cxx\x94N\x8c\x03g++\x94\x86\x94\x8c\x15enable_kernel_profile\x94\x89h\x03Nu\x8c\x06triton\x94}\x94(hDh\x02\x8c\ncudagraphs\x94\x88\x8c\x10debug_sync_graph\x94\x89\x8c\x11debug_sync_kernel\x94\x89\x8c\x0bconvolution\x94\x8c\x04aten\x94\x8c\x0edense_indexing\x94\x89\x8c\tmax_tiles\x94K\x02\x8c\x12autotune_pointwise\x94\x88\x8c tiling_prevents_pointwise_fusion\x94\x88\x8c tiling_prevents_reduction_fusion\x94\x88\x8c\x14ordered_kernel_names\x94\x89\x8c\x18descriptive_kernel_names\x94\x89h\x03Nu\x8c\x05trace\x94}\x94(hDh\x02\x8c\x07enabled\x94\x89\x8c\tdebug_log\x94\x88\x8c\x08info_log\x94\x89\x8c\x08fx_graph\x94\x88\x8c\x14fx_graph_transformed\x94\x88\x8c\rir_pre_fusion\x94\x88\x8c\x0eir_post_fusion\x94\x88\x8c\x0boutput_code\x94\x88\x8c\rgraph_diagram\x94\x89\x8c\x0fcompile_profile\x94\x89\x8c\nupload_tar\x94Nh\x03Nu\x8c\x15InductorConfigContext\x94}\x94(hDh\x02\x8c\x0f__annotations__\x94}\x94(\x8c\rstatic_memory\x94\x8c\x08builtins\x94\x8c\x04bool\x94\x93\x94\x8c\x0ematmul_padding\x94hoh-ho\x8c\x12triton_convolution\x94hm\x8c\x03str\x94\x93\x94\x8c\x17rematerialize_threshold\x94hm\x8c\x03int\x94\x93\x94\x8c\x1brematerialize_acc_threshold\x94hvu\x8c\x05_save\x94h\x02\x8c\x1bInductorConfigContext._save\x94\x93\x94\x8c\x06_apply\x94h\x02\x8c\x1cInductorConfigContext._apply\x94\x93\x94\x8c\x08__init__\x94h\x02\x8c\x1eInductorConfigContext.__init__\x94\x93\x94\x8c\t__enter__\x94h\x02\x8c\x1fInductorConfigContext.__enter__\x94\x93\x94\x8c\x08__exit__\x94h\x02\x8c\x1eInductorConfigContext.__exit__\x94\x93\x94h\x03Nu\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94h\x87\x93\x94u.')


# REPLACEABLE COMMENT FOR TESTING PURPOSES


# torch version: 2.0.0a0+gitc4ccf7e
# torch cuda version: 11.8
# torch git version: c4ccf7e12147671fdc3535a222260d687c2128a2


# CUDA Info: 
# nvcc: NVIDIA (R) Cuda compiler driver 
# Copyright (c) 2005-2022 NVIDIA Corporation 
# Built on Wed_Sep_21_10:33:58_PDT_2022 
# Cuda compilation tools, release 11.8, V11.8.89 
# Build cuda_11.8.r11.8/compiler.31833905_0 

# GPU Hardware Info: 
# NVIDIA GeForce RTX 4090 : 1 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant0', torch.randn([], dtype=torch.float32).cuda())
        self.register_buffer('_tensor_constant1', torch.randn([], dtype=torch.float32).cuda())
        self.register_buffer('_tensor_constant2', torch.randn([], dtype=torch.float32).cuda())



    def forward(self, primals_6, view_1, select, select_1, pow_2, pow_4, bmm, div_2, view_7, permute_8, permute_13, permute_14, permute_15, permute_16, permute_19, tangents_1):
        _tensor_constant0 = self._tensor_constant0
        lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
        maximum = torch.ops.aten.maximum.default(pow_2, lift_fresh_copy);  lift_fresh_copy = None
        expand = torch.ops.aten.expand.default(maximum, [400, 6, 49, 10]);  maximum = None
        div = torch.ops.aten.div.Tensor(select, expand)
        _tensor_constant1 = self._tensor_constant1
        lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1);  _tensor_constant1 = None
        maximum_1 = torch.ops.aten.maximum.default(pow_4, lift_fresh_copy_1);  lift_fresh_copy_1 = None
        expand_1 = torch.ops.aten.expand.default(maximum_1, [400, 6, 49, 10]);  maximum_1 = None
        div_1 = torch.ops.aten.div.Tensor(select_1, expand_1)
        view_4 = torch.ops.aten.view.default(bmm, [400, 6, 49, 49]);  bmm = None
        _tensor_constant2 = self._tensor_constant2
        lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy.default(_tensor_constant2);  _tensor_constant2 = None
        minimum = torch.ops.aten.minimum.default(primals_6, lift_fresh_copy_2);  lift_fresh_copy_2 = None
        exp = torch.ops.aten.exp.default(minimum);  minimum = None
        full = torch.ops.aten.full.default([4, 64, 64, 60], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
        slice_scatter_1 = torch.ops.aten.slice_scatter.default(full, tangents_1, 3, 0, 9223372036854775807);  full = tangents_1 = None
        full_1 = torch.ops.aten.full.default([4, 64, 70, 60], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
        slice_scatter_2 = torch.ops.aten.slice_scatter.default(full_1, slice_scatter_1, 2, 0, 64);  full_1 = slice_scatter_1 = None
        full_2 = torch.ops.aten.full.default([4, 70, 70, 60], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
        slice_scatter_3 = torch.ops.aten.slice_scatter.default(full_2, slice_scatter_2, 1, 0, 64);  slice_scatter_2 = None
        slice_scatter_4 = torch.ops.aten.slice_scatter.default(full_2, slice_scatter_3, 0, 0, 9223372036854775807);  full_2 = slice_scatter_3 = None
        view_10 = torch.ops.aten.view.default(slice_scatter_4, [4, 10, 7, 10, 7, 60]);  slice_scatter_4 = None
        permute_7 = torch.ops.aten.permute.default(view_10, [0, 1, 3, 2, 4, 5]);  view_10 = None
        clone_8 = torch.ops.aten.clone.default(permute_7, memory_format = torch.contiguous_format);  permute_7 = None
        _unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_8, [400, 49, 60]);  clone_8 = None
        view_11 = torch.ops.aten.view.default(_unsafe_view_6, [19600, 60]);  _unsafe_view_6 = None
        mm = torch.ops.aten.mm.default(view_11, permute_8);  permute_8 = None
        permute_9 = torch.ops.aten.permute.default(view_11, [1, 0])
        mm_1 = torch.ops.aten.mm.default(permute_9, view_7);  permute_9 = view_7 = None
        permute_10 = torch.ops.aten.permute.default(mm_1, [1, 0]);  mm_1 = None
        sum_4 = torch.ops.aten.sum.dim_IntList(view_11, [0], True);  view_11 = None
        view_12 = torch.ops.aten.view.default(sum_4, [60]);  sum_4 = None
        permute_11 = torch.ops.aten.permute.default(permute_10, [1, 0]);  permute_10 = None
        view_13 = torch.ops.aten.view.default(mm, [400, 49, 60]);  mm = None
        view_14 = torch.ops.aten.view.default(view_13, [400, 49, 6, 10]);  view_13 = None
        permute_12 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3]);  view_14 = None
        clone_9 = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format);  permute_12 = None
        _unsafe_view_7 = torch.ops.aten._unsafe_view.default(clone_9, [2400, 49, 10]);  clone_9 = None
        bmm_2 = torch.ops.aten.bmm.default(permute_13, _unsafe_view_7);  permute_13 = None
        bmm_3 = torch.ops.aten.bmm.default(_unsafe_view_7, permute_14);  _unsafe_view_7 = permute_14 = None
        view_15 = torch.ops.aten.view.default(bmm_2, [400, 6, 49, 10]);  bmm_2 = None
        view_16 = torch.ops.aten.view.default(bmm_3, [400, 6, 49, 49]);  bmm_3 = None
        mul_1 = torch.ops.aten.mul.Tensor(view_16, div_2);  view_16 = None
        sum_5 = torch.ops.aten.sum.dim_IntList(mul_1, [-1], True)
        mul_2 = torch.ops.aten.mul.Tensor(div_2, sum_5);  div_2 = sum_5 = None
        sub_1 = torch.ops.aten.sub.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
        sum_6 = torch.ops.aten.sum.dim_IntList(sub_1, [0], True)
        mul_3 = torch.ops.aten.mul.Tensor(sub_1, view_4);  view_4 = None
        mul_4 = torch.ops.aten.mul.Tensor(sub_1, exp);  sub_1 = None
        sum_7 = torch.ops.aten.sum.dim_IntList(mul_3, [0, 2, 3], True);  mul_3 = None
        view_17 = torch.ops.aten.view.default(sum_7, [6, 1, 1]);  sum_7 = None
        mul_5 = torch.ops.aten.mul.Tensor(view_17, exp);  view_17 = exp = None
        scalar_tensor = torch.ops.aten.scalar_tensor.default(0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
        le = torch.ops.aten.le.Scalar(primals_6, 4.605170185988092);  primals_6 = None
        where = torch.ops.aten.where.self(le, mul_5, scalar_tensor);  le = mul_5 = None
        view_18 = torch.ops.aten.view.default(mul_4, [2400, 49, 49]);  mul_4 = None
        bmm_4 = torch.ops.aten.bmm.default(permute_15, view_18);  permute_15 = None
        bmm_5 = torch.ops.aten.bmm.default(view_18, permute_16);  view_18 = permute_16 = None
        view_19 = torch.ops.aten.view.default(bmm_4, [400, 6, 10, 49]);  bmm_4 = None
        view_20 = torch.ops.aten.view.default(bmm_5, [400, 6, 49, 10]);  bmm_5 = None
        permute_17 = torch.ops.aten.permute.default(view_19, [0, 1, 3, 2]);  view_19 = None
        neg = torch.ops.aten.neg.default(permute_17)
        div_4 = torch.ops.aten.div.Tensor(div_1, expand_1);  div_1 = None
        mul_6 = torch.ops.aten.mul.Tensor(neg, div_4);  neg = div_4 = None
        div_5 = torch.ops.aten.div.Tensor(permute_17, expand_1);  permute_17 = expand_1 = None
        sum_8 = torch.ops.aten.sum.dim_IntList(mul_6, [3], True);  mul_6 = None
        ge = torch.ops.aten.ge.Scalar(pow_4, 1e-12)
        where_1 = torch.ops.aten.where.self(ge, sum_8, scalar_tensor);  ge = sum_8 = None
        div_6 = torch.ops.aten.div.Tensor(select_1, pow_4);  select_1 = None
        eq = torch.ops.aten.eq.Scalar(pow_4, 0);  pow_4 = None
        where_2 = torch.ops.aten.where.self(eq, scalar_tensor, div_6);  eq = div_6 = None
        mul_7 = torch.ops.aten.mul.Tensor(where_1, where_2);  where_1 = where_2 = None
        add_1 = torch.ops.aten.add.Tensor(div_5, mul_7);  div_5 = mul_7 = None
        neg_1 = torch.ops.aten.neg.default(view_20)
        div_8 = torch.ops.aten.div.Tensor(div, expand);  div = None
        mul_8 = torch.ops.aten.mul.Tensor(neg_1, div_8);  neg_1 = div_8 = None
        div_9 = torch.ops.aten.div.Tensor(view_20, expand);  view_20 = expand = None
        sum_9 = torch.ops.aten.sum.dim_IntList(mul_8, [3], True);  mul_8 = None
        ge_1 = torch.ops.aten.ge.Scalar(pow_2, 1e-12)
        where_3 = torch.ops.aten.where.self(ge_1, sum_9, scalar_tensor);  ge_1 = sum_9 = None
        div_10 = torch.ops.aten.div.Tensor(select, pow_2);  select = None
        eq_1 = torch.ops.aten.eq.Scalar(pow_2, 0);  pow_2 = None
        where_4 = torch.ops.aten.where.self(eq_1, scalar_tensor, div_10);  eq_1 = scalar_tensor = div_10 = None
        mul_9 = torch.ops.aten.mul.Tensor(where_3, where_4);  where_3 = where_4 = None
        add_2 = torch.ops.aten.add.Tensor(div_9, mul_9);  div_9 = mul_9 = None
        full_4 = torch.ops.aten.full.default([3, 400, 6, 49, 10], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
        select_scatter = torch.ops.aten.select_scatter.default(full_4, view_15, 0, 2);  view_15 = None
        select_scatter_1 = torch.ops.aten.select_scatter.default(full_4, add_1, 0, 1);  add_1 = None
        add_3 = torch.ops.aten.add.Tensor(select_scatter, select_scatter_1);  select_scatter = select_scatter_1 = None
        select_scatter_2 = torch.ops.aten.select_scatter.default(full_4, add_2, 0, 0);  full_4 = add_2 = None
        add_4 = torch.ops.aten.add.Tensor(add_3, select_scatter_2);  add_3 = select_scatter_2 = None
        permute_18 = torch.ops.aten.permute.default(add_4, [1, 3, 0, 2, 4]);  add_4 = None
        clone_10 = torch.ops.aten.clone.default(permute_18, memory_format = torch.contiguous_format);  permute_18 = None
        _unsafe_view_8 = torch.ops.aten._unsafe_view.default(clone_10, [400, 49, 180]);  clone_10 = None
        view_21 = torch.ops.aten.view.default(_unsafe_view_8, [19600, 180]);  _unsafe_view_8 = None
        mm_2 = torch.ops.aten.mm.default(view_21, permute_19);  permute_19 = None
        permute_20 = torch.ops.aten.permute.default(view_21, [1, 0])
        mm_3 = torch.ops.aten.mm.default(permute_20, view_1);  permute_20 = view_1 = None
        permute_21 = torch.ops.aten.permute.default(mm_3, [1, 0]);  mm_3 = None
        sum_10 = torch.ops.aten.sum.dim_IntList(view_21, [0], True);  view_21 = None
        view_22 = torch.ops.aten.view.default(sum_10, [180]);  sum_10 = None
        permute_22 = torch.ops.aten.permute.default(permute_21, [1, 0]);  permute_21 = None
        view_23 = torch.ops.aten.view.default(mm_2, [400, 49, 60]);  mm_2 = None
        slice_9 = torch.ops.aten.slice.Tensor(view_22, 0, 60, 120)
        clone_11 = torch.ops.aten.clone.default(slice_9, memory_format = torch.contiguous_format);  slice_9 = None
        full_like_1 = torch.ops.aten.full_like.default(clone_11, 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format);  clone_11 = None
        slice_scatter_5 = torch.ops.aten.slice_scatter.default(view_22, full_like_1, 0, 60, 120);  view_22 = full_like_1 = None
        view_24 = torch.ops.aten.view.default(view_23, [4, 10, 10, 7, 7, 60]);  view_23 = None
        permute_23 = torch.ops.aten.permute.default(view_24, [0, 1, 3, 2, 4, 5]);  view_24 = None
        clone_12 = torch.ops.aten.clone.default(permute_23, memory_format = torch.contiguous_format);  permute_23 = None
        _unsafe_view_9 = torch.ops.aten._unsafe_view.default(clone_12, [4, 70, 70, 60]);  clone_12 = None
        return [permute_22, permute_11, sum_6, slice_scatter_5, view_12, where, _unsafe_view_9]

args = [((6, 1, 1), (1, 1, 1), torch.float32, 'cuda'), ((19600, 60), (60, 1), torch.float32, 'cuda'), ((400, 6, 49, 10), (8820, 10, 180, 1), torch.float32, 'cuda'), ((400, 6, 49, 10), (8820, 10, 180, 1), torch.float32, 'cuda'), ((400, 6, 49, 1), (294, 49, 1, 1), torch.float32, 'cuda'), ((400, 6, 49, 1), (294, 49, 1, 1), torch.float32, 'cuda'), ((2400, 49, 49), (2401, 49, 1), torch.float32, 'cuda'), ((400, 6, 49, 49), (14406, 2401, 49, 1), torch.float32, 'cuda'), ((19600, 60), (60, 1), torch.float32, 'cuda'), ((60, 60), (60, 1), torch.float32, 'cuda'), ((2400, 49, 49), (2401, 1, 49), torch.float32, 'cuda'), ((2400, 10, 49), (490, 1, 10), torch.float32, 'cuda'), ((2400, 10, 49), (490, 1, 10), torch.float32, 'cuda'), ((2400, 49, 10), (490, 1, 49), torch.float32, 'cuda'), ((180, 60), (60, 1), torch.float32, 'cuda'), ((4, 64, 64, 60), (245760, 3840, 60, 1), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro(), tracing_mode='real')(*args)


from functools import partial
from torch._dynamo.debug_utils import (
    isolate_fails,
    dump_compiler_graph_state,
)
from functorch.compile import minifier

env_variables = {"CUDA_VISIBLE_DEVICES": "0"}

minifier(
    mod,
    args,
    module_fails=partial(isolate_fails, env=env_variables, compiler_name="inductor", patch_code=isolate_fails_code_str),
    dump_state=partial(dump_compiler_graph_state, compiler_name="inductor"),
)

Versions

PyTorch version: 2.0.0a0+gitc4ccf7e
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: 16.0.0 (++20230130103025+16a5dd495d02-1~exp1~20230130223133.7)
CMake version: version 3.25.2
Libc version: glibc-2.35

Python version: 3.10.9+ (main, Feb  1 2023, 12:46:32) [Clang 16.0.0 (++20230130103025+16a5dd495d02-1~exp1~20230130223133.7)] (64-bit runtime)
Python platform: Linux-6.1.8-200.fc37.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 525.85.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.4
[pip3] pytorch-lightning==1.9.0
[pip3] torch==2.0.0a0+gitc4ccf7e
[pip3] torch-fidelity==0.3.0
[pip3] torchmetrics==0.11.1
[pip3] torchvision==0.15.0a0
[conda] Could not collect

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions