-
Notifications
You must be signed in to change notification settings - Fork 385
Description
I am having some strange issue with low bit optimizer and the combination of FSDP2 and CPU Offloading:
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method lerp(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(2, 1536)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., size=(2, 1536)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), 0.09999999999999998), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.lerp.Scalar, found two different devices cuda:0, cpu')
https://github.com/pytorch/ao/blob/main/torchao/optim/adam.py#L129
It works fine without CPU Offloading, but with it fails.
All params are on cpu device.
It works fine with regular AdamW.
Torch and torchao are on nightly.
Any ideas? Thanks