Skip to content

Cannot mask a DTensor #152717

@pbontrager

Description

@pbontrager

🐛 Describe the bug

I'm attempting to mask a sharded d_tensor and everything that I've tried has failed so far. I have a DTenosr sharded on the last dim and I create a boolean mask to select a subset of the tensor (selecting from earlier dims not the sharded dim). I have tried to select using both a local mask and a replicated dtensor mask, I've also tried using torch.masked_select and regular tensor[mask] indexing. I am not sure if this is a bug with DTensor or I'm attempting to do this wrong.

import torch
from  torch.distributed import init_process_group, get_rank
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor

init_process_group('nccl')
torch.cuda.set_device(get_rank())
mesh = init_device_mesh('cuda', (2,), mesh_dim_names=('tp',))


embed = torch.randn(4, 32, 64)
d_embed = distribute_tensor(embed, mesh)

mask = (torch.arange(4*32, device='cuda').reshape(4, 32) % 2).bool()
d_embed[mask]

Run the above with torchrun --nproc_per_node 2 mask_dtensor.py and I get RuntimeError: aten.index.Tensor: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!.

I can also distributed the mask with d_mask = distribute_tensor(mask, mesh) which gives me the error torch._subclasses.fake_tensor.DynamicOutputShapeException: aten.nonzero.default

Versions

Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250416+cu126
[pip3] torchao==0.10.0
[pip3] torchaudio==2.6.0.dev20250416+cu126
[pip3] torchdata==0.11.0
[pip3] torchtune==0.0.0
[pip3] torchvision==0.22.0.dev20250416+cu126
[pip3] triton==3.3.0
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.5.1.17 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.24 dev_0
[conda] pytorch-triton 3.3.0+git96316ce5 pypi_0 pypi
[conda] torch 2.8.0.dev20250416+cu126 pypi_0 pypi
[conda] torchao 0.10.0 pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250416+cu126 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchtune 0.0.0 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250416+cu126 pypi_0 pypi
[conda] triton 3.3.0 pypi_0 pypi

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @tianyu-l @XilunWu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions