Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot convert -oo to int #128245

Open
bhack opened this issue Jun 7, 2024 · 4 comments
Open

Cannot convert -oo to int #128245

bhack opened this issue Jun 7, 2024 · 4 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bhack
Copy link
Contributor

bhack commented Jun 7, 2024

馃悰 Describe the bug

This is working correctly in eager mode.

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))

    x = windows.view(B, H // window_size, W // window_size, window_size,
                     window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

Error logs

  File "/workspace/tools/eval.py", line 135, in <module>
    main()
  File "/workspace/tools/eval.py", line 130, in main
    main_worker(0, cfg, enable_amp=args.amp)
  File "/workspace/tools/eval.py", line 30, in main_worker
    evaluator.evaluating()
  File "/workspace/networks/managers/evaluator.py", line 474, in evaluating
    engine.match_propogate_one_frame(current_img)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 432, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/engines/aotv3_engine.py", line 657, in match_propogate_one_frame
    aot_engine.match_propogate_one_frame(img, img_embs=img_embs)
  File "/workspace/networks/engines/aotv3_engine.py", line 374, in match_propogate_one_frame
    def match_propogate_one_frame(self, img=None, img_embs=None):
  File "/workspace/networks/engines/aotv3_engine.py", line 127, in encode_one_img_mask
    def encode_one_img_mask(self, img=None, mask=None, frame_step=-1):
  File "/workspace/networks/models/aotv3.py", line 162, in encode_image
    def encode_image(self, img):
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/encoders/swin/swin_transformer.py", line 702, in forward
    x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
                             ^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/encoders/swin/swin_transformer.py", line 426, in forward
    def forward(self, x, H, W):
  File "/workspace/networks/encoders/swin/swin_transformer.py", line 434, in torch_dynamo_resume_in_forward_at_434
    Hp = int(np.ceil(H / self.window_size)) * self.window_size
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1115, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 947, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 471, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 83, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 816, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 635, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1184, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 581, in transform
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2455, in run
    super().run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 809, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 503, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2063, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 747, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 841, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 91, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 753, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2786, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 809, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 503, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1504, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 747, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 342, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 91, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 753, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2786, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 809, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 503, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2063, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 747, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 91, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 753, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2786, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 809, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 503, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2063, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 747, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 692, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 476, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1854, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1786, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1301, in wrap_fake_exception
    return fn()
           ^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1787, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1922, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1906, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1060, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1449, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1152, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1689, in _dispatch_impl
    return decomposition_table[func](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_refs/__init__.py", line 4551, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3624, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_prims_common/__init__.py", line 892, in infer_size
    if d == -1:
       ^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/__init__.py", line 562, in __bool__
    return self.node.bool_()
           ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 475, in bool_
    return self.guard_bool("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 413, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 244, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5153, in evaluate_expr
    static_expr = self._maybe_evaluate_static(expr,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1441, in wrapper
    return fn_cache(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4529, in _maybe_evaluate_static
    out = bound_sympy(new_expr, new_range_env)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/value_ranges.py", line 1019, in bound_sympy
    return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in sympy_interp
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in <listcomp>
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in sympy_interp
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in <listcomp>
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in sympy_interp
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in <listcomp>
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in sympy_interp
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 134, in <listcomp>
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/interp.py", line 166, in sympy_interp
    return handler(*args)
           ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/value_ranges.py", line 522, in int_truediv
    return ValueRanges.coordinatewise_monotone_map(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/value_ranges.py", line 371, in coordinatewise_monotone_map
    products = [
               ^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/value_ranges.py", line 372, in <listcomp>
    fn(a, b)
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/functions.py", line 29, in inner
    r = f(*args)
        ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/sympy/core/function.py", line 466, in __new__
    result = super().__new__(cls, *args, **options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/sympy/core/function.py", line 307, in __new__
    evaluated = cls.eval(*args)
                ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_sympy/functions.py", line 513, in eval
    return sympy.Float(int(base) / int(divisor))
                       ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/sympy/core/expr.py", line 325, in __int__
    raise TypeError("Cannot convert %s to int" % r)
torch._dynamo.exc.TorchRuntimeError: Failed running call_method view(*(FakeTensor(..., device='cuda:0',
           size=((((s4 + Mod(7 - Mod(s4, 7), 7))//7))*((((s1//s4) + Mod(7 - Mod(s5, 7), 7))//7)), 7, (((((s4 + Mod(7 - Mod(s4, 7), 7))//(((s4 + Mod(7 - Mod(s4, 7), 7))//7))))*((((s1//s4) + Mod(7 - Mod(s5, 7), 7))//((((s1//s4) + Mod(7 - Mod(s5, 7), 7))//7)))))//7), 256),
           dtype=torch.float16), TruncToInt(FloatTrueDiv(ToFloat((((s4 + Mod(7 - Mod(s4, 7), 7))//7))*((((s1//s4) + Mod(7 - Mod(s5, 7), 7))//7))), FloatTrueDiv(IntTrueDiv(s4*((s1//s4)) + s4*(Mod(7 - Mod(s5, 7), 7)) + ((s1//s4))*(Mod(7 - Mod(s4, 7), 7)) + (Mod(7 - Mod(s4, 7), 7))*(Mod(7 - Mod(s5, 7), 7)), 7), 7.0))), ((s4 + Mod(7 - Mod(s4, 7), 7))//7), (((s1//s4) + Mod(7 - Mod(s5, 7), 7))//7), 7, 7, -1), **{}):
Cannot convert -oo to int

from user code:
   File "/workspace/networks/encoders/swin/swin_transformer.py", line 464, in torch_dynamo_resume_in_forward_at_435
    if self.downsample is not None:
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/networks/encoders/swin/swin_transformer.py", line 303, in forward
    shifted_x = window_reverse(attn_windows, self.window_size, Hp,
  File "/workspace/networks/encoders/swin/swin_transformer.py", line 95, in window_reverse
    x = windows.view(B, H // window_size, W // window_size, window_size,

Minified repro

No response

Versions

pytorch-nightly

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@bhack
Copy link
Contributor Author

bhack commented Jun 7, 2024

@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2024

The patch that regressed this was reverted today, but #127693 should help this

@bhack
Copy link
Contributor Author

bhack commented Jun 7, 2024

Sometimes this fail fast policy is very confusing as we could think that another issue was solved with a new nightly but instead it could be a regression on an earlier part of the code so we could go to close an old ticket but instead is only masquerade by a new one.

But I suppose that there is any technical alternative to fail fast right?

@bhack
Copy link
Contributor Author

bhack commented Jun 8, 2024

I don't have anymore the error hotpatching with your:
curl -L https://github.com/pytorch/pytorch/pull/127693.diff | patch -p1 -d $(pip show torch | grep Location | awk '{print $2"/"}') --batch

But now it is never going to complete the compilation. See #127677 (comment)

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamic shapes labels Jun 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants