Skip to content

Commit

Permalink
Update on "Implement timeout support for RRefs"
Browse files Browse the repository at this point in the history
This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works:

- Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call).
- Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown)
- A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for.
- Before blocking, `to_here` checks if the RRef creation on the remote node has timed out, and if so, throws with that information.
- `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user.
- Tests are added, primarily via delay injection.

Differential Revision: [D21588165](https://our.internmc.facebook.com/intern/diff/D21588165/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21588165/)!

[ghstack-poisoned]
  • Loading branch information
rohan-varma committed Jun 4, 2020
2 parents 56700d1 + ec5d579 commit a72eeed
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 73 deletions.
12 changes: 12 additions & 0 deletions .github/pytorch-circleci-labels.yml
@@ -0,0 +1,12 @@
# For documentation concerning this configuration please refer to,
# https://github.com/pytorch/pytorch-probot#trigger-circleci-workflows
labels_to_circle_params:
ci/binaries:
parameter: run_binary_tests
default_true_on:
branches:
- nightly
- ci-all/.*
- release/.*
tags:
- v[0-9]+(\.[0-9]+)*-rc[0-9]+
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Resize.h
Expand Up @@ -134,6 +134,7 @@ inline void setStrided(
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset) {
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(
size, stride, storage_offset, self_->dtype(), self_->storage());
Expand All @@ -143,7 +144,6 @@ inline void setStrided(
self_->set_storage_offset(storage_offset);

/* size and stride */
AT_ASSERT(size.size() == stride.size());
if (self_->sizes() == size && self_->strides() == stride) {
return;
}
Expand Down
9 changes: 4 additions & 5 deletions aten/src/THCUNN/generic/RReLU.cu
Expand Up @@ -20,7 +20,8 @@ void THNN_(RReLU_updateOutput)(
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(generator, at::cuda::detail::getDefaultCUDAGenerator());
if (train)
{
input = THCTensor_(newContiguous)(state, input);
auto inputTensor = THTensor_wrap(input).contiguous();
input = inputTensor.unsafeGetTensorImpl();
THCTensor_(resizeAs)(state, noise, input);
scalar_t *input_data = THCTensor_(data)(state, input);
scalar_t *noise_data = THCTensor_(data)(state, noise);
Expand Down Expand Up @@ -50,7 +51,6 @@ void THNN_(RReLU_updateOutput)(
n, rng_engine_inputs, input_data, noise_data, output_data, lower, upper);
}
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, input);
}
else
{
Expand Down Expand Up @@ -82,7 +82,8 @@ void THNN_(RReLU_updateGradInput)(
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, noise);

gradOutput = THCTensor_(newContiguous)(state, gradOutput);
auto gradOutputTensor = THTensor_wrap(gradOutput).contiguous();
gradOutput = gradOutputTensor.unsafeGetTensorImpl();

if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU
{
Expand Down Expand Up @@ -113,8 +114,6 @@ void THNN_(RReLU_updateGradInput)(
THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor<scalar_t>(negSlope));
}
}

THCTensor_(free)(state, gradOutput);
}

#endif
6 changes: 0 additions & 6 deletions mypy.ini
Expand Up @@ -340,15 +340,9 @@ ignore_errors = True
[mypy-torch.multiprocessing.spawn]
ignore_errors = True

[mypy-torch.backends.cudnn.rnn]
ignore_errors = True

[mypy-torch.backends.cuda]
ignore_errors = True

[mypy-torch.backends.cudnn]
ignore_errors = True

[mypy-torch.backends.quantized]
ignore_errors = True

Expand Down
10 changes: 7 additions & 3 deletions test/quantization/test_workflow_module.py
Expand Up @@ -823,10 +823,14 @@ class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(2, 2)
self.conv = nn.Conv2d(1, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU()

def forward(self, x):
x = self.linear(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

model = Model()
Expand All @@ -841,5 +845,5 @@ def forward(self, x):
self.assertEqual(model_device, device)

# ensure that running an input on CUDA works without any needed changes
input = torch.randn(2, device=device)
input = torch.randn(4, 1, 4, 4, device=device)
model(input)
8 changes: 8 additions & 0 deletions test/run_test.py
Expand Up @@ -73,6 +73,8 @@
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',
'distributed/rpc/faulty_agent/test_rpc_spawn',
'distributed/rpc/jit/test_dist_autograd_spawn',
'distributed/rpc/tensorpipe/test_dist_autograd_spawn',
'distributed/rpc/tensorpipe/test_dist_optimizer_spawn',
'distributed/rpc/tensorpipe/test_rpc_spawn',
'distributed/rpc/test_dist_autograd_spawn',
'distributed/rpc/test_dist_optimizer_spawn',
Expand All @@ -89,6 +91,8 @@
'distributed/rpc/faulty_agent/test_rpc_spawn',
'distributed/rpc/jit/test_dist_autograd_spawn',
'distributed/rpc/jit/test_rpc_spawn',
'distributed/rpc/tensorpipe/test_dist_autograd_spawn',
'distributed/rpc/tensorpipe/test_dist_optimizer_spawn',
'distributed/rpc/tensorpipe/test_rpc_spawn',
'distributed/rpc/test_dist_autograd_spawn',
'distributed/rpc/test_dist_optimizer_spawn',
Expand All @@ -101,6 +105,8 @@
'distributed/rpc/faulty_agent/test_rpc_spawn',
'distributed/rpc/jit/test_dist_autograd_spawn',
'distributed/rpc/jit/test_rpc_spawn',
'distributed/rpc/tensorpipe/test_dist_autograd_spawn',
'distributed/rpc/tensorpipe/test_dist_optimizer_spawn',
'distributed/rpc/tensorpipe/test_rpc_spawn',
'distributed/rpc/test_dist_autograd_spawn',
'distributed/rpc/test_dist_optimizer_spawn',
Expand Down Expand Up @@ -143,6 +149,8 @@
'test_jit_profiling',
'test_torch',
'distributed/test_distributed',
'distributed/rpc/tensorpipe/test_dist_autograd_spawn',
'distributed/rpc/tensorpipe/test_dist_optimizer_spawn',
'distributed/rpc/tensorpipe/test_rpc_spawn',
'distributed/rpc/test_dist_autograd_spawn',
'distributed/rpc/test_rpc_spawn',
Expand Down
8 changes: 8 additions & 0 deletions test/test_torch.py
Expand Up @@ -10983,6 +10983,14 @@ def test_empty_strided(self, device):
self.assertEqual(empty_strided.shape, as_strided.shape)
self.assertEqual(empty_strided.stride(), as_strided.stride())

def test_strided_mismatched_stride_shape(self, device):
for shape, strides in [((1, ), ()), ((1, 2), (1, ))]:
with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"):
torch.tensor(0.42, device=device).as_strided(shape, strides)

with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"):
torch.tensor(0.42, device=device).as_strided_(shape, strides)

def test_sign(self, device):
for dtype in torch.testing.get_all_math_dtypes(device):
if dtype.is_complex:
Expand Down
7 changes: 7 additions & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -116,12 +116,19 @@ def _get_backcompat_keepdim_warn() -> _bool: ...
def _is_xnnpack_enabled() -> _bool: ...
def _get_mkldnn_enabled() -> _bool: ...
def _set_mkldnn_enabled(arg: _bool) -> None: ...
def _get_cudnn_enabled() -> _bool: ...
def _set_cudnn_enabled(arg: _bool) -> None: ...
def _get_cudnn_benchmark() -> _bool: ...
def _set_cudnn_benchmark(arg: _bool) -> None: ...
def _get_cudnn_deterministic() -> _bool: ...
def _set_cudnn_deterministic(arg: _bool) -> None: ...
def _set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def _set_default_dtype(d: _dtype) -> None: ...
def _initExtension(shm_manager_path: str) -> None: ...
has_openmp: _bool
has_mkldnn: _bool
has_mkl: _bool
has_cudnn: _bool
_GLIBCXX_USE_CXX11_ABI: _bool

# Defined in torch/csrc/jit/python/script_init.cpp
Expand Down
17 changes: 17 additions & 0 deletions torch/_C/_cudnn.pyi
@@ -0,0 +1,17 @@
from enum import Enum

from torch.types import Tuple, Number, _bool

# Defined in torch/csrc/cuda/shared/cudnn.cpp
is_cuda: _bool

def getRuntimeVersion() -> Tuple[int, int, int]: ...
def getCompileVersion() -> Tuple[int, int, int]: ...
def getVersionInt() -> int: ...

class RNNMode(int, Enum):
value: int
rnn_relu = ...
rnn_tanh = ...
lstm = ...
gru = ...
11 changes: 6 additions & 5 deletions torch/backends/cudnn/__init__.py
Expand Up @@ -7,7 +7,7 @@
try:
from torch._C import _cudnn
except ImportError:
_cudnn = None
_cudnn = None # type: ignore

# Write:
#
Expand Down Expand Up @@ -83,11 +83,7 @@ def is_acceptable(tensor):
return True


_handles = {}


def set_flags(_enabled, _benchmark, _deterministic):
global benchmark, deterministic
orig_flags = (torch._C._get_cudnn_enabled(),
torch._C._get_cudnn_benchmark(),
torch._C._get_cudnn_deterministic())
Expand Down Expand Up @@ -124,3 +120,8 @@ def __init__(self, m, name):
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)

# Add type annotation for the replaced module
enabled: bool
deterministic: bool
benchmark: bool
4 changes: 2 additions & 2 deletions torch/backends/cudnn/rnn.py
Expand Up @@ -5,7 +5,7 @@
except ImportError:
# Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
# so it's safe to not emit any checks here.
_cudnn = None
_cudnn = None # type: ignore


def get_cudnn_mode(mode):
Expand Down Expand Up @@ -48,7 +48,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
if dropout_p == 0:
dropout_state[dropout_desc_name] = Unserializable(None)
else:
dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state(
dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( # type: ignore
dropout_p,
train,
dropout_seed,
Expand Down
19 changes: 10 additions & 9 deletions torch/quantization/quantize.py
Expand Up @@ -73,26 +73,27 @@ def _observer_forward_hook(self, input, output):
"""
return self.activation_post_process(output)

def add_observer_(module):
def add_observer_(module, device=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
has a valid qconfig attribute.
Args:
module: input module with qconfig attributes for all the leaf modules that we want to quantize
device: parent device, if any
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
# respect device affinity when adding observers
# devices = {p.device for p in module.parameters()}
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
if device is None:
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None

for child in module.children():
if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional:
Expand All @@ -102,7 +103,7 @@ def add_observer_(module):
activation.to(device)
child.activation_post_process = activation
else:
add_observer_(child)
add_observer_(child, device)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
Expand Down

0 comments on commit a72eeed

Please sign in to comment.