Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiling with Inductor, DDP, and Dynamic Shapes Results in Errors #125641

Closed
warner-benjamin opened this issue May 6, 2024 · 11 comments
Closed
Assignees
Labels
high priority module: ddp Issues/PRs related distributed data parallel training module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@warner-benjamin
Copy link

warner-benjamin commented May 6, 2024

馃悰 Describe the bug

torch.compile with the inductor backend errors out with dynamic shapes and DistributedDataParallel. Either a direct error ConstraintViolationError: Constraints violated (L['x'].size()[1])! when using torch._dynamo.mark_dynamic, or recompiling multiple times until the recompile limit is reached due to a "stride mismatch at index 0" compilation error with dynamic=True or dynamic=None.

These errors occur in both PyTorch 2.3 and the latest PyTorch Nightly.

I've created a replication with a simple "transformer" model, with just an embedding layer and linear head layer, so I can vary the shape of the sequence length in the batch. I get the same errors with a full from-scratch transformer with DDP.

I inconsistently get the ConstraintViolationError when using torch._dynamo.mark_dynamic in a non-distributed context with PyTorch 2.3. Specifically, with the Hugging Face Transformers Llama implementation. But I have been unable to replicate it with non-HF code.

Error logs

With my replication script below, compiling a DDP model for dynamic shapes with the recommended torch._dynamo.mark_dynamic instead of using torch.compile(..., dynamic=True) using the following command:

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

results with the following ConstraintViolationError

[rank0]: torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
[rank0]:   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (970)

You can turn on logging with --logging, but the dynamo logs don't appear to be that useful compared to other errors I've seen.

 torch/_guards.py:261] [0/0] Traceback (most recent call last):
 torch/_guards.py:261] [0/0]   File "/torch/_guards.py", line 259, in create
 torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
 torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 torch/_guards.py:261] [0/0]   File "/torch/_dynamo/guards.py", line 1664, in SHAPE_ENV
 torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
 torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 torch/_guards.py:261] [0/0]   File "/torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
 torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
 torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
 torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (986).
 torch/_guards.py:263] [0/0] Created at:
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/convert_frame.py", line 499, in transform
 torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/symbolic_convert.py", line 2143, in __init__
 torch/_guards.py:263] [0/0]     output=OutputGraph(
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/output_graph.py", line 309, in __init__
 torch/_guards.py:263] [0/0]     self.init_ambient_guards()
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/output_graph.py", line 448, in init_ambient_guards
 torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

The same command using torch.compile(..., dynamic=True) or torch.compile(..., dynamic=None) and relying on the compiler to detect dynamic shapes

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true
# or
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen

results in a recompiles error:

torch/_dynamo/convert_frame.py:367] torch._dynamo hit config.cache_size_limit (8) torch._dynamo hit config.cache_size_limit (8)
torch/_dynamo/convert_frame.py:367]    function: 'forward' (/replication.py:51)
torch/_dynamo/convert_frame.py:367]    last reason: tensor 'L['x']' stride mismatch at index 0. expected 1014, actual 1023

The logging output also doesn't appear to verbose.

torch/_dynamo/guards.py:2546] [__recompiles_verbose] Recompiling function forward in /replication.py:51
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     triggered by the following guard failure(s):
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 0 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 985, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose] 
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 1 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 976, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose] 
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 2 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 1015, actual 1011

I'm happy to add more logging if wanted.

Minified repro

Replication Script
# based on the PyTorch DDP example

import argparse
import logging
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from typing import Tuple
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel

torch.set_float32_matmul_precision("high")


def pytorch_logs_to_file(file: str = "pytorch.log"):
    torch._logging.set_logs(
        dynamo=logging.INFO,
        aot=logging.INFO,
        inductor=logging.INFO,
        dynamic=logging.INFO,
        distributed=logging.INFO,
        graph_breaks=True,
        guards=True,
        recompiles=True,
        recompiles_verbose=True,
        output_code=True,
        graph_code=True,
        graph=True,
        ddp_graphs=True,
    )
    torch._logging._init_logs(file)

    loggers = logging.Logger.manager.loggerDict.keys()
    for logger_name in loggers:
        if logger_name.startswith("torch"):
            logger = logging.getLogger(logger_name)
            if isinstance(logger, logging.Logger):
                handlers = logger.handlers
                for handler in handlers:
                    if isinstance(handler, logging.StreamHandler):
                        logger.removeHandler(handler)


class EmbedHeadModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x: Tensor):
        out = self.vocab_embed(x)
        out = self.head(out)
        return out


def get_batch(
    batch_size: Tensor, sequence_length: int, vocab_size: int, device: torch.device, dynamic: bool
) -> Tuple[Tensor, Tensor]:
    if dynamic:
        input = torch.randint(
            0,
            vocab_size - 1,
            (batch_size, sequence_length - random.randint(0, min(512, sequence_length / 2)) // 8 + 1),
            device=device,
        )
    else:
        input = torch.randint(0, vocab_size - 1, (batch_size, sequence_length + 1), device=device)
    return input[:, :-1].contiguous(), input[:, 1:].contiguous()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sequence_length", type=int, default=1024)
    parser.add_argument("--ddp", action="store_true")
    parser.add_argument("--compile", action="store_true")
    parser.add_argument("--variable_seqlen", action="store_true", help="Batches have variable sequence lengths")
    parser.add_argument("--dynamic_true", action="store_true", help="Compile with dynamic=True instead of None")
    parser.add_argument(
        "--use_mark_dynamic", action="store_true", help="Use torch._dynamo.mark_dynamic for dynamic shapes"
    )
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--iterations", type=int, default=100)
    parser.add_argument("--vocab_size", type=int, default=8000)
    parser.add_argument("--hidden_size", type=int, default=2048)
    parser.add_argument("--logging", action="store_true")
    parser.add_argument("--log_name", type=str, default="pytorch.log")

    return parser.parse_args()


def train():
    args = parse_args()

    if args.ddp:
        dist.init_process_group("nccl")
        rank = dist.get_rank()
        print(f"Start running basic DDP example on rank {rank}.")
    else:
        rank = 0

    if args.logging and rank == 0:
        pytorch_logs_to_file(args.log_name)

    device_id = rank % torch.cuda.device_count()
    model = EmbedHeadModel(args.vocab_size, args.hidden_size).to(device=device_id)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    if args.compile:
        model = torch.compile(model, dynamic=True if args.dynamic_true and not args.use_mark_dynamic else None)

    if args.ddp:
        model = DistributedDataParallel(model, device_ids=[device_id])

    model.train()

    for _ in range(0, args.iterations):
        data, targets = get_batch(
            args.batch_size, args.sequence_length, args.vocab_size, device_id, args.variable_seqlen
        )
        if args.use_mark_dynamic:
            torch._dynamo.mark_dynamic(data, index=1)

        output = model(data)
        loss = criterion(output.view(-1, args.vocab_size), targets.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if args.ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    train()

Versions

I ran my replication script on fresh conda environments:

conda create -n torchnight python=3.11 pytorch torchvision pytorch-cuda=12.4 -c pytorch-nightly -c nvidia -c conda-forge

conda create -n torch23 python=3.11 pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forge
PyTorch Nightly Environment
PyTorch version: 2.4.0.dev20240506
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 23.10 (x86_64)
GCC version: (Ubuntu 13.2.0-4ubuntu3) 13.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.38

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True


Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240506
[pip3] torchvision==0.19.0.dev20240506
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] libopenvino-pytorch-frontend 2024.0.0             he02047a_5    conda-forge
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4          py311h64a7726_0    conda-forge
[conda] pillow                    9.3.0           py311h3fd9d12_2    pytorch-nightly
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240506 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240506     py311_cu124    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang

@msaroufim msaroufim added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: ddp Issues/PRs related distributed data parallel training module: dynamic shapes labels May 6, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented May 7, 2024

@warner-benjamin I ran locally with a nightly and this actually passes for me. Can you try out a nightly? https://pytorch.org/get-started/locally/

@warner-benjamin
Copy link
Author

@bdhirsh I tested my replication script with yesterday's nightly and 2.3. You can see my environment in the "PyTorch Nightly Environment" section. These errors are only with DDP. Single GPU compiles and trains without issue.

I installed today's nightly pytorch-2.4.0.dev20240507 with both Cuda 12.4 & 12.1. Setting dynamic shapes with DDP via torch._dynamo.mark_dynamic using the following command still errors out with the same ConstraintViolationError.

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

And setting torch.compile(..., dynamic=True) or torch.compile(..., dynamic=None) using the following command still results with recompilations every batch until the torch._dynamo hit config.cache_size_limit (8) is hit.

# torch.compile(..., dynamic=True)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true

# torch.compile(..., dynamic=None)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen

@algal
Copy link

algal commented May 7, 2024

I am seeing the same issue this morning, running the same three commands on the replication script, on my system using CUDA 12.1.

Details below:

PyTorch Nightly Environment details
Collecting environment information...
PyTorch version: 2.4.0.dev20240507
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.1.0-20-amd64-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090

Nvidia driver version: 525.147.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               16
On-line CPU(s) list:                  0-15
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 7 5700G with Radeon Graphics
CPU family:                           25
Model:                                80
Thread(s) per core:                   2
Core(s) per socket:                   8
Socket(s):                            1
Stepping:                             0
Frequency boost:                      enabled
CPU(s) scaling MHz:                   64%
CPU max MHz:                          4672.0698
CPU min MHz:                          1400.0000
BogoMIPS:                             7585.74
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                       AMD-V
L1d cache:                            256 KiB (8 instances)
L1i cache:                            256 KiB (8 instances)
L2 cache:                             4 MiB (8 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-15
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240507
[pip3] torchvision==0.19.0.dev20240507
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] libopenvino-pytorch-frontend 2024.0.0             he02047a_5    conda-forge
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4          py311h64a7726_0    conda-forge
[conda] pillow                    9.3.0           py311h3fd9d12_2    pytorch-nightly
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240507 py3.11_cuda12.1_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.1                 ha16c6d3_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240507     py311_cu121    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

I'm going to look into this. But my recollection is that HF added some error checking code which forces specialization, and I haven't gotten around to yelling at them to stop running this logic when being torch compiled.

BTW, the two errors here are one and the same. mark_dynamic is yelling at you because it tried to make it dynamic, but failed due to specialization. You can use TORCH_LOGS=dynamic to find out where the specialization happened.

@warner-benjamin
Copy link
Author

I'm going to look into this. But my recollection is that HF added some error checking code which forces specialization, and I haven't gotten around to yelling at them to stop running this logic when being torch compiled.

It's not just HF models which trigger this when using DDP. My replication script uses a simple two-layer model with an Embedding and Linear layer. One layer doesn't replicate this issue. It seems to have something to do with adding a second layer.

class EmbedHeadModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x: Tensor):
        out = self.vocab_embed(x)
        out = self.head(out)
        return out

BTW, the two errors here are one and the same. mark_dynamic is yelling at you because it tried to make it dynamic, but failed due to specialization. You can use TORCH_LOGS=dynamic to find out where the specialization happened.

When I run my replication script with TORCH_LOGS=+dynamic

TORCH_LOGS=+dynamic torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

I get the following output for rank 0:

TORCH_LOGS=+dynamic Rank 0 Output
torch/fx/experimental/symbolic_shapes.py:2268] [0/0] create_env
torch/fx/experimental/symbolic_shapes.py:3239] [0/0] create_symbol s0 = 977 for L['x'].size()[1] [2, 9223372036854775806] at test/replication.py:52 in forward (_dynamo/variables/builder.py:2137 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(16*s0, 16) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(Mod(16, 16*s0), 0) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 2048*s0 > 2048 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(s0, 1) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(16*s0, 16) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(s0, 1) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval s0 > 1 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4634] [0/0] eval 32768*s0 < 2147483648 [guard added] (_inductor/codegen/triton.py:3409 in can_use_32bit_indexing), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="32768*s0 < 2147483648"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 16*s0 < 2147483648 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]
torch/fx/experimental/symbolic_shapes.py:3326] [0/0] produce_guards
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[0] 16 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[1] 977 RelaxedUnspecConstraint(warn_only=False)
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[0] 977 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[1] 1 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].storage_offset() 0 None
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[0] == 16
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[1] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[0] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[1] == 1
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].storage_offset() == 0
torch/_guards.py:261] [0/0] Error while creating guard:
torch/_guards.py:261] [0/0] Name: ''
torch/_guards.py:261] [0/0]     Source: shape_env
torch/_guards.py:261] [0/0]     Create Function: SHAPE_ENV
torch/_guards.py:261] [0/0]     Guard Types: None
torch/_guards.py:261] [0/0]     Code List: None
torch/_guards.py:261] [0/0]     Object Weakref: None
torch/_guards.py:261] [0/0]     Guarded Class Weakref: None
torch/_guards.py:261] [0/0] Traceback (most recent call last):
torch/_guards.py:261] [0/0]   File "torch/_guards.py", line 259, in create
torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).
torch/_guards.py:263] [0/0] Created at:
torch/_guards.py:263] [0/0]   File "torch/_dynamo/convert_frame.py", line 499, in transform
torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/symbolic_convert.py", line 2143, in __init__
torch/_guards.py:263] [0/0]     output=OutputGraph(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 308, in __init__
torch/_guards.py:263] [0/0]     self.init_ambient_guards()
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 447, in init_ambient_guards
torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats evaluate_expr: CacheInfo(hits=342, misses=15, maxsize=256, currsize=15)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _find: CacheInfo(hits=31, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats has_hint: CacheInfo(hits=1, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats size_hint: CacheInfo(hits=1, misses=3, maxsize=256, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats simplify: CacheInfo(hits=6, misses=17, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats replace: CacheInfo(hits=3257, misses=58, maxsize=None, currsize=18)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=1, misses=21, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_implications: CacheInfo(hits=1, misses=1, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_axioms: CacheInfo(hits=17, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats safe_expand: CacheInfo(hits=606, misses=53, maxsize=256, currsize=53)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats uninteresting_files: CacheInfo(hits=18, misses=1, maxsize=None, currsize=1)

I'm not seeing anything about specialization, but might be misinterpreting the logs.

@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

It's this:

torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]

Very strange though, why is this suppressed 馃. You could get a full backtrace for this log with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)"

@eellison
Copy link
Contributor

eellison commented May 8, 2024

@ezyang could it be related to this ? https://github.com/pytorch/pytorch/pull/120523/files#diff-cb8e02fc8f37e53904ab1b151c46dd109cf50d8121bbd340834b2e976b22ebc4R74

Maybe the idiom there is not correct. We're trying update the meta strides without adding guards or specializations

@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

Oh yeah, this looks very very naughty. Hmmmm

@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

As a stopgap, I guess we could prevent replacements from happening when guards are suppressed. This still seems very naughty though.....

@warner-benjamin
Copy link
Author

Here's the additional backtrace with the "Eq(2048*s0, 2000896)" guard added:

TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)"
torch/fx/experimental/symbolic_shapes.py:2268] [0/0] create_env
torch/fx/experimental/symbolic_shapes.py:3239] [0/0] create_symbol s0 = 977 for L['x'].size()[1] [2, 9223372036854775806] at test/replication.py:52 in forward (_dynamo/variables/builder.py:2137 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(16*s0, 16) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(Mod(16, 16*s0), 0) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 2048*s0 > 2048 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(s0, 1) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(16*s0, 16) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(s0, 1) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval s0 > 1 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4634] [0/0] eval 32768*s0 < 2147483648 [guard added] (_inductor/codegen/triton.py:3409 in can_use_32bit_indexing), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="32768*s0 < 2147483648"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 16*s0 < 2147483648 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]
torch/fx/experimental/symbolic_shapes.py:3326] [0/0] produce_guards
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[0] 16 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[1] 977 RelaxedUnspecConstraint(warn_only=False)
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[0] 977 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[1] 1 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].storage_offset() 0 None
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[0] == 16
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[1] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[0] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[1] == 1
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].storage_offset() == 0
torch/_guards.py:261] [0/0] Error while creating guard:
torch/_guards.py:261] [0/0] Name: ''
torch/_guards.py:261] [0/0]     Source: shape_env
torch/_guards.py:261] [0/0]     Create Function: SHAPE_ENV
torch/_guards.py:261] [0/0]     Guard Types: None
torch/_guards.py:261] [0/0]     Code List: None
torch/_guards.py:261] [0/0]     Object Weakref: None
torch/_guards.py:261] [0/0]     Guarded Class Weakref: None
torch/_guards.py:261] [0/0] Traceback (most recent call last):
torch/_guards.py:261] [0/0]   File "torch/_guards.py", line 259, in create
torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).
torch/_guards.py:263] [0/0] Created at:
torch/_guards.py:263] [0/0]   File "torch/_dynamo/convert_frame.py", line 499, in transform
torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/symbolic_convert.py", line 2143, in __init__
torch/_guards.py:263] [0/0]     output=OutputGraph(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 308, in __init__
torch/_guards.py:263] [0/0]     self.init_ambient_guards()
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 447, in init_ambient_guards
torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "replication.py", line 139, in <module>
    train()
  File "replication.py", line 127, in train
    output = model(data)
             ^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/parallel/distributed.py", line 1620, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/parallel/distributed.py", line 1438, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 977, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 822, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 410, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/benja/.conda/envs/torchnight/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 703, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 660, in compile_inner
    check_fn = CheckFunctionManager(
               ^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/guards.py", line 2086, in __init__
    guard.create(builder)
  File "torch/_guards.py", line 259, in create
    return self.create_fn(builder, self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
    guards = output_graph.shape_env.produce_guards(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats evaluate_expr: CacheInfo(hits=342, misses=15, maxsize=256, currsize=15)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _find: CacheInfo(hits=31, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats has_hint: CacheInfo(hits=1, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats size_hint: CacheInfo(hits=1, misses=3, maxsize=256, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats simplify: CacheInfo(hits=6, misses=17, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats replace: CacheInfo(hits=3257, misses=58, maxsize=None, currsize=18)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=1, misses=21, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_implications: CacheInfo(hits=1, misses=1, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_axioms: CacheInfo(hits=17, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats safe_expand: CacheInfo(hits=606, misses=53, maxsize=256, currsize=53)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats uninteresting_files: CacheInfo(hits=18, misses=1, maxsize=None, currsize=1)

ezyang added a commit that referenced this issue May 14, 2024
Also improve logging when guards are suppressed

Partially addresses #125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: c139cf64a5c142934cf35d0bba2bd2fd8df048d9
Pull Request resolved: #126210
ezyang added a commit that referenced this issue May 21, 2024
Also improve logging when guards are suppressed

Partially addresses #125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: b8cd30648bcc882640addb397c0363ba5615cb04
Pull Request resolved: #126210
ezyang added a commit that referenced this issue May 22, 2024
Also improve logging when guards are suppressed

Partially addresses #125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 731ba856c11505cc705c51b1df46f94e43fc6b74
Pull Request resolved: #126210
ezyang added a commit that referenced this issue May 22, 2024
Also improve logging when guards are suppressed

Partially addresses #125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 6c3f973416ed6602eb51377762d8e21cbc1fec76
Pull Request resolved: #126210
pytorchmergebot pushed a commit that referenced this issue May 23, 2024
Also improve logging when guards are suppressed

Partially addresses #125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #126210
Approved by: https://github.com/jbschlosser
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this issue May 28, 2024
Also improve logging when guards are suppressed

Partially addresses pytorch#125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#126210
Approved by: https://github.com/jbschlosser
@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2024

I believe this is fixed

@ezyang ezyang closed this as completed Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: ddp Issues/PRs related distributed data parallel training module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants