Skip to content

Commit 3ddf201

Browse files
Revert "Support setting grad_dtype on leaf tensors (#162815)"
This reverts commit dca7398. Reverted #162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](#162815 (comment)))
1 parent fac6f20 commit 3ddf201

19 files changed

+41
-375
lines changed

aten/src/ATen/core/Tensor.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,4 @@ unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)>
173173
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
174174
}
175175

176-
std::optional<ScalarType> TensorBase::grad_dtype() const {
177-
return impl::GetVariableHooks()->grad_dtype(*this);
178-
}
179-
180-
void TensorBase::set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const {
181-
return impl::GetVariableHooks()->set_grad_dtype(*this, grad_dtype);
182-
}
183-
184176
} // namespace at

aten/src/ATen/core/TensorBase.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -930,10 +930,6 @@ class TORCH_API TensorBase {
930930

931931
const TensorBase& requires_grad_(bool _requires_grad=true) const;
932932

933-
std::optional<ScalarType> grad_dtype() const;
934-
935-
void set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const;
936-
937933
// View Variables
938934
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
939935

aten/src/ATen/core/VariableHooksInterface.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ struct TORCH_API VariableHooksInterface {
6868
const c10::OperatorHandle& op,
6969
c10::DispatchKeySet dispatch_keys,
7070
torch::jit::Stack* stack) const = 0;
71-
virtual std::optional<c10::ScalarType> grad_dtype(const TensorBase&) const = 0;
72-
virtual void set_grad_dtype(const TensorBase&, const std::optional<c10::ScalarType>&) const = 0;
7371
};
7472

7573
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);

test/dynamo/test_backward_higher_order_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
140140
141141
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
142142
143-
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
143+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
144144
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
145145
146146
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@@ -171,7 +171,7 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
171171
172172
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
173173
174-
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
174+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
175175
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
176176
177177
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@@ -255,7 +255,7 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]",
255255
256256
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
257257
258-
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
258+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
259259
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
260260
261261
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None

test/inductor/test_compiled_autograd.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3604,12 +3604,12 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36043604
unwrap_maybe_dynamic_int_18 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_23); getitem_23 = None
36053605
unwrap_maybe_dynamic_int_19 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_24); getitem_24 = None
36063606
3607-
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True, 6)]); getitem = None
3607+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
36083608
getitem_25 = validate_outputs[0]; validate_outputs = None
36093609
36103610
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_25 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
36113611
getitem_26 = sum_backward0[0]; sum_backward0 = None
3612-
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True, 6)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
3612+
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
36133613
getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None
36143614
36153615
getitem_28 = hooks[0]; getitem_28 = None
@@ -3631,7 +3631,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36313631
call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None
36323632
getitem_36 = call_backward[0]
36333633
getitem_37 = call_backward[1]; call_backward = None
3634-
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False, 6)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
3634+
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
36353635
getitem_39 = validate_outputs_2[0]
36363636
36373637
call_accumulate_grad_1 = torch__dynamo_external_utils_call_accumulate_grad(getitem_4, getitem_39, False); getitem_4 = getitem_39 = call_accumulate_grad_1 = None
@@ -3866,12 +3866,12 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
38663866
unwrap_maybe_dynamic_int_10 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_12); getitem_12 = None
38673867
unwrap_maybe_dynamic_int_11 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_13); getitem_13 = None
38683868
3869-
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False, 6)]); getitem = None
3869+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
38703870
getitem_14 = validate_outputs[0]; validate_outputs = None
38713871
38723872
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_14 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
38733873
getitem_15 = sum_backward0[0]; sum_backward0 = None
3874-
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False, 6)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
3874+
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
38753875
getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None
38763876
38773877
getitem_17 = hooks[0]
@@ -3883,7 +3883,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
38833883
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None
38843884
getitem_21 = mul_backward0[0]
38853885
getitem_22 = mul_backward0[1]; mul_backward0 = None
3886-
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False, 6)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
3886+
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
38873887
getitem_23 = validate_outputs_2[0]
38883888
getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None
38893889
@@ -3892,7 +3892,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
38923892
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None
38933893
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None
38943894
getitem_27 = cos_backward0[0]; cos_backward0 = None
3895-
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False, 6)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
3895+
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
38963896
getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None
38973897
add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None
38983898
@@ -3901,7 +3901,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
39013901
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None
39023902
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
39033903
getitem_31 = sin_backward0[0]; sin_backward0 = None
3904-
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False, 6)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
3904+
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
39053905
getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None
39063906
39073907
call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None
@@ -5266,7 +5266,6 @@ def wrap_test_class(orig_cls):
52665266
"test_dropout_inductor", # functionalize_rng_ops not yet supported
52675267
"test_function_with_kwargs", # functionalize_rng_ops not yet supported
52685268
"test_module", # functionalize_rng_ops not yet supported
5269-
"test_grad_dtype", # AttributeError: args / Float did not match Double
52705269
},
52715270
"eager": { # will be run without torch.compiling the CA graph
52725271
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods

test/jit/test_builtins.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,6 @@ def fn(x):
326326
# This has a longer implementation, maybe not worth copying to
327327
# TorchScript if named tensors don't work there anyways
328328
"names",
329-
# We don't plan to support grad_dtype in TorchScript
330-
"grad_dtype",
331329
}
332330

333331
for p in properties:

test/test_autograd.py

Lines changed: 0 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -3694,130 +3694,6 @@ def test_sparse_gather_x_scalar(self):
36943694
def test_sparse_gather_both_scalar(self):
36953695
self._test_sparse_gather((), (), 0)
36963696

3697-
@skipIfTorchDynamo("grad_dtype not supported in compile")
3698-
def test_grad_dtype(self):
3699-
leaf = torch.tensor([1.0, 2.0], requires_grad=True)
3700-
# Default to tensor's dtype
3701-
self.assertEqual(leaf.grad_dtype, torch.float32)
3702-
leaf.grad_dtype = torch.float16
3703-
self.assertEqual(leaf.grad_dtype, torch.float16)
3704-
leaf.grad_dtype = None # Allow any dtype
3705-
self.assertIsNone(leaf.grad_dtype)
3706-
3707-
# get/set grad_dtype is only allowed on leaf tensors
3708-
non_leaf = leaf * 2
3709-
self.assertFalse(non_leaf.is_leaf)
3710-
with self.assertRaisesRegex(
3711-
RuntimeError, "grad_dtype can only be accessed on leaf tensors"
3712-
):
3713-
_ = non_leaf.grad_dtype
3714-
with self.assertRaisesRegex(
3715-
RuntimeError, "grad_dtype can only be set on leaf tensors"
3716-
):
3717-
non_leaf.grad_dtype = torch.float16
3718-
3719-
# Manual setting
3720-
x = torch.tensor([1.0, 2.0], requires_grad=True)
3721-
grad_match = torch.tensor([1.0, 1.0])
3722-
x.grad = grad_match
3723-
self.assertEqual(x.grad.dtype, torch.float32)
3724-
3725-
x.grad = None
3726-
x.grad_dtype = torch.float16
3727-
grad_mismatch = torch.tensor([1.0, 1.0])
3728-
with self.assertRaisesRegex(
3729-
RuntimeError,
3730-
"attempting to assign a gradient with dtype.*float.*to a tensor with grad_dtype.*Half",
3731-
):
3732-
x.grad = grad_mismatch
3733-
3734-
# When grad_dtype is None, any dtype is allowed
3735-
x.grad = None
3736-
x.grad_dtype = None
3737-
grad_any = torch.tensor([1.0, 1.0], dtype=torch.float64)
3738-
x.grad = grad_any
3739-
self.assertEqual(x.grad.dtype, torch.float64)
3740-
3741-
# Incoming gradient case
3742-
class MismatchedGradientFunction(torch.autograd.Function):
3743-
@staticmethod
3744-
def forward(ctx, inp):
3745-
return inp * 2
3746-
3747-
@staticmethod
3748-
def backward(ctx, grad_output):
3749-
return grad_output.to(torch.float64)
3750-
3751-
d = torch.tensor([1.0, 2.0], requires_grad=True)
3752-
output = MismatchedGradientFunction.apply(d)
3753-
loss = output.sum()
3754-
loss.backward()
3755-
# Default behavior is to cast to tensor dtype
3756-
self.assertEqual(d.grad.dtype, torch.float32)
3757-
self.assertTrue(torch.allclose(d.grad, torch.tensor([1.0, 1.0])))
3758-
3759-
e = torch.tensor([3.0, 4.0], requires_grad=True)
3760-
e.grad_dtype = None
3761-
output_e = MismatchedGradientFunction.apply(e)
3762-
loss_e = output_e.sum()
3763-
loss_e.backward()
3764-
# No casting is done if set to None.
3765-
self.assertTrue(
3766-
torch.allclose(e.grad, torch.tensor([1.0, 1.0], dtype=torch.float64))
3767-
)
3768-
3769-
f = torch.tensor([5.0, 6.0], requires_grad=True)
3770-
f.grad_dtype = torch.float16 # Expect float16 gradients
3771-
output_f = MismatchedGradientFunction.apply(f)
3772-
loss_f = output_f.sum()
3773-
loss_f.backward()
3774-
self.assertTrue(
3775-
torch.allclose(f.grad, torch.tensor([1.0, 1.0], dtype=torch.float16))
3776-
)
3777-
3778-
# Setting grad_dtype when gradient already exists
3779-
g = torch.tensor([1.0, 2.0], requires_grad=True)
3780-
g.grad = torch.tensor([1.0, 1.0])
3781-
g.grad_dtype = torch.float32
3782-
self.assertEqual(g.grad_dtype, torch.float32)
3783-
with self.assertRaisesRegex(
3784-
RuntimeError, "Cannot set grad_dtype.*because there is already a gradient"
3785-
):
3786-
g.grad_dtype = torch.float16
3787-
g.grad_dtype = None
3788-
self.assertIsNone(g.grad_dtype)
3789-
g.grad = None
3790-
g.grad_dtype = torch.float16
3791-
self.assertEqual(g.grad_dtype, torch.float16)
3792-
3793-
# Test the case where there is an existing accumulate grad
3794-
h = torch.tensor([1.0, 2.0], requires_grad=True)
3795-
_ = h.clone()
3796-
h.grad_dtype = None
3797-
output = MismatchedGradientFunction.apply(h)
3798-
output.sum().backward()
3799-
self.assertEqual(h.grad.dtype, torch.float64)
3800-
3801-
# Mixed accumulation cases
3802-
k = torch.tensor([1.0, 2.0], requires_grad=True)
3803-
k.grad_dtype = None
3804-
y = k * 2
3805-
y.sum().backward()
3806-
k.grad = k.grad.to(torch.bfloat16)
3807-
y2 = k * 3
3808-
# Doesn't type promote to float32, always coerce to current .grad's dtype.
3809-
# This is because the accumulation is done in-place on the existing grad.
3810-
self.assertEqual(k.grad.dtype, torch.bfloat16)
3811-
3812-
l = torch.tensor([3.0, 4.0], requires_grad=True, dtype=torch.bfloat16)
3813-
l.grad_dtype = None
3814-
z = l * 2
3815-
z.sum().backward()
3816-
l.grad = l.grad.to(torch.float32)
3817-
z2 = l * 3
3818-
z2.sum().backward()
3819-
self.assertEqual(l.grad.dtype, torch.float32)
3820-
38213697
def test_gc_in_destructor(self):
38223698
"""
38233699
Previously, if a Function destructor triggered a garbage collection,

torch/_C/__init__.pyi.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1917,7 +1917,6 @@ class TensorBase(metaclass=_TensorMeta):
19171917
names: list[str]
19181918
device: _device
19191919
dtype: _dtype
1920-
grad_dtype: _dtype | None
19211920
layout: _layout
19221921
real: Tensor
19231922
imag: Tensor

torch/_tensor_docs.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6624,36 +6624,6 @@ def callable(a, b) -> number
66246624
""",
66256625
)
66266626

6627-
add_docstr_all(
6628-
"grad_dtype",
6629-
r"""
6630-
The allowed dtype of :attr:``grad`` for this tensor.
6631-
6632-
:attr:``grad_dtype`` can be set to a specific dtype or ``None``. By default,
6633-
``t.grad_dtype == t.dtype``. When not None, the autograd engine casts
6634-
incoming gradients to this dtype. This attribute is only accessible and
6635-
settable for leaf tensors.
6636-
6637-
.. warning::
6638-
Use with caution. Diverging the dtypes of a tensor and its gradient may
6639-
break downstream systems that assume they match.
6640-
6641-
Example::
6642-
6643-
>>> x = torch.tensor([1.0, 2.0], requires_grad=True)
6644-
>>> x.grad_dtype
6645-
torch.float32
6646-
6647-
>>> x.grad_dtype = torch.float16
6648-
>>> x.grad_dtype
6649-
torch.float16
6650-
6651-
>>> # Allow any gradient dtype
6652-
>>> x.grad_dtype = None
6653-
>>> x.grad_dtype
6654-
""",
6655-
)
6656-
66576627
add_docstr_all(
66586628
"retain_grad",
66596629
r"""

0 commit comments

Comments
 (0)