Skip to content

Warning originating in C10 backend does not get translated to Python warning if run from subprocess #75725

@otaj

Description

@otaj

🐛 Describe the bug

Hi,

I want to record a warning in Python, that is originating in C10 portion of the code (TORCH_WARN_ONCE), while running in a subprocess because of DDP. However, it seems that this warning is impossible to catch because it does not propagate to Python correctly. Below is a simple demo, that is mostly taken from this tutorial and adapted to catching warnings.

Code and output with warnings
import contextlib
import io
import os
import sys
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(torch.nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    new_stdout = io.StringIO()
    new_stderr = io.StringIO()

    with contextlib.ExitStack() as stack:
        warns = stack.enter_context(warnings.catch_warnings(record=True))
        stack.enter_context(contextlib.redirect_stdout(new_stdout))
        stack.enter_context(contextlib.redirect_stderr(new_stderr))
        warnings.simplefilter("always")
        warnings.warn("Simple warning", Warning)

        print(f"Running basic DDP example on rank {rank}.")
        setup(rank, world_size)

        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        try:
            outputs = ddp_model(torch.randn(20, 10))
            labels = torch.randn(20, 5).to(rank)
            loss_fn(outputs, labels).backward()
            optimizer.step()

        except:
            print(traceback.format_exc(), file=sys.stderr)

        finally:
            cleanup()

    print(f"Caught warnings:")
    for warn in warns:
        print(warn)

    print(f"Caught stdout: {new_stdout.getvalue()}")
    print(f"Caught stderr: {new_stderr.getvalue()}")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    world_size = n_gpus
    run_demo(demo_basic, world_size)

Output:

[W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.

Caught stderr: 

However, if I do some intentional mistake in order to raise an Exception in the similar code path (such as changing the size of tensors so that they do not match anymore), the Exception is correctly propagated to to Python as a RuntimeError, see the modified code

Code and output with Exception
import contextlib
import io
import os
import sys
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(torch.nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    new_stdout = io.StringIO()
    new_stderr = io.StringIO()

    with contextlib.ExitStack() as stack:
        warns = stack.enter_context(warnings.catch_warnings(record=True))
        stack.enter_context(contextlib.redirect_stdout(new_stdout))
        stack.enter_context(contextlib.redirect_stderr(new_stderr))
        warnings.simplefilter("always")
        warnings.warn("Simple warning", Warning)

        print(f"Running basic DDP example on rank {rank}.")
        setup(rank, world_size)

        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        try:
            outputs = ddp_model(torch.randn(20, 9)) # <--- Change is here, this will create error
            labels = torch.randn(20, 5).to(rank)
            loss_fn(outputs, labels).backward()
            optimizer.step()

        except:
            print(traceback.format_exc(), file=sys.stderr)

        finally:
            cleanup()

    print(f"Caught warnings:")
    for warn in warns:
        print(warn)

    print(f"Caught stdout: {new_stdout.getvalue()}")
    print(f"Caught stderr: {new_stderr.getvalue()}")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    world_size = n_gpus
    run_demo(demo_basic, world_size)

Output:

Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.

Caught stderr: Traceback (most recent call last):
  File "/home/otaj/files/grid/simple-demo/main.py", line 60, in demo_basic
    outputs = ddp_model(torch.randn(20, 9))
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/files/grid/simple-demo/main.py", line 34, in forward
    return self.net2(self.relu(self.net1(x)))
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x9 and 10x10)

The issue was first reported on PyTorch slack, cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @albanD , @ezyang , @mruberry , it is most likely linked to this issue: #72948

Thanks a lot!

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 11.2.0
Clang version: Could not collect
CMake version: version 3.23.0
Libc version: glibc-2.35

Python version: 3.9.11 (main, Apr 7 2022, 15:33:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.17.1-zen1-1-zen-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.112
GPU models and configuration: GPU 0: NVIDIA T1200 Laptop GPU
Nvidia driver version: 510.60.02
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.3.3
/usr/lib/libcudnn_adv_infer.so.8.3.3
/usr/lib/libcudnn_adv_train.so.8.3.3
/usr/lib/libcudnn_cnn_infer.so.8.3.3
/usr/lib/libcudnn_cnn_train.so.8.3.3
/usr/lib/libcudnn_ops_infer.so.8.3.3
/usr/lib/libcudnn_ops_train.so.8.3.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.942
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.3
[pip3] pytorch-lightning==1.7.0.dev0
[pip3] torch==1.11.0+cu113
[pip3] torchmetrics==0.7.3
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0+cu113
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions