Skip to content

Conversation

@skpark-rh
Copy link
Contributor

@skpark-rh skpark-rh commented Oct 17, 2025

Fixes #160513

The Problem Summary

The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/init.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.

The Fix

The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the is_wrapped_number method but this came with a big issue. During the forward_ad the derivative of those scalars turned out to be ZeroTensors. The ZeroTensor internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the is_number_wrapped_ property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for is_wrapped_number_ requires dim > 0 and a scalar ZeroTensor is a meta dtype tensor which complicates things.

So I chose the route of creating a new property called was_wrapped_number and exposed this property to the python tensor API. I had to modify the autograd code generation to set was_wrapped_number in the mul, add, and div operations in VariableType.cpp. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.

I wrote a new ops testing module TestForwardADWithScalars. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.

@ezyang @OihanJoyot

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…numbers. Then using the correct dtype promotions on the python side.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…numbers. Then using the correct dtype promotions on the python side.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
@pytorch-bot pytorch-bot bot added the release notes: autograd release notes category label Oct 17, 2025
@skpark-rh
Copy link
Contributor Author

@pytorchbot label "module: forward ad"

@skpark-rh
Copy link
Contributor Author

skpark-rh commented Oct 17, 2025

@ezyang - continuing from our conversation in the closed PR...
The first way I tried has its limitations. The _efficientzerotensor_symint copies the size of the input tensor but that input tensor is an empty meta tensor. Even if I passed in an option of is_wrapped_number_, the device_type of ZeroTensor is a meta device with dim=0. So when the set_wrap_number is called it has a strict constraint saying that the tensor cannot be a scalar tensor and have is_wrapped_number_ set. Either we would have to remove that constraint from the setter or modify the _efficientzerotensor_symint to not be a meta device type. I just can't get around not having the ZeroTensor. The derivative of these scalars will have to be zeros.

@ezyang
Copy link
Contributor

ezyang commented Oct 18, 2025

My intuition is that we should relax the is_wrapped_number setter so we are able to have a zerotensor 0d tensor be a wrapped number. This lets us repesent a zero 0d regular tensor. Do you know where the check you're running against is? I checked set_wrapped_number but it seems like it should work with 0d tensor.

@skpark-rh
Copy link
Contributor Author

skpark-rh commented Oct 18, 2025

In VariableType_0.cpp:13495 -> result_new_fw_grad_opt = other_t * self_p + self_t * other_p; other_t and self_t becomes the ZeroTensor. I can set the is_wrapped_number_ property just fine here. The problem happens when the scalar operations start to execute. Once the addition/multiplication starts to happen the code eventually lands itself in the python code,torch/_prims_common/__init__.py:1537-1726, to compute the correct dtype. When the code lands in the python side of things, I need to access the is_wrapped_number_ or the _wrapped_number property. I tried to expose the property through the python bindings but the following code, torch/csrc/jit/python/pybind_utils.cpp:590, prevented me from exposing the property because the incoming ZeroTensor was a device type meta. There is a hard internal assert, TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());, preventing me from grabbing the is_wrapped_number_ property.

This is why I needed a new property was_wrapped_number to get around this constraint. I believe this problem becomes difficult because the code suddenly goes to the python side to compute the appropriate dtypes... If I can get away with removing that internal assert from the JIT code then this becomes a simpler problem. I just don't have the expertise or the experience to make that judgement.

@ezyang
Copy link
Contributor

ezyang commented Oct 18, 2025

Ok, this makes sense. I propose we relax the assert. Specifically if you have a zero tensor which is 0d then it is also ok for the device type to be meta.

Another strategy would be to avoid use of zero tensor when 0d but I do not know how easy that is.

@albanD albanD removed their request for review October 19, 2025 03:29
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 19, 2025
@skpark-rh
Copy link
Contributor Author

Okay I'll pursue that avenue of relaxing the assert. I'll check to see if it is also a meta device with dim=0 to allow for the property to pass through.

As for avoiding using ZeroTensor with 0dim, I do not know either... If the above path becomes a dead end then I can purse this route for sure.

@skpark-rh
Copy link
Contributor Author

skpark-rh commented Oct 20, 2025

@ezyang
So my comment got lost in the woods with the botched merged in my previous PR but I need your advice on modifying TensorOption. Now that I can set the is_wrapped_number property in the EfficientZeroTensor and see it in the python side, I need to propagate the is_wrapped_number from the parent tensor to the child tensor using TensorOption when the EfficientZeroTensor is created. I am planning to add is_wrapped_number as an optional boolean to EfficientZeroTensor. I think I can minimize the fallout by modifying the torchgen/native_function_generation.py to look for the pattern, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None and add bool? is_wrapped_number=None at the end of it. By doing this, I will minimize the fallout from inspecting all of tensor native function to just the ones with said pattern. Since the added boolean is optional, it won't do any harm to other native functions and will give the added benefit of allowing developers to use it when the time comes.

Your thoughts?

@ezyang
Copy link
Contributor

ezyang commented Oct 21, 2025

Is there a reason you have to update the constructor? I would have just set the property into the field directly after you made the ZeroTensor.

@skpark-rh
Copy link
Contributor Author

Oh I just assumed from your comments about pursing my initial plan was that you wanted me to not add code to the autogeneration but do it through the constructor. (tools/autograd/gen_variable_type.py:766-773,1919-1925) I basically am using a setter there. I assumed that modifying the TensorOptions will provide a more robust way to catch any dtype issues with ZeroTensor. I am only using the setter for the mul, add, and div operations (tools/autograd/gen_variable_type.py:1919).

@ezyang
Copy link
Contributor

ezyang commented Oct 27, 2025

we discussed this at PTC, there is a next step

@skpark-rh
Copy link
Contributor Author

@ezyang So I removed the was_wrapped_number property and used the is_wrapped_number property instead. I am exposing is_wrapped_number on the python side to allow the dtype promotion logic to consider wrapped python numbers and 0d tensor scalars.


if (ivalue.isTensor()) {
auto tensor = std::move(ivalue).toTensor();
bool is_wrapped_number = (tensor.unsafeGetTensorImpl()->is_wrapped_number()) ? true : false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why you have to ternary here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was dumb of me...

TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
"obtained using .clone() if you want a mutable tensor.");
tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
Tensor updated_tensor = tensors[j];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip writing into tensors[j] and then writing into it again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, dumb of me... Will get this updated.

r.copy_(self, non_blocking);
if (self.unsafeGetTensorImpl()->is_wrapped_number()) {
r.unsafeGetTensorImpl()->set_wrapped_number(true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are kind of sus, are you sure these were necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when the copy_ gets invoked, the is_wrapped_number does not get copied over which I thought was weird. Without this setter, the flag gets lost.

"is_mkldnn": ["is_mkldnn: _bool"],
"is_vulkan": ["is_vulkan: _bool"],
"is_ipu": ["is_ipu: _bool"],
"is_wrapped_number": ["is_wrapped_number: _bool"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be necessary to expose this in Python; wrapped numbers always turn into plain numbers when you get to Python. Why did you need it? Is it because of ZeroTensor wrapped number? I think ZeroTensor should turn into a plain number too.

Copy link
Contributor Author

@skpark-rh skpark-rh Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the reason this is needed is because ZeroTensor is a meta tensor with an empty storage. There is no number to convert to get to the python side.

Copy link
Contributor Author

@skpark-rh skpark-rh Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we turn ZeroTensor into a plain number then this become easier. I don't have to expose anything on the python side. I do wonder what would happen to precison/accuracy if a double precision ZeroTensor gets converted to just a plain number. The resulting scalar tensor will be a lower precision if the said scalar tensor is a float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I do think the mul_zerotensor (pytorch/aten/src/ATen/native/BinaryOps.cpp:1012-1019) will be affected. Line 1017 is the one in question. I think it'll be okay if we handle the conversion in the JIT code.

zero_dim_tensor_dtype = get_higher_dtype(
zero_dim_tensor_dtype, _dtype
)
if x.is_wrapped_number:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this shouldn't be possible here, specifically

Copy link
Contributor Author

@skpark-rh skpark-rh Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when I go through the stack trace and debug here specifically, the arguments are tensors and they are both meta devices. I think this is because of the forward autograd operation requiring them to be tensors all the way through. I can post the stack trace up until this point. The code uses two meta tensors to get the respective dtypes.
Screenshot From 2025-10-29 09-45-30

Copy link
Contributor Author

@skpark-rh skpark-rh Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why at this point, the code needs to know if those meta devices are wrapped numbers to properly determine the right promotion.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is promising but I hope it can be tighter

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is promising but I hope it can be tighter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: forward ad oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: autograd release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Forward autodiff : Multiplying by python float changes the dual dtype in some situations

5 participants