Skip to content

FSDP2 + CPU Offload + AdamW8bit issue #1931

@psinger

Description

@psinger

I am having some strange issue with low bit optimizer and the combination of FSDP2 and CPU Offloading:

torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method lerp(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(2, 1536)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., size=(2, 1536)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), 0.09999999999999998), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.lerp.Scalar, found two different devices cuda:0, cpu')
https://github.com/pytorch/ao/blob/main/torchao/optim/adam.py#L129

It works fine without CPU Offloading, but with it fails.
All params are on cpu device.

It works fine with regular AdamW.

Torch and torchao are on nightly.

Any ideas? Thanks

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions