Description
🐛 Describe the bug
Here's a simple unit test repro:
@torch.compile
def f():
y = torch.tensor([0, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192], dtype = torch.int32, device = "cuda")
return (y,)
index = 0
with torch.cuda._DeviceGuard(device):
torch.cuda.set_device(device)
result = f()
assert(result[0].device == torch.device("cuda:0"))
index = 1
with torch.cuda._DeviceGuard(index):
torch.cuda.set_device(index)
result = f()
assert(result[0].device == torch.device("cuda:1")) # Fails
When creating a constant tensor with torch.tensor
, Dynamo should guard on the specific device index of the tensor being created, because the output of f()
should always return a tensor of the current cuda device in eager.
However, AOTAutograd embeds constants into the graph, so guards need to be added so that dynamo correctly recompiles when the device guard changes.
This also affects AOTAutogradCache. If you run the same example, but with a torch._dynamo.reset()
in between, while enabling FXGraphCache and AOTAutogradCache, you'll get a cache hit and a similar issue.
There are a bunch of possible fixes here: AOTAutograd should probably add a guard on the ambient device index when converting a tensor into a constant, and it should also be part of the cache key. Theoretically, when creating the constant tensor, AOTAutograd must use something on the dynamo graph to tell it how to create the tensor. Will dig in more.
Versions
latest torch nightly