Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] sum operator int32 compile+ eager backened fails #100698

Closed
jay746 opened this issue May 5, 2023 · 6 comments
Closed

[PT2] sum operator int32 compile+ eager backened fails #100698

jay746 opened this issue May 5, 2023 · 6 comments
Labels
actionable good first issue module: primTorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jay746
Copy link

jay746 commented May 5, 2023

馃悰 Describe the bug

With sum operator in int32 variant with output initialized to empty fails with dtype argument and out dtype must match in reduction

Please use below code to reproduce the issue:

import torch

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, param):
        z = torch.sum(**param)
        return z


if __name__ == "__main__":
    model = Repro()
    params = {'dim': [1, -1], 'keepdim': True, "input" : torch.randn([64, 54, 43]).to(torch.int32)}
    params["out"] = torch.empty(0, dtype=torch.int32)
    model = torch.compile(model, backend="eager")
    res = model(params)
    print(res)

Error logs

Traceback (most recent call last):
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1194, in run_node
    return node.target(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1170, in dispatch
    r = func(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_ops.py", line 287, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_refs/__init__.py", line 2161, in sum
    return _reduction(
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_refs/__init__.py", line 2039, in _reduction
    raise RuntimeError(
RuntimeError: dtype argument and out dtype must match in reduction

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1152, in get_fake_value
    return wrap_fake_exception(
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 808, in wrap_fake_exception
    return fn()
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1153, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1206, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <built-in method sum of type object at 0x7f648627be80>(*(), **{'dim': [1, -1], 'keepdim': True, 'input': FakeTensor(FakeTensor(..., device='meta', size=(64, 54, 43), dtype=torch.int32), cpu), 'out': FakeTensor(FakeTensor(..., device='meta', size=(0,), dtype=torch.int32), cpu)}):
dtype argument and out dtype must match in reduction
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "frexp_repro.py", line 17, in <module>
    res = model(params)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1002, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/variables/torch.py", line 548, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 754, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 789, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1173, in get_fake_value
    raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError: 

from user code:
   File "frexp_repro.py", line 8, in forward
    z = torch.sum(**param)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

import torch

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, param):
        z = torch.sum(**param)
        return z


if __name__ == "__main__":
    model = Repro()
    params = {'dim': [1, -1], 'keepdim': True, "input" : torch.randn([64, 54, 43]).to(torch.int32)}
    params["out"] = torch.empty(0, dtype=torch.int32)
    model = torch.compile(model, backend="eager")
    res = model(params)
    print(res)

Versions

Name: torch
Version: 2.0.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /home/jthakur/.venv_pt_op_test/lib/python3.8/site-packages
Requires: filelock, jinja2, networkx, sympy, typing-extensions

cc @ezyang @mruberry @ngimel @lezcano @peterbell10 @soumith @msaroufim @wconstab @bdhirsh @anijain2305

@ezyang
Copy link
Contributor

ezyang commented May 5, 2023

I'm actually not sure this is a good first issue, but it's all in Python and a strong engineer should be able to figure it out. The PrimTorch decomposition for sum must be adjusted to handle mismatch between out precision and inferred precision. You'll need to first figure out what the eager semantics are, and then replicate them (or conclude the eager semantics are wrong and fixup eager not to allow this program.)

cc @cpuhrsch

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 6, 2023
@pytorch pytorch deleted a comment from shivamkc01 May 7, 2023
@ekamiti
Copy link
Contributor

ekamiti commented May 9, 2023

Hi, I'm new to pytorch contributing and I'd like to work on this issue. Seems like for integer types the output dtype is set to int64 I assume to handle potential overflow.

if dtype is None:
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
dtype = torch.int64
else:
dtype = a.dtype

But then this results in a failure here:

if out is not None:
assert isinstance(out, TensorLike)
if dtype is not None:
# TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
if dtype != out.dtype:
raise RuntimeError(
"dtype argument and out dtype must match in reduction"
)

If this handling is meant to address potential overflow issues. I see that for floats such handling is done here:

computation_dtype, result_dtype = utils.reduction_dtypes(
a, output_dtype_kind, dtype
)

With the relevant mappings here:

_computation_dtype_map = {
torch.bfloat16: torch.float32,
torch.float16: torch.float32,
torch.complex32: torch.complex64,
}

If I have diagnosed the issue correctly then potential solutions are:

  1. Remove special case handling in the sum code for integer types and warn user of possible overflow
  2. Remove special case handling in the sum code for integer types and add a similar mapping for computation precision mapping for integer types as done for floating point types in the _computation_dtype_map
  3. Leave special case handling in place in the sum code and add more handling to prevent raising an exception by adding checks in the _reduction code to make sure the provided reduction dtype and the output dtype are at least compatible i.e. both integer types, or floating types and just warn about possible overflow.

Let me know if I'm on the right track or if I'm overlooking something important, I assume there must be a good reason why float type promotion seems to be handled more gracefully but not integers. Thanks!

@ezyang
Copy link
Contributor

ezyang commented May 15, 2023

The important thing to figure out is how the eager mode test is done. Can you check that out?

@ekamiti
Copy link
Contributor

ekamiti commented May 16, 2023

The important thing to figure out is how the eager mode test is done. Can you check that out?

Ok I'll figure that out.

@ekamiti
Copy link
Contributor

ekamiti commented May 30, 2023

Seems like for eager mode the semantics are (what's highlighted are the args to sum) :

  • If both out is specified and dtype is specified then they have to match
  • If dtype is not specified but out is specified then the dtype is set to match the out dtype
  • If neither dtype nor out is set then the dtype is set to kLong if it is a bool or an integral type

Relevant code:

static ScalarType infer_dtype_from_optional(
const Tensor& self,
const optional<ScalarType>& opt_dtype,
const Tensor& result) {
// 'opt_dtype' has the priority for both cases.
if (result.defined()) {
// Otherwise, get the result type, if defined.
return opt_dtype.value_or(result.scalar_type());
} else {
// Last case is to get the self type.
// If the self type is an integer, we promote it to kLong.
return at::native::get_dtype_from_self(self, opt_dtype, true);
}
}

inline ScalarType get_dtype_from_self(
const Tensor& self,
const c10::optional<ScalarType>& dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}
ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}

These semantics make sense to me so seems that the TorchDynamo semantics should be updated to match this.

@ekamiti
Copy link
Contributor

ekamiti commented May 30, 2023

Seems the following patch works to prevent the failure but I will not be able to submit a PR for it until at least next week as I will be away for a few days and still need to go through what is required for testing and other pytorch PR requirements.

index cfc81762468..0ca454f5e97 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -2176,7 +2176,9 @@ def sum(
     out: Optional[Tensor] = None,
 ) -> TensorLikeType:
     if dtype is None:
-        if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
+        if out is not None:
+            dtype = out.dtype
+        elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
             dtype = torch.int64
         else:
             dtype = a.dtype

ekamiti added a commit to ekamiti/pytorch that referenced this issue Jun 6, 2023
Fixes [pytorch#100698](pytorch#100698)
The current behaviour for dynamo is to set the dtype to torch.int64 for
integral types if the dtype is not specified explicitly which results in
mismatched behaviour as compared to eager mode. In eager mode the semantics are:
- If both out is specified and dtype is specified then they have to match
- If dtype is not specified but out is specified then the dtype is set to match the out dtype
- If neither dtype nor out is set then the dtype is set to kLong if it is a bool or an integral type
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable good first issue module: primTorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants