Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Oct 31, 2025

Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with tracing_mode="fake", or with your own FakeTensorMode, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

image
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()


def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @ezyang @EikanWang @jgong5 @wenzhe-nrv

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166750

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit bffb8b7 with merge base 5c63946 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Oct 31, 2025
@pytorch-bot pytorch-bot bot added release notes: fx release notes category labels Oct 31, 2025
@bdhirsh bdhirsh added release notes: distributed (dtensor) release notes category and removed release notes: fx release notes category labels Oct 31, 2025
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test would be great! thanks!

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />


```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()


def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 3, 2025
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 3, 2025
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added here

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 3, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />


```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()


def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 4, 2025
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />


```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()


def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 5, 2025
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 5, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fx Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants