Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] print exception message on workers that run python functions #46372

Closed
wants to merge 6 commits into from
2 changes: 2 additions & 0 deletions torch/distributed/rpc/internal.py
Expand Up @@ -2,6 +2,7 @@
import copyreg
import io
import pickle
import sys
import threading
import traceback
from enum import Enum
Expand Down Expand Up @@ -168,6 +169,7 @@ def _run_function(python_udf):
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
f"{repr(e)}\n{traceback.format_exc()}"
)
print(except_str, file=sys.stderr)
result = RemoteException(except_str, type(e))
return result

Expand Down
11 changes: 11 additions & 0 deletions torch/testing/_internal/common_distributed.py
@@ -1,5 +1,7 @@

from multiprocessing import Manager
from contextlib import contextmanager
from io import StringIO
import os
import sys
import tempfile
Expand Down Expand Up @@ -174,6 +176,15 @@ def create_device(interface=None):
def get_timeout(test_id):
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)

@contextmanager
def captured_output():
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err

def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
"""
Expand Down
16 changes: 2 additions & 14 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -11,7 +11,6 @@
from contextlib import contextmanager, suppress
from datetime import timedelta
from functools import reduce
from io import StringIO
from typing import Union, NamedTuple

import torch
Expand All @@ -35,6 +34,7 @@
skip_if_no_gpu,
require_n_gpus_for_nccl_backend,
requires_nccl_version,
captured_output,
)
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
Expand Down Expand Up @@ -133,18 +133,6 @@ def forward(self, x):
BN_NET = BatchNormNet()
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)


@contextmanager
def _captured_output():
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err


def get_timeout(test_id):
test_name = test_id.split(".")[-1]
if test_name in CUSTOMIZED_TIMEOUT:
Expand Down Expand Up @@ -377,7 +365,7 @@ def _init_multigpu_helper(self):
return rank_to_GPU

def test_dump_DDP_relevant_env_vars(self):
with _captured_output() as (out, err):
with captured_output() as (out, _):
_dump_DDP_relevant_env_vars()
lines = out.getvalue().splitlines()

Expand Down
19 changes: 12 additions & 7 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Expand Up @@ -22,7 +22,7 @@
_internal_rpc_pickler,
_build_rpc_profiling_key,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu, captured_output
from torch.testing._internal.common_utils import IS_MACOS, load_tests
from torch.testing._internal.dist_utils import (
dist_init,
Expand Down Expand Up @@ -313,8 +313,9 @@ def my_script_func(tensor):
return torch.add(tensor, tensor)


expected_err = "Expected error"
def raise_func():
raise ValueError("Expected error")
raise ValueError(expected_err)


global_rref = None
Expand Down Expand Up @@ -1954,11 +1955,15 @@ def test_py_function_exception(self):

@dist_init
def test_py_raise_in_user_func(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
with self.assertRaises(ValueError):
fut.wait()
with captured_output() as (_, err):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
with self.assertRaisesRegex(ValueError, expected_err):
fut.wait()
lines = err.getvalue()

self.assertTrue(expected_err in lines)

@dist_init
def test_nested_rpc(self):
Expand Down