-
Couldn't load subscription status.
- Fork 25.7k
Bugfix to forward autodiff causing different datatype 2 #165784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…python side to handle dtype promotions.
…numbers. Then using the correct dtype promotions on the python side.
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…python side to handle dtype promotions.
…numbers. Then using the correct dtype promotions on the python side.
…rch into bugfix/dtype_foward_agrad
…erations caused dtypes to be different.
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
|
@pytorchbot label "module: forward ad" |
|
@ezyang - continuing from our conversation in the closed PR... |
|
My intuition is that we should relax the is_wrapped_number setter so we are able to have a zerotensor 0d tensor be a wrapped number. This lets us repesent a zero 0d regular tensor. Do you know where the check you're running against is? I checked set_wrapped_number but it seems like it should work with 0d tensor. |
|
In This is why I needed a new property |
|
Ok, this makes sense. I propose we relax the assert. Specifically if you have a zero tensor which is 0d then it is also ok for the device type to be meta. Another strategy would be to avoid use of zero tensor when 0d but I do not know how easy that is. |
|
Okay I'll pursue that avenue of relaxing the assert. I'll check to see if it is also a meta device with dim=0 to allow for the property to pass through. As for avoiding using |
|
@ezyang Your thoughts? |
|
Is there a reason you have to update the constructor? I would have just set the property into the field directly after you made the ZeroTensor. |
|
Oh I just assumed from your comments about pursing my initial plan was that you wanted me to not add code to the autogeneration but do it through the constructor. ( |
|
we discussed this at PTC, there is a next step |
…l assert to allow meta devices types with 0dims through.
…rch into bugfix/dtype_foward_agrad
|
@ezyang So I removed the |
|
|
||
| if (ivalue.isTensor()) { | ||
| auto tensor = std::move(ivalue).toTensor(); | ||
| bool is_wrapped_number = (tensor.unsafeGetTensorImpl()->is_wrapped_number()) ? true : false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why you have to ternary here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was dumb of me...
| TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ", | ||
| "obtained using .clone() if you want a mutable tensor."); | ||
| tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes()); | ||
| Tensor updated_tensor = tensors[j]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip writing into tensors[j] and then writing into it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, dumb of me... Will get this updated.
| r.copy_(self, non_blocking); | ||
| if (self.unsafeGetTensorImpl()->is_wrapped_number()) { | ||
| r.unsafeGetTensorImpl()->set_wrapped_number(true); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two are kind of sus, are you sure these were necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So when the copy_ gets invoked, the is_wrapped_number does not get copied over which I thought was weird. Without this setter, the flag gets lost.
| "is_mkldnn": ["is_mkldnn: _bool"], | ||
| "is_vulkan": ["is_vulkan: _bool"], | ||
| "is_ipu": ["is_ipu: _bool"], | ||
| "is_wrapped_number": ["is_wrapped_number: _bool"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be necessary to expose this in Python; wrapped numbers always turn into plain numbers when you get to Python. Why did you need it? Is it because of ZeroTensor wrapped number? I think ZeroTensor should turn into a plain number too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the reason this is needed is because ZeroTensor is a meta tensor with an empty storage. There is no number to convert to get to the python side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we turn ZeroTensor into a plain number then this become easier. I don't have to expose anything on the python side. I do wonder what would happen to precison/accuracy if a double precision ZeroTensor gets converted to just a plain number. The resulting scalar tensor will be a lower precision if the said scalar tensor is a float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I do think the mul_zerotensor (pytorch/aten/src/ATen/native/BinaryOps.cpp:1012-1019) will be affected. Line 1017 is the one in question. I think it'll be okay if we handle the conversion in the JIT code.
| zero_dim_tensor_dtype = get_higher_dtype( | ||
| zero_dim_tensor_dtype, _dtype | ||
| ) | ||
| if x.is_wrapped_number: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this shouldn't be possible here, specifically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So when I go through the stack trace and debug here specifically, the arguments are tensors and they are both meta devices. I think this is because of the forward autograd operation requiring them to be tensors all the way through. I can post the stack trace up until this point. The code uses two meta tensors to get the respective dtypes.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is why at this point, the code needs to know if those meta devices are wrapped numbers to properly determine the right promotion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is promising but I hope it can be tighter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is promising but I hope it can be tighter
Fixes #160513
The Problem Summary
The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/init.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.
The Fix
The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the
is_wrapped_numbermethod but this came with a big issue. During theforward_adthe derivative of those scalars turned out to beZeroTensors. TheZeroTensorinternally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially theis_number_wrapped_property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests foris_wrapped_number_requiresdim > 0and a scalarZeroTensoris a meta dtype tensor which complicates things.So I chose the route of creating a new property called
was_wrapped_numberand exposed this property to the python tensor API. I had to modify the autograd code generation to setwas_wrapped_numberin the mul, add, and div operations inVariableType.cpp. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.I wrote a new ops testing module
TestForwardADWithScalars. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.@ezyang @OihanJoyot
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel