Skip to content

AssertionError: Torch not compiled with CUDA enabled when using device_map="auto" in Ascend NPU #38468

Open
@jiaqiw09

Description

@jiaqiw09

System Info

Ascend NPU
transformers>=4.50.0
torch 2.1

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When using device_map with Ascend NPU devices in transformers >=4.50.0, loading models fails with assertion errors. The issue occurs because the new loading implementation in _load_state_dict_into_meta_model doesn't properly handle integer device indices for NPU devices, whereas previous versions (<4.50.0) used accelerate.utils.set_module_tensor_to_device which correctly converts integer indices to device strings like "npu:0".

On an Ascend NPU system, attempt to load a model with device mapping:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "gpt2",
    device_map="auto",  # Or custom device_map with integer indices
    torch_dtype=torch.float16
)

Observe the failure with stack trace pointing to modeling_utils.py in _load_state_dict_into_meta_model

Expected behavior

In transformers <=4.49.0, device mapping used accelerate.utils.set_module_tensor_to_device for various device types

Image

In transformers >=4.50.0, the new _load_state_dict_into_meta_model directly uses device values from device_map without converting integer indices to device-specific strings

For NPU devices, integer indices (like 0) are not automatically converted to proper device strings ("npu:0")

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions