-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compiling with Inductor, DDP, and Dynamic Shapes Results in Errors #125641
Comments
@warner-benjamin I ran locally with a nightly and this actually passes for me. Can you try out a nightly? https://pytorch.org/get-started/locally/ |
@bdhirsh I tested my replication script with yesterday's nightly and 2.3. You can see my environment in the "PyTorch Nightly Environment" section. These errors are only with DDP. Single GPU compiles and trains without issue. I installed today's nightly torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic And setting # torch.compile(..., dynamic=True)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true
# torch.compile(..., dynamic=None)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen |
I am seeing the same issue this morning, running the same three commands on the replication script, on my system using CUDA 12.1. Details below: PyTorch Nightly Environment details
|
I'm going to look into this. But my recollection is that HF added some error checking code which forces specialization, and I haven't gotten around to yelling at them to stop running this logic when being torch compiled. BTW, the two errors here are one and the same. |
It's not just HF models which trigger this when using DDP. My replication script uses a simple two-layer model with an class EmbedHeadModel(nn.Module):
def __init__(self, vocab_size: int, hidden_size: int):
super().__init__()
self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
self.head = nn.Linear(hidden_size, vocab_size)
def forward(self, x: Tensor):
out = self.vocab_embed(x)
out = self.head(out)
return out
When I run my replication script with TORCH_LOGS=+dynamic torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic I get the following output for rank 0: TORCH_LOGS=+dynamic Rank 0 Output
I'm not seeing anything about specialization, but might be misinterpreting the logs. |
It's this:
Very strange though, why is this suppressed 馃. You could get a full backtrace for this log with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)" |
@ezyang could it be related to this ? https://github.com/pytorch/pytorch/pull/120523/files#diff-cb8e02fc8f37e53904ab1b151c46dd109cf50d8121bbd340834b2e976b22ebc4R74 Maybe the idiom there is not correct. We're trying update the meta strides without adding guards or specializations |
Oh yeah, this looks very very naughty. Hmmmm |
As a stopgap, I guess we could prevent replacements from happening when guards are suppressed. This still seems very naughty though..... |
Here's the additional backtrace with the "Eq(2048*s0, 2000896)" guard added: TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)"
|
Also improve logging when guards are suppressed Partially addresses #125641 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #126210 Approved by: https://github.com/jbschlosser
Also improve logging when guards are suppressed Partially addresses pytorch#125641 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#126210 Approved by: https://github.com/jbschlosser
I believe this is fixed |
馃悰 Describe the bug
torch.compile
with the inductor backend errors out with dynamic shapes and DistributedDataParallel. Either a direct errorConstraintViolationError: Constraints violated (L['x'].size()[1])!
when usingtorch._dynamo.mark_dynamic
, or recompiling multiple times until the recompile limit is reached due to a "stride mismatch at index 0" compilation error withdynamic=True
ordynamic=None
.These errors occur in both PyTorch 2.3 and the latest PyTorch Nightly.
I've created a replication with a simple "transformer" model, with just an embedding layer and linear head layer, so I can vary the shape of the sequence length in the batch. I get the same errors with a full from-scratch transformer with DDP.
I inconsistently get the ConstraintViolationError when using
torch._dynamo.mark_dynamic
in a non-distributed context with PyTorch 2.3. Specifically, with the Hugging Face Transformers Llama implementation. But I have been unable to replicate it with non-HF code.Error logs
With my replication script below, compiling a DDP model for dynamic shapes with the recommended
torch._dynamo.mark_dynamic
instead of usingtorch.compile(..., dynamic=True)
using the following command:results with the following
ConstraintViolationError
You can turn on logging with
--logging
, but the dynamo logs don't appear to be that useful compared to other errors I've seen.The same command using
torch.compile(..., dynamic=True)
ortorch.compile(..., dynamic=None)
and relying on the compiler to detect dynamic shapestorchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true # or torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen
results in a recompiles error:
The logging output also doesn't appear to verbose.
I'm happy to add more logging if wanted.
Minified repro
Replication Script
Versions
I ran my replication script on fresh conda environments:
PyTorch Nightly Environment
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang
The text was updated successfully, but these errors were encountered: