-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Description
🐛 Describe the bug
Hi,
I want to record a warning in Python, that is originating in C10 portion of the code (TORCH_WARN_ONCE), while running in a subprocess because of DDP. However, it seems that this warning is impossible to catch because it does not propagate to Python correctly. Below is a simple demo, that is mostly taken from this tutorial and adapted to catching warnings.
Code and output with warnings
import contextlib
import io
import os
import sys
import warnings
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
new_stdout = io.StringIO()
new_stderr = io.StringIO()
with contextlib.ExitStack() as stack:
warns = stack.enter_context(warnings.catch_warnings(record=True))
stack.enter_context(contextlib.redirect_stdout(new_stdout))
stack.enter_context(contextlib.redirect_stderr(new_stderr))
warnings.simplefilter("always")
warnings.warn("Simple warning", Warning)
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
try:
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
except:
print(traceback.format_exc(), file=sys.stderr)
finally:
cleanup()
print(f"Caught warnings:")
for warn in warns:
print(warn)
print(f"Caught stdout: {new_stdout.getvalue()}")
print(f"Caught stderr: {new_stderr.getvalue()}")
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
world_size = n_gpus
run_demo(demo_basic, world_size)Output:
[W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.
Caught stderr:
However, if I do some intentional mistake in order to raise an Exception in the similar code path (such as changing the size of tensors so that they do not match anymore), the Exception is correctly propagated to to Python as a RuntimeError, see the modified code
Code and output with Exception
import contextlib
import io
import os
import sys
import warnings
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
new_stdout = io.StringIO()
new_stderr = io.StringIO()
with contextlib.ExitStack() as stack:
warns = stack.enter_context(warnings.catch_warnings(record=True))
stack.enter_context(contextlib.redirect_stdout(new_stdout))
stack.enter_context(contextlib.redirect_stderr(new_stderr))
warnings.simplefilter("always")
warnings.warn("Simple warning", Warning)
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
try:
outputs = ddp_model(torch.randn(20, 9)) # <--- Change is here, this will create error
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
except:
print(traceback.format_exc(), file=sys.stderr)
finally:
cleanup()
print(f"Caught warnings:")
for warn in warns:
print(warn)
print(f"Caught stdout: {new_stdout.getvalue()}")
print(f"Caught stderr: {new_stderr.getvalue()}")
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
world_size = n_gpus
run_demo(demo_basic, world_size)Output:
Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.
Caught stderr: Traceback (most recent call last):
File "/home/otaj/files/grid/simple-demo/main.py", line 60, in demo_basic
outputs = ddp_model(torch.randn(20, 9))
File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/otaj/files/grid/simple-demo/main.py", line 34, in forward
return self.net2(self.relu(self.net1(x)))
File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x9 and 10x10)
The issue was first reported on PyTorch slack, cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @albanD , @ezyang , @mruberry , it is most likely linked to this issue: #72948
Thanks a lot!
Versions
Collecting environment information...
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Arch Linux (x86_64)
GCC version: (GCC) 11.2.0
Clang version: Could not collect
CMake version: version 3.23.0
Libc version: glibc-2.35
Python version: 3.9.11 (main, Apr 7 2022, 15:33:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.17.1-zen1-1-zen-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.112
GPU models and configuration: GPU 0: NVIDIA T1200 Laptop GPU
Nvidia driver version: 510.60.02
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.3.3
/usr/lib/libcudnn_adv_infer.so.8.3.3
/usr/lib/libcudnn_adv_train.so.8.3.3
/usr/lib/libcudnn_cnn_infer.so.8.3.3
/usr/lib/libcudnn_cnn_train.so.8.3.3
/usr/lib/libcudnn_ops_infer.so.8.3.3
/usr/lib/libcudnn_ops_train.so.8.3.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.942
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.3
[pip3] pytorch-lightning==1.7.0.dev0
[pip3] torch==1.11.0+cu113
[pip3] torchmetrics==0.7.3
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0+cu113
[conda] Could not collect