Skip to content

Dynamo error for autograd function #109910

@drisspg

Description

@drisspg

Summary

Encountering Dynamo error when attempting to compile an autograd function that returns a dtype.

Repro

import torch

class dtype_test(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        tensor: torch.Tensor,
        dtype=torch.dtype,
    ):
        orig_precision = tensor.dtype
        ctx.orig_precision = orig_precision
        return tensor.to(dtype), dtype
    @staticmethod
    def backward(ctx, dOut):
        return dOut.to(ctx.orig_precision), None, None,

def main():
    x = torch.randn(16, 16, device="cpu", dtype=torch.float32)
    out = dtype_test.apply(x, torch.bfloat16)
    def test_func(x, dtype):
        return dtype_test.apply(x, dtype)
    compiled_func = torch.compile(test_func, fullgraph=True)
    y = compiled_func(x, torch.bfloat16)

if __name__ == "__main__":
    main()

Output

 File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 329, in call_method
   raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
 File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 176, in unimplemented
   raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method TupleVariable() to [ConstantVariable(dtype)] {}

from user code:
  File "/home/drisspg/meta/float8_experimental/../scripts/compile/autograd_dtype_float.py", line 21, in test_func
   return dtype_test.apply(x, dtype)
 File "/home/drisspg/meta/float8_experimental/../scripts/compile/autograd_dtype_float.py", line 12, in forward
   return tensor.to(dtype), dtype

Note

This is a proxy issue I am creating for meta-pytorch/float8_experimental#108

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions