Skip to content

fix "Expected all tensors to be on the same device, but found at least two devices" error #11690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/unets/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,8 @@ def forward(
b1=self.b1,
b2=self.b2,
)

if hidden_states.device != res_hidden_states.device:
res_hidden_states = res_hidden_states.to(hidden_states.device)
Comment on lines +2560 to +2561
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need that since both hidden_states and res_hidden_states should be on the same device no ? The pre-forward hook added by accelerate should be move all the inputs to the same device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc , i suppose this is a corner case? torch.cat is a weight-less function, so seems cannot covered by the pre-forward hook set by accelerate...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean since hidden_states and res_hidden_states_tuple are in the forward definition, they should be moved to the same device by the pre-forward hook added by accelerate

Copy link
Contributor Author

@yao-matrix yao-matrix Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc We run into a corner case here. Since we have 8 cards here, so the determined device_map(by https://github.com/huggingface/diffusers/blob/1bc6f3dc0f21779480db70a4928d14282c0198ed/src/diffusers/models/model_loading_utils.py#L64C5-L64C26) is

device_map: OrderedDict([('conv_in', 0), ('time_proj', 0), ('time_embedding', 0), ('down_blocks.0', 0), ('down_blocks.1.resnets.0', 1), ('up_blocks.0.resnets.0', 1), ('up_blocks.0.resnets.1', 2), ('up_blocks.0.upsamplers', 2), ('up_blocks.1', 3), ('mid_block.attentions', 3), ('conv_norm_out', 4), ('conv_act', 4), ('conv_out', 4), ('mid_block.resnets', 4)])

We can see UpBlock is not the atomic module, its submodules are assigned to different devices(up_blocks.0.resnets.0, up_blocks.0.resnets.1), so pre-hook for UpBlock will not help in this case. And since torch.cat is not pre-hooked(and cannot since it's a function rather than a module?), so the issue happens.

If there is no a torch.cat btw the sub-blocks in UpBlock, things will be all fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc, need your inputs in how to proceed for this corner case, thx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc We can see a similar case in transformers ut pytest -rA tests/models/chameleon/test_modeling_chameleon.py::ChameleonVision2SeqModelTest::test_model_parallel_beam_search w/ 2 cards, the error log is "RuntimeError: Expected all tensors to be on the same device, but found at least two devices,src/transformers/models/chameleon/modeling_chameleon.py", the reason is even residual is in the same device as hidden_states at the beginning, but after they went through some operators as both input and output, they finally placed to different device, but when they come to + which is not a nn.Module(so accelerate cannot pre-hook it), error happens. Do you have some insights on such issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc, could you share your insights on the issue i mentioned above? thx very much.

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator,
run_test_in_subprocess,
slow,
Expand Down Expand Up @@ -1744,7 +1743,7 @@ def test_push_to_hub_library_name(self):
delete_repo(self.repo_id, token=TOKEN)


@require_torch_gpu
@require_torch_accelerator
@require_torch_2
@is_torch_compile
@slow
Expand Down Expand Up @@ -1789,7 +1788,7 @@ def test_compile_with_group_offloading(self):
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
"onload_device": "cuda",
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
Expand Down