Skip to content

dynamo should recompile with constant tensors that use ambient device guards #147405

Closed
@jamesjwu

Description

@jamesjwu

🐛 Describe the bug

Here's a simple unit test repro:

@torch.compile
def f():
    y = torch.tensor([0, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192], dtype = torch.int32, device = "cuda")
    return (y,)

index = 0
with torch.cuda._DeviceGuard(device):
    torch.cuda.set_device(device)
    result = f()
    assert(result[0].device == torch.device("cuda:0"))

index = 1
with torch.cuda._DeviceGuard(index):
    torch.cuda.set_device(index)
    result = f()
   assert(result[0].device == torch.device("cuda:1")) # Fails

When creating a constant tensor with torch.tensor, Dynamo should guard on the specific device index of the tensor being created, because the output of f() should always return a tensor of the current cuda device in eager.

However, AOTAutograd embeds constants into the graph, so guards need to be added so that dynamo correctly recompiles when the device guard changes.

This also affects AOTAutogradCache. If you run the same example, but with a torch._dynamo.reset() in between, while enabling FXGraphCache and AOTAutogradCache, you'll get a cache hit and a similar issue.

There are a bunch of possible fixes here: AOTAutograd should probably add a guard on the ambient device index when converting a tensor into a constant, and it should also be part of the cache key. Theoretically, when creating the constant tensor, AOTAutograd must use something on the dynamo graph to tell it how to create the tensor. Will dig in more.

Versions

latest torch nightly

cc @chauhang @penguinwu @zou3519 @bdhirsh

Metadata

Metadata

Labels

actionabledynamo-must-fixThese bugs affect TorchDynamo reliability.module: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions