-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Closed
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Summary
Encountering Dynamo error when attempting to compile an autograd function that returns a dtype.
Repro
import torch
class dtype_test(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
dtype=torch.dtype,
):
orig_precision = tensor.dtype
ctx.orig_precision = orig_precision
return tensor.to(dtype), dtype
@staticmethod
def backward(ctx, dOut):
return dOut.to(ctx.orig_precision), None, None,
def main():
x = torch.randn(16, 16, device="cpu", dtype=torch.float32)
out = dtype_test.apply(x, torch.bfloat16)
def test_func(x, dtype):
return dtype_test.apply(x, dtype)
compiled_func = torch.compile(test_func, fullgraph=True)
y = compiled_func(x, torch.bfloat16)
if __name__ == "__main__":
main()
Output
File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 329, in call_method
raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 176, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method TupleVariable() to [ConstantVariable(dtype)] {}
from user code:
File "/home/drisspg/meta/float8_experimental/../scripts/compile/autograd_dtype_float.py", line 21, in test_func
return dtype_test.apply(x, dtype)
File "/home/drisspg/meta/float8_experimental/../scripts/compile/autograd_dtype_float.py", line 12, in forward
return tensor.to(dtype), dtype
Note
This is a proxy issue I am creating for meta-pytorch/float8_experimental#108
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module