Skip to content

[Distributed Inference] moving stage.submod to non-fp32 (bf16, fp16) results in dtensor assert "self.mask_buffer.data is not None" #1086

@lessw2020

Description

@lessw2020

🐛 Describe the bug

Using our prototype parallel blocks for built in distributed, we can run tp + pp in fp32 successfully.
However, moving the model to bfloat16 or fp32 results in an embedding assert:

[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 530, in forward_one_chunk
[rank0]:[rank0]:     output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 464, in forward_maybe_with_nosync
[rank0]:[rank0]:     out_val = self.submod(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/less/local/torchchat/build/model_dist.py", line 105, in forward
[rank0]:[rank0]:     x: DTensor = self.tok_embeddings(x)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1801, in _call_impl
[rank0]:[rank0]:     hook_result = hook(self, args, result)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 895, in <lambda>
[rank0]:[rank0]:     lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/tensor/parallel/style.py", line 251, in _prepare_output_fn
[rank0]:[rank0]:     outputs = outputs.redistribute(placements=output_layouts, async_op=True)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 541, in redistribute
[rank0]:[rank0]:     return Redistribute.apply(self, device_mesh, placements, async_op)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/_redistribute.py", line 295, in forward
[rank0]:[rank0]:     output = redistribute_local_tensor(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/_redistribute.py", line 196, in redistribute_local_tensor
[rank0]:[rank0]:     new_local_tensor = partial_spec._reduce_value(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/ops/_embedding_ops.py", line 119, in _reduce_value
[rank0]:[rank0]:     assert self.mask_buffer.data is not None
[rank0]:[rank0]: AssertionError

This issue is to track the debugging and resolution.

Versions

N/A

Metadata

Metadata

Assignees

Labels

DistributedIssues related to all things distributed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions