Skip to content

Any plans to support USE_DISTRIBUTED=0 pytorch? #3452

@rene-descartes2021

Description

@rene-descartes2021

Dec 6th EDIT: simplified & expanded error and reproduction example from conversation below.

If not then please write in readme/requirements somewhere. The error below was cryptic.

Error that led me to this conception:

Traceback (most recent call last):
  File "/data/data/com.termux/files/home/dev/llm/sd/test/./to.py", line 3, in <module>
    import torchao
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/__init__.py", line 127, in <module>
    from torchao.quantization import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/quantization/__init__.py", line 6, in <module>
    from .autoquant import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/quantization/autoquant.py", line 11, in <module>
    from torchao.dtypes import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/dtypes/__init__.py", line 1, in <module>
    from . import affine_quantized_tensor_ops
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 14, in <module>
    from torchao.dtypes.floatx.cutlass_semi_sparse_layout import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/dtypes/floatx/__init__.py", line 4, in <module>
    from .float8_layout import Float8Layout
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/dtypes/floatx/float8_layout.py", line 21, in <module>
    from torchao.float8.inference import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/float8/__init__.py", line 12, in <module>
    from torchao.float8.float8_linear_utils import (
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/float8/float8_linear_utils.py", line 14, in <module>
    from torchao.float8.float8_linear import Float8Linear
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/float8/float8_linear.py", line 15, in <module>
    from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/float8/distributed_utils.py", line 14, in <module>
    from torchao.float8.float8_training_tensor import Float8TrainingTensor
  File "/data/data/com.termux/files/home/dev/llm/sd/ao/torchao/float8/float8_training_tensor.py", line 10, in <module>
    from torch.distributed._tensor import DTensor
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py", line 25, in <module>
    sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module(
                                                            ^^^^^^^^^^^^^^
  File "/data/data/com.termux/files/usr/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py", line 4, in <module>
    import torch.distributed.tensor._ops  # force import all built-in dtensor ops
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py", line 2, in <module>
    from ._conv_ops import *  # noqa: F403
    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py", line 5, in <module>
    from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py", line 6, in <module>
    from torch.distributed.tensor.placement_types import (
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py", line 8, in <module>
    import torch.distributed._functional_collectives as funcol
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py", line 9, in <module>
    import torch.distributed.distributed_c10d as c10d
  File "/data/data/com.termux/files/usr/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 23, in <module>
    from torch._C._distributed_c10d import (
ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package

Reproduction example leading to the error above when used with USE_DISTRIBUTED=0 USE_CUDA=0 built pytorch:

import torchao

I tried guarding the imports/usage of DTensor with something like:

import torch.distributed
is_torch_distributed_available = torch.distributed.is_available()
if is_torch_distributed_available:
  from torch.distributed._tensor import DTensor

But DTensor usage turned out to be prolific and well integrated. Maybe there is some sort of subclass solution and guard anything DTensor specific? I'm out of my depth here.

I don't know of anything else to try with torchao and it looks like I must look into other quantization backends.

Here are a couple pytorch issues on the error that pops up when using other libraries which don't guard DTensor usage [1][2]. In that second issue maybe upstream pytorch will gracefully allow importing DTensor stuff without failing (cryptically) with USE_DISTRIBUTED=0. So probably waiting on that and adding something to the readme/requirements would be a low effort solution. My take after some reading. Thanks for reading and have a good day.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @malfet @seemethere

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions