From a7a5992d7dcfb4694b9fe6a1efc5b13a92ca3cf4 Mon Sep 17 00:00:00 2001 From: Jeffrey Wan Date: Tue, 25 May 2021 13:05:44 -0700 Subject: [PATCH 01/18] Add no-grad inference mode note (#58513) Summary: Adds a note explaining the difference between several often conflated mechanisms in the autograd note Also adds a link to this note from the docs in `grad_mode` and `nn.module`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58513 Reviewed By: gchanan Differential Revision: D28651129 Pulled By: soulitzer fbshipit-source-id: af9eb1749b641fc1b632815634eea36bf7979156 --- docs/cpp/source/notes/inference_mode.rst | 2 - docs/source/autograd.rst | 4 + docs/source/notes/autograd.rst | 201 +++++++++++++++++------ torch/autograd/grad_mode.py | 17 ++ torch/nn/modules/module.py | 6 + torch/nn/parameter.py | 2 +- 6 files changed, 179 insertions(+), 53 deletions(-) diff --git a/docs/cpp/source/notes/inference_mode.rst b/docs/cpp/source/notes/inference_mode.rst index 2ceb2dcdb762..efb1b9de2d1a 100644 --- a/docs/cpp/source/notes/inference_mode.rst +++ b/docs/cpp/source/notes/inference_mode.rst @@ -30,8 +30,6 @@ Inside an ``InferenceMode`` block, we make the following performance guarantees: - Inplace operations on inference tensors are guaranteed not to do a version bump. For more implementation details of ``InferenceMode`` please see the `RFC-0011-InferenceMode `_. -Currently this guard is only available in C++ frontend, adding python frontend support -is tracked in #56608. Migration guide from ``AutoNonVariableTypeMode`` ------------------------------------------------ diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst index 5bc588b0fa8b..566808036701 100644 --- a/docs/source/autograd.rst +++ b/docs/source/autograd.rst @@ -50,6 +50,10 @@ you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), inp Locally disabling gradient computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +See :ref:`locally-disable-grad-doc` for more information on the differences +between no-grad and inference mode as well as other related mechanisms that +may be confused with the two. + .. autosummary:: :toctree: generated :nosignatures: diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index c15a0d0340a5..6d0e0e83d3d2 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -8,56 +8,6 @@ operations. It's not strictly necessary to understand all this, but we recommend getting familiar with it, as it will help you write more efficient, cleaner programs, and can aid you in debugging. -.. _excluding-subgraphs: - -Excluding subgraphs from backward ---------------------------------- - -Every Tensor has a flag: :attr:`requires_grad` that allows for fine grained -exclusion of subgraphs from gradient computation and can increase efficiency. - -.. _excluding-requires_grad: - -``requires_grad`` -^^^^^^^^^^^^^^^^^ - -If there's a single input to an operation that requires gradient, its output -will also require gradient. Conversely, only if all inputs don't require -gradient, the output also won't require it. Backward computation is never -performed in the subgraphs, where all Tensors didn't require gradients. - -.. code:: - - >>> x = torch.randn(5, 5) # requires_grad=False by default - >>> y = torch.randn(5, 5) # requires_grad=False by default - >>> z = torch.randn((5, 5), requires_grad=True) - >>> a = x + y - >>> a.requires_grad - False - >>> b = a + z - >>> b.requires_grad - True - -This is especially useful when you want to freeze part of your model, or you -know in advance that you're not going to use gradients w.r.t. some parameters. -For example if you want to finetune a pretrained CNN, it's enough to switch the -:attr:`requires_grad` flags in the frozen base, and no intermediate buffers will -be saved, until the computation gets to the last layer, where the affine -transform will use weights that require gradient, and the output of the network -will also require them. - -.. code:: - - model = torchvision.models.resnet18(pretrained=True) - for param in model.parameters(): - param.requires_grad = False - # Replace the last fully-connected layer - # Parameters of newly constructed modules have requires_grad=True by default - model.fc = nn.Linear(512, 100) - - # Optimize only the classifier - optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9) - .. _how-autograd-encodes-history: How autograd encodes the history @@ -86,6 +36,157 @@ flow statements, that can change the overall shape and size of the graph at every iteration. You don't have to encode all possible paths before you launch the training - what you run is what you differentiate. +.. _locally-disable-grad-doc: + +Locally disabling gradient computation +-------------------------------------- + +There are several mechanisms available from Python to locally disable gradient +computation: + +To disable gradients across entire blocks of code, there are context managers +like no-grad mode and inference mode. +For more fine-grained exclusion of subgraphs from gradient computation, +there is setting the ``requires_grad`` field of a tensor. + +Below, in addition to discussing the mechanisms above, we also describe +evaluation mode (:meth:`nn.Module.eval()`), a method that is not actually used +to disable gradient computation but, because of its name, is often mixed up with the three. + +Setting ``requires_grad`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +:attr:`requires_grad` is a flag that allows for fine-grained exclusion of +subgraphs from gradient computation. It takes effect in both the forward +and backward passes: + +During the forward pass, an operation is only recorded in the backward graph if +at least one of its input tensors require grad. +During the backward pass (``.backward()``), only leaf tensors with +``requires_grad=True`` will have gradients accumulated into their ``.grad`` +fields. + +It is important to note that even though every tensor has this flag, +*setting* it only makes sense for leaf tensors (tensors that do not have a +``grad_fn``, e.g., a ``nn.Module``'s parameters). +Non-leaf tensors (tensors that do have ``grad_fn``) are tensors that have a +backward graph associated with them. Thus their gradients will be needed +as an intermediary result to compute the gradient for a leaf tensor that +requires grad. From this definition, it is clear that all non-leaf tensors +will automatically have ``require_grad=True``. + +Setting ``requires_grad`` should be the main way you control which parts +of the model are part of the gradient computation, for example, if you need to +freeze parts of your pretrained model during model fine-tuning. + +To freeze parts of your model, simply apply ``.requires_grad_(False)`` to +the parameters that you don't want updated. And as described above, +since computations that use these parameters as inputs would not be recorded in +the forward pass, they won't have their ``.grad`` fields updated in the backward +pass because they won't be part of the backward graph in the first place, as +desired. + +Because this is such a common pattern, ``requires_grad`` can also be set at +the module level with :meth:`nn.Module.requires_grad_()`. +When applied to a module, ``.requires_grad_()`` takes effect on all +of the module's parameters (which have ``requires_grad=True`` by default). + +Grad Modes +^^^^^^^^^^ + +Apart from setting ``requires_grad`` there are also three possible modes +enableable from Python that can affect how computations in PyTorch are +processed by autograd internally: default mode (grad mode), no-grad mode, +and inference mode, all of which can be togglable via context managers and +decorators. + +Default Mode (Grad Mode) +^^^^^^^^^^^^^^^^^^^^^^^^ + +The "default mode" is actually the mode we are implicitly in when no other modes like +no-grad and inference mode are enabled. To be contrasted with +"no-grad mode" the default mode is also sometimes called "grad mode". + +The most important thing to know about the default mode is that it is the only +mode in which ``requires_grad`` takes effect. ``requires_grad`` is always overridden +to be ``False`` in both the two other modes. + +No-grad Mode +^^^^^^^^^^^^ + +Computations in no-grad mode behave as if none of the inputs require grad. +In other words, computations in no-grad mode are never recorded in the backward graph +even if there are inputs that have ``require_grad=True``. + +Enable no-grad mode when you need to perform operations that should not be +recorded by autograd, but you’d still like to use the outputs of these +computations in grad mode later. This context manager makes it convenient to +disable gradients for a block of code or function without +having to temporarily set tensors to have ``requires_grad=False``, and then +back to ``True``. + +For example, no-grad mode might be useful when writing an optimizer: when +performing the training update you’d like to update parameters +in-place without the update being recorded by autograd. +You also intend to use the updated parameters for computations in +grad mode in the next forward pass. + +The implementations in :ref:`nn-init-doc` also +rely on no-grad mode when initializing the parameters as to avoid +autograd tracking when updating the intialized parameters in-place. + +Inference Mode +^^^^^^^^^^^^^^ + +Inference mode is the extreme version of no-grad mode. Just like in no-grad +mode, computations in inference mode are not recorded in the backward graph, but +enabling inference mode will allow PyTorch to speed up your model even more. +This better runtime comes with a drawback: tensors created in inference mode +will not be able to be used in computations to be recorded by autograd after +exiting inference mode. + +Enable inference mode when you are performing computations that don’t need +to be recorded in the backward graph, AND you don’t plan on using the tensors +created in inference mode in any computation that is to be recorded by autograd later. + +It is recommended that you try out inference mode in the parts of your code +that do not require autograd tracking (e.g., data processing and model evaluation). +If it works out of the box +for your use case it’s a free performance win. If you run into errors after +enabling inference mode, check that you are not using tensors created in +inference mode in computations that are recorded by autograd after exiting inference +mode. If you cannot avoid such use in your case, you can always switch back +to no-grad mode. + +For details on inference mode please see +`Inference Mode `_. + +For implementation details of inference mode see +`RFC-0011-InferenceMode `_. + +Evaluation Mode (``nn.Module.eval()``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Evaluation mode is not actually a mechanism to locally disable gradient computation. +It is included here anyway because it is sometimes confused to be such a mechanism. + +Functionally, ``module.eval()`` (or equivalently ``module.train()``) are completely +orthogonal to no-grad mode and inference mode. How ``model.eval()`` affects +your model depends entirely on the specific modules used in your model and +whether they define any training-mode specific behavior. + +You are responsible for calling ``model.eval()`` and ``model.train()`` if your +model relies on modules such as :class:`torch.nn.Dropout` and +:class:`torch.nn.BatchNorm2d` that may behave +differently depending on training mode, for example, to avoid updating your +BatchNorm running statistics on validation data. + +It is recommended that you always use ``model.train()`` when +training and ``model.eval()`` when evaluating your model (validation/testing) even +if you aren’t sure your model has training-mode specific behavior, because a +module you are using might be updated to behave differently in training and +eval modes. + In-place operations with autograd --------------------------------- diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 7cbd5516e563..1cabb72b1e38 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -97,6 +97,10 @@ class no_grad(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + No-grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -136,6 +140,10 @@ class enable_grad(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + enable_grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -178,6 +186,10 @@ class set_grad_enabled(object): (``False``). This can be used to conditionally enable gradients. + .. note:: + set_grad_enabled is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -222,6 +234,11 @@ class inference_mode(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + Inference mode is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + Args: mode (bool): Flag whether to enable or disable inference mode diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 5aa97c93e156..3739bb2c8848 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1651,6 +1651,9 @@ def eval(self: T) -> T: This is equivalent with :meth:`self.train(False) `. + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + Returns: Module: self """ @@ -1666,6 +1669,9 @@ def requires_grad_(self: T, requires_grad: bool = True) -> T: This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 7a2d1e4e839b..df562e27382c 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -18,7 +18,7 @@ class Parameter(torch.Tensor): Args: data (Tensor): parameter tensor. requires_grad (bool, optional): if the parameter requires gradient. See - :ref:`excluding-subgraphs` for more details. Default: `True` + :ref:`locally-disable-grad-doc` for more details. Default: `True` """ def __new__(cls, data=None, requires_grad=True): if data is None: From 813adf1076fa217d132d1090f8bc02af9c13e110 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 25 May 2021 13:16:46 -0700 Subject: [PATCH 02/18] [Pytorch Delegated Backend] Save operator name and function name in (#57441) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57441 debug info Previous diffs did not save operator name in debug info. For delegated backends that only idenfity op for profiling with debug handle, operator name should be stores as well. Furthermore to complete debug informaton also serialize function name. Test Plan: Existing lite interpreter and backend tests Existing lite interpreter and backend tests Imported from OSS Differential Revision: D28144581 D28144581 Reviewed By: raziel Pulled By: kimishpatel fbshipit-source-id: 415210f147530a53b444b07f1d6ee699a3570d99 --- test/cpp/jit/test_backend.cpp | 10 +++--- .../jit/test_cs_debug_info_serialization.cpp | 20 ++++++----- test/cpp/jit/test_lite_interpreter.cpp | 33 +++++++++--------- .../delegated_submodule_with_debug_info.ptl | Bin 9937 -> 10129 bytes .../test_lite_interpreter_runtime.cpp | 2 +- .../jit/backends/backend_debug_handler.cpp | 2 +- .../csrc/jit/backends/backend_debug_handler.h | 22 ++++++------ torch/csrc/jit/ir/scope.h | 9 ++++- torch/csrc/jit/mobile/debug_info.cpp | 18 +++++++--- torch/csrc/jit/mobile/debug_info.h | 2 +- .../callstack_debug_info_serialization.cpp | 32 ++++++++++------- .../callstack_debug_info_serialization.h | 4 +-- .../csrc/jit/serialization/export_module.cpp | 10 +++--- 13 files changed, 95 insertions(+), 69 deletions(-) diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index b85994b3ee31..400ef1c8fc6d 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -190,7 +190,7 @@ TEST(BackendTestDebugInfo, TestCompiler) { lm._save_for_mobile(ss, ExtraFilesMap(), true); auto mlm = _load_for_mobile(ss); std::string error_pattern = R"( - Module hierarchy:top(backend_with_compiler_demoLoweredModule) + Module hierarchy:top(backend_with_compiler_demoLoweredModule).aten::add Traceback of TorchScript (most recent call last): File "", line 5, in FunctionName_UNKNOWN typed_inputs: List[Any] = [x, h, ] @@ -244,7 +244,7 @@ TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) { lm._save_for_mobile(ss, ExtraFilesMap(), true); auto mlm = _load_for_mobile(ss); std::string error_pattern = R"( - Module hierarchy:top(backend_with_compiler_demoLoweredModule).A0(A) + Module hierarchy:top(backend_with_compiler_demoLoweredModule).A0(A).aten::add Traceback of TorchScript (most recent call last): File "", line 5, in FunctionName_UNKNOWN typed_inputs: List[Any] = [x, y, ] @@ -337,7 +337,7 @@ TEST( * */ std::string error_pattern = R"( - Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A) + Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A).aten::add Traceback of TorchScript (most recent call last): File "", line 5, in FunctionName_UNKNOWN typed_inputs: List[Any] = [x, y, ] @@ -424,7 +424,7 @@ TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) { c._save_for_mobile(ss, ExtraFilesMap(), true); auto c_loaded = _load_for_mobile(ss); std::string error_pattern = R"( - Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule) + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).aten::add Traceback of TorchScript (most recent call last): File "", line 3, in FunctionName_UNKNOWN @@ -545,7 +545,7 @@ TEST( * * */ std::string error_pattern = R"( - Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add Traceback of TorchScript (most recent call last): File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/jit/test_cs_debug_info_serialization.cpp b/test/cpp/jit/test_cs_debug_info_serialization.cpp index f5b816cbacfd..8db461003094 100644 --- a/test/cpp/jit/test_cs_debug_info_serialization.cpp +++ b/test/cpp/jit/test_cs_debug_info_serialization.cpp @@ -25,21 +25,23 @@ namespace jit { namespace { bool validate_debug_info( - const DebugInfoPair& pre_serialize, - const DebugInfoPair& post_serialize) { - auto sr1 = pre_serialize.first; - auto sr2 = post_serialize.first; + const DebugInfoTuple& pre_serialize, + const DebugInfoTuple& post_serialize) { + auto sr1 = std::get(pre_serialize); + auto sr2 = std::get(post_serialize); if (sr1 != sr2) { return false; } - if (!pre_serialize.second.defined()) { - return !post_serialize.second.defined(); + auto csptr1 = std::get(pre_serialize); + auto csptr2 = std::get(post_serialize); + if (!csptr1.defined()) { + return !csptr2.defined(); } - if (!post_serialize.second.defined()) { + if (!csptr2.defined()) { return false; } - auto vec1 = pre_serialize.second->vec(); - auto vec2 = post_serialize.second->vec(); + auto vec1 = csptr1->vec(); + auto vec2 = csptr2->vec(); if (vec1.size() != vec2.size()) { return false; } diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index ece646f6ede8..fa4b3294fe8b 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -496,8 +496,7 @@ TEST(LiteInterpreterTest, ModuleInfoBasic) { } } - std::unordered_set expected_result({"top(M)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(M).aten::mul")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -559,8 +558,9 @@ TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) { } } - std::set expected_result({"top(B)", "top(B).A0(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::mul")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -594,7 +594,6 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { std::string module_info = bc.get_forward_method_debug_info(pc); if (!module_info.empty() && (module_info.find("debug_handle") == std::string::npos)) { - std::cout << "Module info:" << module_info << std::endl; module_debug_info_set.insert(module_info); } ++pc; @@ -603,9 +602,9 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { } } - std::set expected_result( - {"top(C)", "top(C).A0(A)", "top(C).B0(B)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); } TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) { @@ -790,9 +789,9 @@ TEST(LiteInterpreterTest, SequentialModuleInfo) { // def forward(self, x): // return self.A0.forward(self.B0.forward(x)) - std::set expected_result( - {"top(C)", "top(C).A0(A)", "top(C).B0(B)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).prim::Return")); + AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -838,9 +837,9 @@ TEST(LiteInterpreterTest, HierarchyModuleInfo) { // "top(C).forward": for the add operator in top. // "top(C).B0(B).forward": for the add operator in B0. // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0. - std::set expected_result( - {"top(C)", "top(C).B0(B)", "top(C).B0(B).A0(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).A0(A).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -898,9 +897,9 @@ TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) { // "top(B).A0(A).forward": for the add operator in A0. // "top(B).A1(A).forward": for the add operator in A1. - std::set expected_result( - {"top(B)", "top(B).A0(A)", "top(B).A1(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A1(A).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl b/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl index 06300d1136ff74942ed53719ccbdf24a2c6da353..8c0b7528993ab53a82ae08943df1f792e32f176b 100644 GIT binary patch delta 1394 zcmZ`(c}&v>811j6i)m%QY?$B{KrEgS@? zL4y%H#1j=ERKyD{L#Rw|S3$lghs6>Ig>0cHtzg^c3_*@iEDBEVNpArZRj{i%%8 z2nKT+mEmk~z8DVLja!Z5uVbSJ@PUdhEY}dhE3`;79j)tA>2>HT%Tov*{lRjR1n@ur zZO{Q-&;xx~0R~_QMnD7-7y}thz!c2D94x>RtiT#5U<0;51sd1^9ae%pIDjKK!75;Y zGq`{&xPd!(fG2o?H!#5md|@@L0YC5u76bqr0wGA0A5$tVl_myZFz7R@PzkIB4g|wG z-~taqfDh{-6v7}J03sj~Hb4|a!$ydKScrpo_zDsr5t1MoQs8UY1gWqYwtxWAARUAt zf~~L(G9VMOAREN69dcj?q8+{AkuXw%sVopM>x6JJ+(bGeLDaqD}PS#R}7%{hu+u{!ocdqAz_ zMe292ZK9;!pT#IAjP$9?r9nVF*P|#K|b?v0)9!JAie_Yrh z)e9RfRwc7Zj2nf9rDYG*2#RN{e0;T$=`E;gQ}#5+g-5&@AmhT{jvgFM=ni(s8@VBs z#Fmgt9$*S5s))nHojOK+CQ-j+N2jb|7yZ^h-8#{r=sIVveEHPiqVL@vy1F&1gPht@ zzDGWg9anF04OMe?rgbp*p~gR@n>cT6Jz~_0ot?;{~-ES{Xmj%GBOl2fFK7l-ye|Cv1DT{HZ~n3t9; zX3i`WBpGH6SbJTW#CxiIqP8yVZz7c+ZV%CL_y!3Ty00yyaS>c8VKip4%V_UZfa#*q zUp#&Qq02h^Ge>xZ|y{0%h>-O7grRXMQGd|36Iq?XkHd7vdAW z=0@c832N4ekC-M-%hi8@3m=u}P)QZ9-j}#&3AANA=}l#j>quBhwKfC%$AuAUDjm%6F zWrycN5)$d@(RuiWDvg4Es`0y{XRk_?g8Q@eiwEY&FUz9VFw{n>n9I|XfXwb@~;lXP{JQU-`ze_m1Lns`@(^4Zr zp>TRzoyA7!^dS{R5AzR6C6Ytpe!*NACsBxMK_N`>b{1pbb15i;kB(GUAR8{%k|iv9 zaa8=+I8`bW&lv|QxeH7NcY`TlD!2#S3#NhT;689am;q*j2f%}17I+BE26Mo};1TdB zcnr)1kArz&K6nBw0DlF41Ahk#K^0g8o&-;Ur@>;d1S|#1KsPmb20RO%1J8pOz;f^+ zcnPckFN2j}6<7_{fL9Q*e%7|_+qTd55VsnQ73)k4m8=Epz$-T*tmzrjwh3;YMX3El#`!5;87cn7=--UEBV`{2J|ANT-#2tEQI zgHOPx;D4agGq4|g4!!_if)asJfeut3}Tv=ncykWdG#S{ zF~@Wyv?T?PO#P(F`to)40lFHIotI?ZzS+hVo?A_A6Q6#lZ=<;)$xRX$p7O=&JaxiX zBVCH3PuzXbw)tN1E4 zD;k}Yb9P^kcXuj{Z1tZUm~7#DUG3UyUaYaRG`eBmndqRq7%6El{C zC7e1R7-QNoBc@m>teMwPzbD+M#Xjl>)q1OX^oD(+^((C&W-J9*H8@!kG2Uw|P=m)BV#zv89iBgy-`1@!`uIE3G1W zBmD~;U3E>CTD+Jv-})vL0tzFcWhR5N{9y2D}T65XJT zCkC5_wQVg-eqdg=x!2NbW@wjHTdBWW|NA9B+g*5x*Sb;r&ehdfQSBPqB8!-aI3lc3=!W+l?SufNcq`~9XU!%cvfXgv8 z6W7}4nhoAGK{7ZVnGJxYHeN`x7--Ywn&KE+M?F3H8!o_4L-TbPVIxM_nrOTk{|^=> B_dfst diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index e76e36b3ff95..3072f21d4687 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -142,7 +142,7 @@ TEST(RunTimeTest, DelegateException) { inputs.emplace_back(torch::rand({13, 9})); std::string error_pattern = R"( - Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add Traceback of TorchScript (most recent call last): File "", line 3, in FunctionName_UNKNOWN diff --git a/torch/csrc/jit/backends/backend_debug_handler.cpp b/torch/csrc/jit/backends/backend_debug_handler.cpp index b0d4fd3daa36..d21e4efd5681 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.cpp +++ b/torch/csrc/jit/backends/backend_debug_handler.cpp @@ -20,7 +20,7 @@ int64_t BackendDebugInfoRecorder::getNextDebugHandle(const Node* node) { DebugHandleType debug_handle = unique_debug_handle_; const SourceRange& range = node->sourceRange(); handles_to_inlined_callstack_ptrs_[debug_handle] = - std::make_pair(range, cs_ptr); + std::make_tuple(range, node->kind().toQualString(), cs_ptr); // This increment is with seq memory order. // Not trying to perf optimizing this for now. unique_debug_handle_++; diff --git a/torch/csrc/jit/backends/backend_debug_handler.h b/torch/csrc/jit/backends/backend_debug_handler.h index 1e121f0ad04c..60727bfcc242 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.h +++ b/torch/csrc/jit/backends/backend_debug_handler.h @@ -13,7 +13,7 @@ namespace jit { * BackendDebugHandleManager is responsible for issuing debug handles to * backends. Debug handles are associated with nodes of a graph. * BackendDebugHandleManager also maintains a map - * [debug-handle, DebugInfoPair = {source range, inlined callstack ptr]} that + * [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that * will help generate a callstack for exception raised using debug handles. * Effectively debug handles are something that is given to backend and later * when an exception occurs in the backend, backend can tell, using debug @@ -21,14 +21,14 @@ namespace jit { * callstack correspoding to the exception. * There are two parts to BackendDebugHandleManager: * 1. static std::atomic debug_handle - * 2. Map of [debug-handle, DebugInfoPair] + * 2. Map of [debug-handle, DebugInfoTuple] * * About 1: * Why do they have to be unique. The reason is that by ensuring * uniqueness of debug handles, we remove the burden of another layer of * mapping where we need to say this set of debug handles were generated for * this lowered module or this bytecode function. This simplifies the API for - * serialization since debug handles can uniquely identify DebugInfoPair. + * serialization since debug handles can uniquely identify DebugInfoTuple. * Thus simplifies the runtime API for throwing exception. Exception throwing * only needs to know debug_handle and not which module or method threw it. * There are 2 issues to keep in mind, though,for static std::atomic @@ -40,8 +40,8 @@ namespace jit { * done. * * Now about 2: - * There are two usecases for [debug-handle, DebugInfoPair] - * A. During bytecode generation the DebugInfoPair corresponding to the nodes + * There are two usecases for [debug-handle, DebugInfoTuple] + * A. During bytecode generation the DebugInfoTuple corresponding to the nodes * of the inlined graph being serialized, are stored in this object and a * unique debug handle is returned. This unique debug handle is stored in * mobile_debug info for pytorch lite models. It will be used for raising @@ -52,13 +52,13 @@ namespace jit { * the debug handles provide a way to map nodes of the graph to the model level * debug info. * - * During byte-code model serialization, [debug-handle, DebugInfoPair] is + * During byte-code model serialization, [debug-handle, DebugInfoTuple] is * serialized. Now we know a. debug handles and b. how to map debug handles to * model source code. Thus we can either do eager symbolication by converting * debug handles to corresponding source code at runtime, or do lazy * symbolicattion offline. * - * Note that it is not necessary to serialize [debug-handle, DebugInfoPair] + * Note that it is not necessary to serialize [debug-handle, DebugInfoTuple] * corresponding to lowered backend if the lowering process, that is * preprocess/compile, and execution happens in the same session, then eager * symbolication can be employed. @@ -66,15 +66,15 @@ namespace jit { * Now how does BackendDebugHandleManager capture all of the above? * By providing two API. * 1. getNextDebugHandle which given a Node* returns a unique debug handle, - * that will uniquely identify DebugInfoPair. + * that will uniquely identify DebugInfoTuple. * and * 2. getCallStackPtrMap which returns the map - * [debug-handle, DebugInfoPair] + * [debug-handle, DebugInfoTuple] * * 1 provides debug handles to backends and 2 provides runtime a way to map * debug handles to source level debug info. * - * So why does debug handle map to DebugInfoPair = {source range and inlined + * So why does debug handle map to DebugInfoTuple = {source range and inlined * cs}? {debug_handle, source_range_tag, serialized_callstack} Take this * example: class L(nn.Module): def __init__(self): * ... @@ -112,7 +112,7 @@ namespace jit { using DebugHandleType = int64_t; using BackendDebugInfoMapType = - std::unordered_map; + std::unordered_map; /* * This class is used to generate debug info map. diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index c0155e5db94b..a8166129cb6f 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -175,6 +175,13 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { } }; -using DebugInfoPair = std::pair; +// {source range, node name, InlinedCallStack} +// We store node name because same debug infor will be used for +// profiling as well, so we need to know op names as well. +using DebugInfoTuple = + std::tuple; +constexpr size_t kDebugInfoTupleSourceRangeIndex{0}; +constexpr size_t kDebugInfoTupleNodeNameIndex{1}; +constexpr size_t kDebugInfoTupleInlinedCSIndex{2}; } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index 9f0c8b7f2843..07b797b3d691 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -14,13 +14,15 @@ namespace jit { namespace { std::pair, std::string> getStackTraceWithModuleHierarchy( - const DebugInfoPair& source_callstack) { + const DebugInfoTuple& source_callstack) { constexpr size_t kSourceRange = 1; constexpr size_t kModuleInstanceInfo = 2; std::vector entries; - const SourceRange& range = source_callstack.first; - InlinedCallStackPtr callstack_ptr = source_callstack.second; + const SourceRange& range = + std::get(source_callstack); + InlinedCallStackPtr callstack_ptr = + std::get(source_callstack); std::string module_info; if (!callstack_ptr) { // If not cs then top level node @@ -70,7 +72,7 @@ std::pair, std::string> getStackTraceWithModuleHierarchy // will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv) // Source level stack information will be from model source code. std::pair getStackTraceWithModuleHierarchy( - const std::vector& source_callstacks, + const std::vector& source_callstacks, const std::string& root_scope_string, const std::string& top_module_type_name) { std::vector stack_entries; @@ -82,6 +84,12 @@ std::pair getStackTraceWithModuleHierarchy( stack_entries.insert(stack_entries.end(), entries.begin(), entries.end()); module_info += debug_info_pair.second; } + // Only last entry in the callstack will have a node name of interest. + // Rest are likely CallMethod/CallFunction nodes + auto last_entry = source_callstacks.back(); + const std::string& node_name = + std::get(last_entry); + module_info += "." + node_name; std::ostringstream ss; ss << "Module hierarchy:" << module_info << "\n"; format_stack_trace(ss, stack_entries); @@ -177,7 +185,7 @@ std::pair MobileDebugTable:: getSourceDebugModuleHierarchyInfo( const std::vector& debug_handles, const std::string& top_module_type_name) const { - std::vector debug_infos; + std::vector debug_infos; bool debug_handle_not_found{false}; for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) { auto debug_handle = *it; diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h index 66258ac353dd..444573ccd013 100644 --- a/torch/csrc/jit/mobile/debug_info.h +++ b/torch/csrc/jit/mobile/debug_info.h @@ -40,7 +40,7 @@ class MobileDebugTable { std::pair getSourceDebugModuleHierarchyInfo( const std::vector& debug_handles, const std::string& top_module_type_name = "ModuleTypeUnknown") const; - ska::flat_hash_map callstack_ptr_map_; + ska::flat_hash_map callstack_ptr_map_; }; } // namespace jit diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 2fb6f075b06e..6f0bdc6389b0 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -71,7 +71,7 @@ c10::IValue InlinedCallStackSerializer::serialize_module_instance_info( } std::vector CallStackDebugInfoPickler::pickle( - const std::unordered_map& callstack_ptrs, + const std::unordered_map& callstack_ptrs, const SourceRangeTagMap& source_range_tags) { std::vector ivalues; for (const auto& it : callstack_ptrs) { @@ -85,15 +85,20 @@ std::vector CallStackDebugInfoPickler::pickle( elements.reserve(3); elements.emplace_back(debug_handle); int64_t source_range_tag{kInvalidSourceRangeTag}; - const SourceRange& sr = it.second.first.findSourceRangeThatGenerated() - ? it.second.first.findSourceRangeThatGenerated().value() - : it.second.first; + const auto source_range = + std::get(it.second); + const SourceRange& sr = source_range.findSourceRangeThatGenerated() + ? source_range.findSourceRangeThatGenerated().value() + : source_range; auto sr_it = source_range_tags.find(sr); if (sr_it != source_range_tags.end()) { source_range_tag = sr_it->second; } elements.emplace_back(source_range_tag); - elements.emplace_back(css_.serialize(it.second.second, source_range_tags)); + elements.emplace_back(std::get(it.second)); + const auto inlined_cs_ptr = + std::get(it.second); + elements.emplace_back(css_.serialize(inlined_cs_ptr, source_range_tags)); c10::IValue tuple = c10::ivalue::Tuple::create(elements); ivalues.emplace_back(tuple); } @@ -190,23 +195,24 @@ c10::optional InlinedCallStackDeserializer:: return cached_module_instance_info_[tup]; } -ska::flat_hash_map CallStackDebugInfoUnpickler:: +ska::flat_hash_map CallStackDebugInfoUnpickler:: unpickle( at::DataPtr&& data, size_t size, const ska::flat_hash_map& source_range_map, const std::shared_ptr& cu) { auto ival = jit::unpickle(reinterpret_cast(data.get()), size); - ska::flat_hash_map callstack_ptrs; + ska::flat_hash_map callstack_ptrs; auto& ivalues = ival.toTuple()->elements(); for (auto& val : ivalues) { const auto tup_elems = val.toTuple()->elements(); TORCH_CHECK( - tup_elems.size() == 3, - "Pickled map must have three elements: " - "debug_handle, source_range_tag, IValue(inlined_call_stack)"); + tup_elems.size() == 4, + "Pickled map must have four elements: " + "debug_handle, source_range_tag, op name, IValue(inlined_call_stack)"); int64_t debug_handle = tup_elems[0].toInt(); int64_t source_range_tag = tup_elems[1].toInt(); + const std::string& node_name = tup_elems[2].toStringRef(); auto source_range_it = source_range_map.find(source_range_tag); TORCH_CHECK( source_range_it != source_range_map.end(), @@ -215,8 +221,10 @@ ska::flat_hash_map CallStackDebugInfoUnpickler:: TORCH_CHECK( callstack_ptrs.count(debug_handle) == 0, "Debug handles should be unique."); - callstack_ptrs[debug_handle] = std::make_pair( - source_range, csds_.deserialize(tup_elems[2], source_range_map, cu)); + callstack_ptrs[debug_handle] = std::make_tuple( + source_range, + node_name, + csds_.deserialize(tup_elems[3], source_range_map, cu)); } return callstack_ptrs; } diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.h b/torch/csrc/jit/serialization/callstack_debug_info_serialization.h index 219b0713ca00..ac1bdf8d3b1d 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.h +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.h @@ -49,7 +49,7 @@ class TORCH_API CallStackDebugInfoPickler { CallStackDebugInfoPickler() = default; std::vector pickle( - const std::unordered_map& callstack_ptrs, + const std::unordered_map& callstack_ptrs, const SourceRangeTagMap& source_range_tags); private: @@ -77,7 +77,7 @@ class InlinedCallStackDeserializer { class TORCH_API CallStackDebugInfoUnpickler { public: - ska::flat_hash_map unpickle( + ska::flat_hash_map unpickle( at::DataPtr&& data, size_t size, const ska::flat_hash_map& source_range_map, diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 6bb56d0954bb..65b7cf467ef9 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -341,12 +341,14 @@ SourceRangeRecords getBackendSourceRanges(const Module& m) { const auto& map = backend_debug_info->getDebugInfoMap(); if (map) { const auto& map_val = map.value(); - // This map is map of debug handle-to-delegateDebugInfoType - // DebugInfoPair = + // This map is map of debug handle-to-DebugInfoTuple + // DebugInfoTuple= for (const auto& it : map_val) { + auto& source_range = + std::get(it.second); sr_records.emplace_back( - std::numeric_limits::max(), it.second.first); - auto cs_ptr = it.second.second; + std::numeric_limits::max(), source_range); + auto cs_ptr = std::get(it.second); if (cs_ptr) { for (const auto& e : cs_ptr->vec()) { const auto sr = std::get(e); From ede3f5421f415171cb5bdf2610be502bef87c4e9 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 25 May 2021 13:16:46 -0700 Subject: [PATCH 03/18] [Pytorch Delegated Backend] Save function name in debug info (#57481) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57481 This diff introduces function name to InlinedCallStack. Since we are using InlinedCallStack for debug information in lite interpreter as well as delegate backends, where InlinedCallStack cannot be constructed from model source code, we need to save function name. In the absence of function name Function* is used to get name of the function. This is when JIT compiles code at runtime. When that is not possible, this diff introduces a way to obtain function name. Test Plan: test_backend test_cs_debug_info_serialization test_backend test_cs_debug_info_serialization Imported from OSS Differential Revision: D28159097 D28159097 Reviewed By: raziel, ZolotukhinM Pulled By: kimishpatel fbshipit-source-id: deacaea3325e27273f92ae96cf0cd0789bbd6e72 --- test/cpp/jit/test_backend.cpp | 12 ++--- .../jit/test_cs_debug_info_serialization.cpp | 29 +++++++--- test/cpp/jit/test_lite_interpreter.cpp | 51 ++++++++++++++++++ .../delegated_submodule_with_debug_info.ptl | Bin 10129 -> 10193 bytes .../test_lite_interpreter_runtime.cpp | 4 +- torch/csrc/jit/ir/scope.cpp | 36 +++++++++++-- torch/csrc/jit/ir/scope.h | 15 ++++++ torch/csrc/jit/mobile/debug_info.cpp | 30 ++++++++--- .../callstack_debug_info_serialization.cpp | 30 +++++++++-- 9 files changed, 176 insertions(+), 31 deletions(-) diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index 400ef1c8fc6d..bf4b48d3e233 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -259,7 +259,7 @@ Traceback of TorchScript (most recent call last): return self.A0.forward(x, y) + self.B0.forward(x) ~~~~~~~~~~~~~~~ <--- HERE - File "", line 3, in FunctionName_UNKNOWN + File "", line 3, in forward def forward(self, x, y): return x + y @@ -352,13 +352,13 @@ Traceback of TorchScript (most recent call last): return self.B0.forward(x, y) + 3 ~~~~~~~~~~~~~~~ <--- HERE - File "", line 3, in FunctionName_UNKNOWN + File "", line 3, in forward def forward(self, x, y): return self.A0.forward(x, y) + 2 ~~~~~~~~~~~~~~~ <--- HERE - File "", line 3, in FunctionName_UNKNOWN + File "", line 3, in forward def forward(self, x, y): return x + y @@ -432,7 +432,7 @@ Traceback of TorchScript (most recent call last): return self.A0.forward(x, y) + self.B0.forward(x) ~~~~~~~~~~~~~~~ <--- HERE - File "", line 5, in FunctionName_UNKNOWN + File "", line 5, in forward typed_inputs: List[Any] = [x, y, ] if self.__backend.is_available() : _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) @@ -553,7 +553,7 @@ Traceback of TorchScript (most recent call last): return self.A0.forward(x, y) + self.B0.forward(x) ~~~~~~~~~~~~~~~ <--- HERE - File "", line 5, in FunctionName_UNKNOWN + File "", line 5, in forward typed_inputs: List[Any] = [x, y, ] if self.__backend.is_available() : _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) @@ -566,7 +566,7 @@ Traceback of TorchScript (most recent call last): return self.AA0.forward(x, y) + 3 ~~~~~~~~~~~~~~~~ <--- HERE - File "", line 3, in FunctionName_UNKNOWN + File "", line 3, in forward def forward(self, x, y): return x + y diff --git a/test/cpp/jit/test_cs_debug_info_serialization.cpp b/test/cpp/jit/test_cs_debug_info_serialization.cpp index 8db461003094..c34f0da1b636 100644 --- a/test/cpp/jit/test_cs_debug_info_serialization.cpp +++ b/test/cpp/jit/test_cs_debug_info_serialization.cpp @@ -45,20 +45,37 @@ bool validate_debug_info( if (vec1.size() != vec2.size()) { return false; } - for (size_t i = 0; i < vec1.size(); i++) { - auto rhs_sr = std::get<1>(vec1[i]); - auto lhs_sr = std::get<1>(vec2[i]); - auto rhs_module = std::get<2>(vec1[i]); - auto lhs_module = std::get<2>(vec2[i]); + while (csptr1) { + auto rhs_sr = csptr1->source_range(); + auto lhs_sr = csptr2->source_range(); + auto rhs_module = csptr1->module_instance(); + auto lhs_module = csptr2->module_instance(); + std::string rhs_fn_name, lhs_fn_name; + if (csptr1->function()) { + rhs_fn_name = csptr1->function()->name(); + } else { + rhs_fn_name = csptr1->function_name(); + } + if (csptr2->function()) { + lhs_fn_name = csptr2->function()->name(); + } else { + lhs_fn_name = csptr2->function_name(); + } if (!((rhs_module.has_value() == lhs_module.has_value()) && (rhs_module.has_value() && (rhs_module.value().class_type()->name().value() == lhs_module.value().class_type()->name().value()) && (rhs_module.value().instance_name() == lhs_module.value().instance_name())) && - (rhs_sr == lhs_sr))) { + (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) { return false; } + if (csptr1->callee()) { + csptr1 = csptr1->callee().value(); + csptr2 = csptr2->callee().value(); + } else { + csptr1 = c10::intrusive_ptr(); + } } return true; } diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index fa4b3294fe8b..c5acb0f68d9d 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -1285,6 +1285,57 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) { testLiteModuleCompareResultTensors(m, inputs); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { + Module a("A"); + a.define(R"( + def bar(self, x, y): + return x + y + )"); + Module b("B"); + b.register_module("A0", a); + b.define(R"( + def foo(self, x, y): + return self.A0.bar(x, y) + 2 + )"); + Module c("C"); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.B0.foo(x, y) + 3 + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + std::stringstream ss; + c._save_for_mobile(ss, ExtraFilesMap(), true); + auto lite_m = _load_for_mobile(ss); + std::string error_pattern = R"( + Module hierarchy:top(C).B0(B).A0(A).aten::add +Traceback of TorchScript (most recent call last): + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.B0.foo(x, y) + 3 + ~~~~~~~~~~~ <--- HERE + + File "", line 3, in foo + + def foo(self, x, y): + return self.A0.bar(x, y) + 2 + ~~~~~~~~~~~ <--- HERE + + File "", line 3, in bar + + def bar(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern); +} + namespace { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static auto reg = diff --git a/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl b/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl index 8c0b7528993ab53a82ae08943df1f792e32f176b..901724d82225b87b78f97590d209235e207b45ee 100644 GIT binary patch delta 936 zcmV;Z16TZ!Pti}XKplT=l-pMmI~2xCL7-eJ6e{-u>N%(99If@P^;!@W1A_|ac#XqM zAcO5p_cUF&eDaCE+no$MU0ut!U-NxmcDl1Czp~{r8(<|jm&s&2C8Ao*3j)YEs?7S? zvdom-5-OalIPoRt6;9anqN9f3)k?h%eQrOR>h&sQPobZ!IPQNC{yvP86E6v1K*2fO zF3U1*9+z85*sRsovt z#4Q-s>k%uu)7_{Sw#CK|U$j&hN08HwCM!w<4OVc6t#$JdM)c};+<4rs?g~{`DuGcg z9kXe7qWdqrMq__lC>aMZrq!?~_~I-$t|f`5HgH66NN^Lu9Q+2sb%H9vH8vwh%IvlVAhh}Je0aHd^#z;z~9^+QnR{K#4y93ikUP5C^i)>Dg z89I-Vsh)_L$)w{eC>lo|Qz?yG8tX~zM`0AI&})Q0MF1C!V-&q9jaV9_t3HSEQ5%2l zC0-L|v@m}}j$wkFO;6bBlUg)Xur$SH&;HT0y|JQIFIbwfG;3*&&AXHK%jf(^s8ER} z=P+yZc{VRoTCjb)MF+f1aEHd?F6)!#>A1%R^wTlg^K`7ioSDP5D$Jk40_!((xQO=q z+AcNKEIqLFkam4U@R)|~34&?gQv~yU)_VqFuhV~Oh|MSvKVV6#YiKjDPVkI^8wAfu zdO`4#V3RzrXoV4M5p1)(G>?dbCn0epCAbV zBa?vzB(uOEy#)y-!5Y{=0{{THlZGaG0g#gcCteW}ApigX0000`O9lr30001QlW->+ K1`;6v0000lkhFXN delta 886 zcmV-+1Bv|6PmxcsKplTolv!8XFc5`9Si?>r>}yxLrJ;M%H7up6h@n7=yQ;<($dK5P zqZ~MW>Qnl+JCYgYSa*5BmnD(NDNb}ErbIFbeVyyJNyVaZj}NtSrRUIc+s z8AdLrG)obyPHlu!ny^s6~D?pWg_Bx^x05 zy15b)B*DAYTRO-}3`uCVJI+Kr;43E9&@lwsq;U zn)Jg!^eIP`WVnV3N!xwKT09XcwR#SPhs^2$lLaXw(XXU4GGULlzQU1 z+k)Fy!9gW)*tLI#Q-l+Q#{hccj}X=natIH}6i?Y}P+K=Vd*;jRvE!bzz@rzwal9L6JD^0d8N>6RG1jDlj=ZnVz?(qax7}J((P${Mr6I+H@tJQwp z*AwG1k=F7OL}FTOGw%3xSq^H@8!@is1u(`mPiFXlrV)sYw?*_oMh#br3GEn$R7^vL z2HR5Qzz=-sI~D&A^F&5F2EZFrzo8yH=Wd7(nE0L(IaM*KgkE&?A*4)t)VzI6iJH=e z#>w>67fpYd6U!KN($JKlX+tw))*dteIi;se`jS;C6;oQDC9^!HIn%c%xWT6g&u}cB zlMb#giWj7d#~<3dDAvWa9>VpUm}!bx(g{O22i}*;%N6Afy)yI~OTIyPixc+_K=z1YrlEM25M(J}lTt zx^+iH3_84fe8yn57Wj~z0jX#m92TsNEov1?TeNHuYxl?!i%qdaMxiy!P-PeDE^tMA zqzBrs=(C}HZ2JN3>=2=h@I_ZxB4bc-1)h)evm%?MppK%ihQ49N@0kCA@Dt$|ej>%= z|BWEvka4c>z+;k*Juap^^spFY_CHWd2MG8bdxi}I003x{jUXQZ8k4~TB?@+Ba&u{K zZbm{DlYs>nvLU?%3HTj*h7AJ%0BDoOCVBx~lR+n55z-(40000008mQ?2LJ#7056lU MCmRORAOHXW0ID*YE&u=k diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index 3072f21d4687..2ccf6ee18d3a 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -150,7 +150,7 @@ Traceback of TorchScript (most recent call last): return self.A0.forward(x, y) + self.B0.forward(x) ~~~~~~~~~~~~~~~ <--- HERE - File "", line 5, in FunctionName_UNKNOWN + File "", line 5, in forward typed_inputs: List[Any] = [x, y, ] if self.__backend.is_available() : _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) @@ -163,7 +163,7 @@ Traceback of TorchScript (most recent call last): return self.AA0.forward(x, y) + 3 ~~~~~~~~~~~~~~~~ <--- HERE - File "", line 3, in FunctionName_UNKNOWN + File "", line 3, in forward def forward(self, x, y): return x + y diff --git a/torch/csrc/jit/ir/scope.cpp b/torch/csrc/jit/ir/scope.cpp index 474dc47cc9fd..b3fd559dcea3 100644 --- a/torch/csrc/jit/ir/scope.cpp +++ b/torch/csrc/jit/ir/scope.cpp @@ -88,7 +88,11 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() { } InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range) - : fn_(fn), source_range_(std::move(source_range)) {} + : fn_(fn), source_range_(std::move(source_range)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( Function* fn, @@ -96,7 +100,11 @@ InlinedCallStack::InlinedCallStack( c10::optional module_instance_info) : fn_(fn), source_range_(std::move(source_range)), - module_instance_info_(std::move(module_instance_info)) {} + module_instance_info_(std::move(module_instance_info)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, @@ -104,7 +112,11 @@ InlinedCallStack::InlinedCallStack( SourceRange source_range) : callee_(std::move(callee)), fn_(fn), - source_range_(std::move(source_range)) {} + source_range_(std::move(source_range)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, @@ -114,7 +126,11 @@ InlinedCallStack::InlinedCallStack( : callee_(std::move(callee)), fn_(fn), source_range_(std::move(source_range)), - module_instance_info_(std::move(module_instance_info)) {} + module_instance_info_(std::move(module_instance_info)) { + if (fn_) { + set_function_name(fn_->name()); + } +} c10::optional InlinedCallStack::callee() const { return callee_; @@ -132,6 +148,18 @@ SourceRange InlinedCallStack::source_range() const { return source_range_; } +Function* InlinedCallStack::function() const { + return fn_; +} + +void InlinedCallStack::set_function_name(std::string fn_name) { + fn_name_ = std::move(fn_name); +} + +std::string InlinedCallStack::function_name() const { + return fn_name_; +} + std::vector InlinedCallStack::vec() { std::vector r; c10::optional current = intrusive_from_this(); diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index a8166129cb6f..83d4e8fdd132 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -120,6 +120,15 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { private: c10::optional callee_; Function* fn_; + // Reason for fn_name_ even though we have fn_ + // Serialized callstack is used in circustmances where InlinedCallstack + // cannot be constructed during runtime, e.g. mobile runtime or + // delegated backends. + // Since in those cases we do not have Function* we store function name + // fn_name does not give you access to the same information that Function* + // does, however in mobile/delegated backend runtime we use InlindedCallStack + // for exception stack and for that purpose fn_name_ suffices. + std::string fn_name_; SourceRange source_range_; InlinedCallStackPtr intrusive_from_this(); c10::optional module_instance_info_; @@ -155,6 +164,12 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { // Returns the source range of the node SourceRange source_range() const; + Function* function() const; + + void set_function_name(std::string fn_name); + + std::string function_name() const; + // Return callstack as a vector of [Function, SourceRange] pairs. std::vector vec(); diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index 07b797b3d691..c38deda6eba2 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -14,7 +14,8 @@ namespace jit { namespace { std::pair, std::string> getStackTraceWithModuleHierarchy( - const DebugInfoTuple& source_callstack) { + const DebugInfoTuple& source_callstack, + const std::string& caller_name) { constexpr size_t kSourceRange = 1; constexpr size_t kModuleInstanceInfo = 2; std::vector entries; @@ -23,15 +24,15 @@ std::pair, std::string> getStackTraceWithModuleHierarchy std::get(source_callstack); InlinedCallStackPtr callstack_ptr = std::get(source_callstack); + std::string prev_function_name = caller_name; std::string module_info; if (!callstack_ptr) { // If not cs then top level node - entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range}); + entries.emplace_back(StackEntry{prev_function_name, range}); return {std::move(entries), std::move(module_info)}; } else { - for (const auto& element : callstack_ptr->vec()) { - const auto& opt_module_instance_info = - std::get(element); + while (callstack_ptr) { + const auto& opt_module_instance_info = callstack_ptr->module_instance(); if (opt_module_instance_info.has_value()) { const auto& module_instance_info = opt_module_instance_info.value(); if (module_instance_info.class_type()) { @@ -57,9 +58,20 @@ std::pair, std::string> getStackTraceWithModuleHierarchy // When we serialize function names, those can be added here. // TODO: Add function name separately entries.emplace_back( - StackEntry{"FunctionName_UNKNOWN", std::get(element)}); + StackEntry{prev_function_name, callstack_ptr->source_range()}); + if (callstack_ptr->function()) { + prev_function_name = callstack_ptr->function()->name(); + } else { + prev_function_name = callstack_ptr->function_name(); + } + + if (callstack_ptr->callee()) { + callstack_ptr = callstack_ptr->callee().value(); + } else { + callstack_ptr = c10::intrusive_ptr(); + } } - entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range}); + entries.emplace_back(StackEntry{prev_function_name, range}); return {std::move(entries), std::move(module_info)}; } } @@ -78,8 +90,10 @@ std::pair getStackTraceWithModuleHierarchy( std::vector stack_entries; std::string module_info = root_scope_string + "(" + top_module_type_name + ")"; + std::string caller_fn_name = "FunctionName_UNKNOWN"; for (const auto& debug_info : source_callstacks) { - auto debug_info_pair = getStackTraceWithModuleHierarchy(debug_info); + auto debug_info_pair = + getStackTraceWithModuleHierarchy(debug_info, caller_fn_name); auto entries = std::move(debug_info_pair.first); stack_entries.insert(stack_entries.end(), entries.begin(), entries.end()); module_info += debug_info_pair.second; diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 6f0bdc6389b0..0480a3af2cfe 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -19,10 +19,17 @@ c10::IValue InlinedCallStackSerializer::serialize( if (cs_it != serialized_inlined_callstack_.end()) { return cs_it->second; } - // Inlined callstack pointer is serialized as tuple of 3 elements - // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack)} + // Inlined callstack pointer is serialized as tuple of 4 elements + // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack), + // function name} Note function name is serialized separately because Function + // is only in memory structure. It gets constructed by JIT from serialized + // Code at runtime. As such even InlinedCallStack get constructed by JIT at + // runtime during graph inlining. However, we introduce + // serialization/deserialization of it in order to generate callstack debug + // information, _when_ equivalent InlinedCallStack cannot be constructed at + // runtime. For example, in lite interpreter or delegated backend. std::vector elements; - elements.reserve(3); + elements.reserve(4); elements.emplace_back( serialize_module_instance_info(cs_ptr->module_instance())); int64_t source_range_tag{kInvalidSourceRangeTag}; @@ -40,6 +47,16 @@ c10::IValue InlinedCallStackSerializer::serialize( } else { elements.emplace_back(c10::IValue()); } + if (cs_ptr->function()) { + elements.emplace_back(cs_ptr->function()->name()); + } else { + auto fn_name = cs_ptr->function_name(); + if (!fn_name.empty()) { + elements.emplace_back(fn_name); + } else { + elements.emplace_back("FunctionName_UNKNOWN"); + } + } c10::IValue serialized_cs = c10::ivalue::Tuple::create(elements); serialized_inlined_callstack_[cs_ptr] = serialized_cs; return serialized_cs; @@ -123,8 +140,9 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( } auto tup_elems = tup->elements(); - TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); - // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack)} + TORCH_INTERNAL_ASSERT(tup_elems.size() == 4); + // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack), + // function name} auto module_instance_info = deserialize_module_instance_info(tup_elems[0], cu); int64_t source_range_tag = tup_elems[1].toInt(); @@ -140,6 +158,7 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( source_range = source_range_it->second; } auto callee = deserialize(tup_elems[2], source_range_map, cu); + auto function_name = tup_elems[3].toStringRef(); InlinedCallStackPtr cs_ptr; if (callee) { cs_ptr = c10::make_intrusive( @@ -148,6 +167,7 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( cs_ptr = c10::make_intrusive( nullptr, source_range, module_instance_info); } + cs_ptr->set_function_name(function_name); cached_inlined_callstacks_[tup] = cs_ptr; // Invoking move constructor // It is not clear if copy-ellision can happen since From ec89bf65352e0cf1e24738cd383c94785db3d4c2 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Tue, 25 May 2021 13:57:16 -0700 Subject: [PATCH 04/18] .github: Ensure 7zip install for windows (#58924) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58924 Was observing behavior where 7zip was nowhere to be found after a build was completed. Let's just have 7zip be installed within the workflow as well just to be completely sure 7zip is there. Signed-off-by: Eli Uriegas Test Plan: Imported from OSS Reviewed By: samestep Differential Revision: D28681241 Pulled By: seemethere fbshipit-source-id: f649c1713edcdeb82c84fd67866700caa2726d71 --- .github/templates/windows_ci_workflow.yml.in | 8 ++++++++ .github/workflows/pytorch-win-vs2019-cpu-py3.yml | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/.github/templates/windows_ci_workflow.yml.in b/.github/templates/windows_ci_workflow.yml.in index 5a1c602b40f2..9544b83138e2 100644 --- a/.github/templates/windows_ci_workflow.yml.in +++ b/.github/templates/windows_ci_workflow.yml.in @@ -31,6 +31,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -73,6 +77,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | diff --git a/.github/workflows/pytorch-win-vs2019-cpu-py3.yml b/.github/workflows/pytorch-win-vs2019-cpu-py3.yml index aba6ecdd2cc6..d3166967ed8c 100644 --- a/.github/workflows/pytorch-win-vs2019-cpu-py3.yml +++ b/.github/workflows/pytorch-win-vs2019-cpu-py3.yml @@ -30,6 +30,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -72,6 +76,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | From 133133afa80819d37361b4c1319cf72286885db0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 25 May 2021 14:45:55 -0700 Subject: [PATCH 05/18] [PyTorch] Extract non-template parts of torch::class_ (#54548) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54548 We don't need to inline most of this class; doing so bloats code size and build time. ghstack-source-id: 129765666 Test Plan: Existing CI buildsizebot some mobile apps Reviewed By: jamesr66a Differential Revision: D27277317 fbshipit-source-id: 7643aa35e4d794fee0a48a3bbe0890c2e428ae78 --- aten/src/ATen/core/custom_class.cpp | 46 ++++++++++++++++ torch/custom_class.h | 84 +++-------------------------- torch/custom_class_detail.h | 48 +++++++++++++++++ 3 files changed, 102 insertions(+), 76 deletions(-) diff --git a/aten/src/ATen/core/custom_class.cpp b/aten/src/ATen/core/custom_class.cpp index c396e810eabe..8f1a66452576 100644 --- a/aten/src/ATen/core/custom_class.cpp +++ b/aten/src/ATen/core/custom_class.cpp @@ -50,5 +50,51 @@ std::vector customClassSchemasForBCCheck() { }); } +namespace detail { +class_base::class_base( + const std::string& namespaceName, + const std::string& className, + std::string doc_string, + const std::type_info& intrusivePtrClassTypeid, + const std::type_info& taggedCapsuleClassTypeid) + : qualClassName("__torch__.torch.classes." + namespaceName + '.' + className), + classTypePtr(at::ClassType::create( + c10::QualifiedName(qualClassName), + std::weak_ptr(), + /*is_module=*/false, + std::move(doc_string))) +{ + detail::checkValidIdent(namespaceName, "Namespace name"); + detail::checkValidIdent(className, "Class name"); + classTypePtr->addAttribute("capsule", at::CapsuleType::get()); + c10::getCustomClassTypeMap().insert( + {std::type_index(intrusivePtrClassTypeid), classTypePtr}); + c10::getCustomClassTypeMap().insert( + {std::type_index(taggedCapsuleClassTypeid), classTypePtr}); + registerCustomClass(classTypePtr); +} + +c10::FunctionSchema class_base::withNewArguments( + const c10::FunctionSchema& schema, + std::initializer_list default_args) { + const auto& old_args = schema.arguments(); + std::vector new_args; + new_args.reserve(old_args.size()); + + new_args.emplace_back(old_args[0]); + // Skip self. + size_t argIdx = 1; + for (const auto& default_arg : default_args) { + auto& old_arg = old_args[argIdx++]; + new_args.emplace_back( + default_arg.name_, + old_arg.type(), + old_arg.N(), + default_arg.value_); + } + return schema.cloneWithArguments(std::move(new_args)); +} + +} // namespace detail } // namespace torch diff --git a/torch/custom_class.h b/torch/custom_class.h index f5b5b07b7291..cbbbae1a3869 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -18,38 +18,6 @@ namespace torch { -/// This struct is used to represent default values for arguments -/// when registering methods for custom classes. -/// static auto register_foo = torch::class_("myclasses", "Foo") -/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); -struct arg { - // Static method for representing a default value of None. This is meant to - // be used like so: - // torch::arg("name") = torch::arg::none - // and is identical to: - // torch::arg("name") = IValue() - static c10::IValue none() { - return c10::IValue(); - } - - // Explicit constructor. - explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {} - // Assignment operator. This enables the pybind-like syntax of - // torch::arg("name") = value. - arg& operator=(const c10::IValue& rhs) { - value_ = rhs; - return *this; - } - - // The name of the argument. This is copied to the schema; argument - // names cannot be extracted from the C++ declaration. - std::string name_; - // IValue's default constructor makes it None, which is not distinguishable from - // an actual, user-provided default value that is None. This boolean - // helps distinguish between the two cases. - c10::optional value_; -}; - /// This function is used in conjunction with `class_::def()` to register /// a constructor for a given C++ class type. For example, /// `torch::init()` would register a two-argument constructor @@ -93,7 +61,7 @@ decltype(auto) init(Func&& f) { /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()` /// is registered with a C++ lambda expression. template -class class_ { +class class_ : public ::torch::detail::class_base { static_assert(std::is_base_of::value, "torch::class_ requires T to inherit from CustomClassHolder"); @@ -105,25 +73,8 @@ class class_ { /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript - explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") { - detail::checkValidIdent(namespaceName, "Namespace name"); - detail::checkValidIdent(className, "Class name"); - qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; - - classTypePtr = at::ClassType::create( - c10::QualifiedName(qualClassName), - std::weak_ptr(), - /*is_module=*/false, - std::move(doc_string)); - classTypePtr->addAttribute("capsule", at::CapsuleType::get()); - - c10::getCustomClassTypeMap().insert( - {std::type_index(typeid(c10::intrusive_ptr)), classTypePtr}); - c10::getCustomClassTypeMap().insert( - {std::type_index(typeid(c10::tagged_capsule)), classTypePtr}); - - registerCustomClass(classTypePtr); - } + explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") + : class_base(namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr), typeid(c10::tagged_capsule)) {} /// def() can be used in conjunction with `torch::init()` to register /// a constructor for a given C++ class type. For example, passing @@ -419,31 +370,15 @@ class class_ { // extracted by inferFunctionSchemaSingleReturn, and so there must be a // torch::arg instance in default_args even for arguments that do not // have an actual default value provided. - TORCH_CHECK( - default_args.size() == 0 || - default_args.size() == schema.arguments().size() - 1, - "Default values must be specified for none or all arguments"); + TORCH_CHECK( + default_args.size() == 0 || + default_args.size() == schema.arguments().size() - 1, + "Default values must be specified for none or all arguments"); // If there are default args, copy the argument names and default values to the // function schema. if (default_args.size() > 0) { - const auto& old_args = schema.arguments(); - std::vector new_args; - new_args.reserve(old_args.size()); - std::vector default_args_v(default_args); - - new_args.emplace_back(old_args[0]); - for (size_t i = 0; i < default_args_v.size(); ++i) { - // Skip self. - auto& arg = old_args[i+1]; - new_args.emplace_back(c10::Argument( - std::move(default_args_v[i].name_), - arg.type(), - arg.N(), - default_args_v[i].value_.has_value() ? std::move(*default_args_v[i].value_) : c10::nullopt)); - } - - schema = schema.cloneWithArguments(new_args); + schema = withNewArguments(schema, default_args); } auto wrapped_func = @@ -467,9 +402,6 @@ class class_ { registerCustomClassMethod(std::move(method)); return method_val; } - - std::string qualClassName; - at::ClassTypePtr classTypePtr; }; /// make_custom_class() is a convenient way to create an instance of a registered diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index 4d5ed3f3556c..6984d9f09962 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -7,6 +7,38 @@ namespace torch { +/// This struct is used to represent default values for arguments +/// when registering methods for custom classes. +/// static auto register_foo = torch::class_("myclasses", "Foo") +/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); +struct arg { + // Static method for representing a default value of None. This is meant to + // be used like so: + // torch::arg("name") = torch::arg::none + // and is identical to: + // torch::arg("name") = IValue() + static c10::IValue none() { + return c10::IValue(); + } + + // Explicit constructor. + explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {} + // Assignment operator. This enables the pybind-like syntax of + // torch::arg("name") = value. + arg& operator=(const c10::IValue& rhs) { + value_ = rhs; + return *this; + } + + // The name of the argument. This is copied to the schema; argument + // names cannot be extracted from the C++ declaration. + std::string name_; + // IValue's default constructor makes it None, which is not distinguishable from + // an actual, user-provided default value that is None. This boolean + // helps distinguish between the two cases. + c10::optional value_; +}; + namespace detail { // Argument type utilities @@ -134,6 +166,22 @@ inline void checkValidIdent(const std::string& str, const char *type) { } } +class TORCH_API class_base { + protected: + explicit class_base( + const std::string& namespaceName, + const std::string& className, + std::string doc_string, + const std::type_info& intrusivePtrClassTypeid, + const std::type_info& taggedCapsuleClass); + + static c10::FunctionSchema withNewArguments( + const c10::FunctionSchema& schema, + std::initializer_list default_args); + std::string qualClassName; + at::ClassTypePtr classTypePtr; +}; + } // namespace detail TORCH_API void registerCustomClass(at::ClassTypePtr class_type); From b4b95fc87a314d2f9c7e58cee635897132e01c48 Mon Sep 17 00:00:00 2001 From: Corey Lammie Date: Tue, 25 May 2021 14:57:00 -0700 Subject: [PATCH 06/18] Expose `cudaMemGetInfo` (#58635) Summary: This PR resolves the second issue outlined in https://github.com/pytorch/pytorch/issues/58376, which has previously been discussed in https://github.com/pytorch/pytorch/issues/50722. `cudaMemGetInfo` is bound/exposed to the Python API. An example function call is provided below: ``` device_free, device_total = torch.cuda.mem_get_info(torch.device('cuda:0')) print(device_free, device_total) ``` In `CUDACachingAllocator.cpp`, in constant to my initial PR, the newly defined function `std::pair raw_cuda_mem_get_info(int device)` has been moved from the `CUDACaching` namespace to the `cuda` namespace. In addition, as suugested by ezyang, `det` has been removed from all function names. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58635 Reviewed By: zou3519 Differential Revision: D28649093 Pulled By: ezyang fbshipit-source-id: d8b7c53e52cf73f35495d8651863c5bb408d7a6a --- torch/csrc/cuda/shared/cudart.cpp | 8 ++++++++ torch/cuda/memory.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index a8f80a35855d..30a43bed0534 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -6,6 +6,7 @@ #else #include #endif +#include namespace torch { namespace cuda { namespace shared { @@ -38,6 +39,13 @@ void initCudartBindings(PyObject* module) { #ifndef __HIP_PLATFORM_HCC__ cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize); #endif + cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair { + C10_CUDA_CHECK(cudaGetDevice(&device)); + size_t device_free; + size_t device_total; + cudaMemGetInfo(&device_free, &device_total); + return {device_free, device_total}; + }); } } // namespace shared diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 61574877ac16..85e5f57d78fd 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -531,3 +531,21 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: mem = p.usedGpuMemory / (1024 * 1024) lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) + +def mem_get_info(device: Union[Device, int] = None) -> int: + r"""Returns the global free and total GPU memory occupied for a given + device using cudaMemGetInfo. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + return torch.cuda.cudart().cudaMemGetInfo(device) From 45aa54d83c123792bd4085e9b6679dc845c2af11 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 25 May 2021 14:57:17 -0700 Subject: [PATCH 07/18] relax test deadlines Summary: Relax test deadlines for c2 tests. We run on loaded machines, and timings are unreliable. Test Plan: Fixes existing tests Reviewed By: mruberry Differential Revision: D28690006 fbshipit-source-id: 457707e81a1ec92548c1f23ea7a0022fa0a3bfda --- .../python/operator_test/batch_sparse_to_dense_op_test.py | 4 ++-- caffe2/python/operator_test/conv_test.py | 4 ++-- .../python/operator_test/elementwise_op_broadcast_test.py | 8 ++++---- caffe2/python/operator_test/fc_operator_test.py | 4 ++-- caffe2/python/operator_test/locally_connected_op_test.py | 4 ++-- caffe2/python/operator_test/reduce_ops_test.py | 8 ++++---- caffe2/python/operator_test/softmax_ops_test.py | 6 +++--- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py b/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py index adfc735c66fd..968da8da8405 100644 --- a/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py +++ b/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py @@ -19,7 +19,7 @@ class TestBatchSparseToDense(serial.SerializedTestCase): default_value=st.floats(min_value=2.0, max_value=3.0), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_batch_sparse_to_dense( self, batch_size, dense_last_dim, default_value, gc, dc ): @@ -75,7 +75,7 @@ def batch_sparse_to_dense_ref(L, I, V, S=None): dense_last_dim=st.integers(5, 10), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_batch_dense_to_sparse(self, batch_size, dense_last_dim, gc, dc): L = np.random.randint(1, dense_last_dim + 1, size=(batch_size)) # The following logic ensure that indices in each batch will not be duplicated diff --git a/caffe2/python/operator_test/conv_test.py b/caffe2/python/operator_test/conv_test.py index e600aa2c9ee9..23217b15b82d 100644 --- a/caffe2/python/operator_test/conv_test.py +++ b/caffe2/python/operator_test/conv_test.py @@ -164,7 +164,7 @@ def test_convolution_separate_stride_pad_gradients( use_bias=st.booleans(), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_convolution_separate_stride_pad_layout( self, op_type, @@ -761,7 +761,7 @@ def canonical(o): engine=st.sampled_from(["CUDNN", ""]), **hu.gcs_no_hip ) - @settings(deadline=1000) + @settings(deadline=None) def test_convolution_sync(self, net_type, num_workers, engine, gc, dc): m = ModelHelper(name="test_model") n = 1 diff --git a/caffe2/python/operator_test/elementwise_op_broadcast_test.py b/caffe2/python/operator_test/elementwise_op_broadcast_test.py index 605c1d741271..bd19ebc6ed97 100644 --- a/caffe2/python/operator_test/elementwise_op_broadcast_test.py +++ b/caffe2/python/operator_test/elementwise_op_broadcast_test.py @@ -75,22 +75,22 @@ def __test_binary_op(self, gc, dc, caffe2_op, op_function): self.assertGradientChecks(gc, op, [X, Y], 1, [0]) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Add(self, gc, dc): self.__test_binary_op(gc, dc, "Add", operator.add) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Mul(self, gc, dc): self.__test_binary_op(gc, dc, "Mul", operator.mul) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Sub(self, gc, dc): self.__test_binary_op(gc, dc, "Sub", operator.sub) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_powt(self, gc, dc): np.random.seed(101) diff --git a/caffe2/python/operator_test/fc_operator_test.py b/caffe2/python/operator_test/fc_operator_test.py index 1e8b5522053d..bd203b7c84a6 100644 --- a/caffe2/python/operator_test/fc_operator_test.py +++ b/caffe2/python/operator_test/fc_operator_test.py @@ -61,8 +61,8 @@ def fc_transposed_op(X, W, b): op.arg.extend([a]) # Check against numpy reference - # ReferenceChecks is flaky on rocm with threshold of 1e-4 for fp16. Relaxing to 1e-3. - threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP and dtype == np.float16) else 1e-4 + # ReferenceChecks is flaky, Relaxing to 1e-3. + threshold = 1e-3 self.assertReferenceChecks( device_option=gc, op=op, diff --git a/caffe2/python/operator_test/locally_connected_op_test.py b/caffe2/python/operator_test/locally_connected_op_test.py index 2adc253f4d88..445c3641573f 100644 --- a/caffe2/python/operator_test/locally_connected_op_test.py +++ b/caffe2/python/operator_test/locally_connected_op_test.py @@ -103,7 +103,7 @@ def lc_2d_nhwc(X, W, b=None): op_name=st.sampled_from(["LC", "LC1D"]), use_bias=st.booleans(), **hu.gcs) - @settings(deadline=5000) + @settings(deadline=None) # Increased timeout from 1 second to 5 for ROCM def test_lc_1d(self, N, C, size, M, kernel, op_name, use_bias, gc, dc): if size < kernel: @@ -163,7 +163,7 @@ def conv(n, m, yl): op_name=st.sampled_from(["LC", "LC3D"]), use_bias=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_lc_3d(self, N, C, T, H, W, M, kernel, op_name, use_bias, gc, dc): if T < kernel: kernel = T diff --git a/caffe2/python/operator_test/reduce_ops_test.py b/caffe2/python/operator_test/reduce_ops_test.py index 7b79b3b81aed..299b373e509d 100644 --- a/caffe2/python/operator_test/reduce_ops_test.py +++ b/caffe2/python/operator_test/reduce_ops_test.py @@ -96,7 +96,7 @@ def test_reduce_mean(self, X, keepdims, num_axes, gc, dc): @given(n=st.integers(1, 3), m=st.integers(1, 3), k=st.integers(1, 3), keepdims=st.booleans(), num_axes=st.integers(1, 3), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_reduce_l1(self, n, m, k, keepdims, num_axes, gc, dc): X = np.arange(n * m * k, dtype=np.float32) - 0.5 np.random.shuffle(X) @@ -253,7 +253,7 @@ def ref_sum(X): np.testing.assert_allclose(output, ref_sum(X)[0], atol=1e-3) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_reduce_front_sum_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) @@ -286,7 +286,7 @@ def ref_mean(X): "ReduceFrontMeanGradient", X, ref_mean, num_reduce_dim) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_reduce_front_mean_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) @@ -411,7 +411,7 @@ def ref_mean(X): "ReduceBackMeanGradient", X, ref_mean, num_reduce_dim) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_reduce_back_mean_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) diff --git a/caffe2/python/operator_test/softmax_ops_test.py b/caffe2/python/operator_test/softmax_ops_test.py index 533d575ee59f..8ec92ae1af9e 100644 --- a/caffe2/python/operator_test/softmax_ops_test.py +++ b/caffe2/python/operator_test/softmax_ops_test.py @@ -143,7 +143,7 @@ def label_softmax(X): @given(n=st.integers(2, 10), D=st.integers(4, 16), only_loss=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_softmax_with_loss(self, n, D, gc, only_loss, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability @@ -301,7 +301,7 @@ def label_softmax_crossent(X, label): ) @given(n=st.integers(2, 10), D=st.integers(4, 16), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_softmax_with_loss_label_prob(self, n, D, gc, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability @@ -358,7 +358,7 @@ def label_softmax_crossent(X, label): D=st.integers(4, 16), only_loss=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_softmax_with_loss_weighted(self, n, D, only_loss, gc, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability From 7179e7ea7bb18e7df80a1c99c02bf54f215da160 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 25 May 2021 15:08:56 -0700 Subject: [PATCH 08/18] [CMake] Prefer third_party/pybind11 by default (#58951) Summary: To make build behaviour aligned with other third_party/ libraries, introduce `USE_SYSTEM_PYBIND11 (https://github.com/pytorch/pytorch/commit/d55b25a633b7e2e6122becf6dbdf0528df6e8b13)` build option, which set to OFF by default, which means PyTorch will be build with bundled pybind11 even if other version is already installed locally. Fixes https://github.com/pytorch/pytorch/issues/58750 Pull Request resolved: https://github.com/pytorch/pytorch/pull/58951 Reviewed By: driazati Differential Revision: D28690411 Pulled By: malfet fbshipit-source-id: e56b5a8f2a23ee1834b2a6d3807f287149decf8c --- CMakeLists.txt | 2 ++ cmake/Dependencies.cmake | 12 ++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f308a75f072..4818b5012b57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -351,6 +351,7 @@ option(USE_SYSTEM_CPUINFO "Use system-provided cpuinfo." OFF) option(USE_SYSTEM_SLEEF "Use system-provided sleef." OFF) option(USE_SYSTEM_GLOO "Use system-provided gloo." OFF) option(USE_SYSTEM_FP16 "Use system-provided fp16." OFF) +option(USE_SYSTEM_PYBIND11 "Use system-provided PyBind11." OFF) option(USE_SYSTEM_PTHREADPOOL "Use system-provided pthreadpool." OFF) option(USE_SYSTEM_PSIMD "Use system-provided psimd." OFF) option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF) @@ -371,6 +372,7 @@ if(USE_SYSTEM_LIBS) set(USE_SYSTEM_BENCHMARK ON) set(USE_SYSTEM_ONNX ON) set(USE_SYSTEM_XNNPACK ON) + set(USE_SYSTEM_PYBIND11 ON) endif() # Used when building Caffe2 through setup.py diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index c7fe9b7d4bde..6d9c3ac3ab90 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -999,24 +999,20 @@ if(BUILD_PYTHON) endif() # ---[ pybind11 -if(NOT pybind11_PREFER_third_party) +if(USE_SYSTEM_BIND11) find_package(pybind11 CONFIG) if(NOT pybind11_FOUND) find_package(pybind11) endif() -endif() - -if(pybind11_FOUND) - message(STATUS "System pybind11 found") + if(NOT pybind11_FOUND) + message(FATAL "Cannot find system pybind11") + endif() else() message(STATUS "Using third_party/pybind11.") set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/../third_party/pybind11/include) install(DIRECTORY ${pybind11_INCLUDE_DIRS} DESTINATION ${CMAKE_INSTALL_PREFIX} FILES_MATCHING PATTERN "*.h") - set(pybind11_PREFER_third_party ON CACHE BOOL - "Use the third_party/pybind11 submodule, instead of looking for system - installation of pybind11") endif() message(STATUS "pybind11 include dirs: " "${pybind11_INCLUDE_DIRS}") include_directories(SYSTEM ${pybind11_INCLUDE_DIRS}) From 36a77580f513434820b45f058b9378ed0cd20ebf Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 25 May 2021 15:26:02 -0700 Subject: [PATCH 09/18] [docs] Clarify batch_first behavior for nn.LSTM, nn.RNN, and nn.GRU (#58809) Summary: Fixes the high-pri doc component of https://github.com/pytorch/pytorch/issues/4145. To make the input / output shapes more readable for both `batch_first` states, this PR also introduces short dim names. Opinions welcome on the readability of the restructured docs! Screenshot for `nn.LSTM`: Screen Shot 2021-05-24 at 5 11 39 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/58809 Reviewed By: gchanan Differential Revision: D28685415 Pulled By: jbschlosser fbshipit-source-id: e8c92e3d7e052071a505b55dca976fd2ef5a8307 --- torch/nn/modules/rnn.py | 203 ++++++++++++++++++++-------------------- 1 file changed, 102 insertions(+), 101 deletions(-) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index ea338fbf020a..7a6fde26ba9d 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -373,50 +373,42 @@ class RNN(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as `(batch, seq, feature)`. Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each RNN layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` Inputs: input, h_0 - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. The input can also be a packed variable length - sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` - or :func:`torch.nn.utils.rnn.pack_sequence` - for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - Defaults to zero if not provided. If the RNN is bidirectional, - num_directions should be 2, else it should be 1. - - Outputs: output, h_n - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features (`h_t`) from the last layer of the RNN, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has - been given as the input, the output will also be a packed sequence. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for each element in the batch. Defaults to zeros if not provided. - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len`. + where: - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} - Shape: - - Input1: :math:`(L, N, H_{in})` tensor containing input features where - :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - - Input2: :math:`(S, N, H_{out})` tensor - containing the initial hidden state for each element in the batch. - :math:`H_{out}=\text{hidden\_size}` - Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` - If the RNN is bidirectional, num_directions should be 2, else it should be 1. - - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state - for each element in the batch + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. Attributes: weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, @@ -433,6 +425,11 @@ class RNN(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_persistent_rnn.rst @@ -518,7 +515,9 @@ class LSTM(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 @@ -526,41 +525,40 @@ class LSTM(RNNBase): proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Inputs: input, (h_0, c_0) - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. - The input can also be a packed variable length sequence. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - If the LSTM is bidirectional, num_directions should be 2, else it should be 1. - If ``proj_size > 0`` was specified, the shape has to be - `(num_layers * num_directions, batch, proj_size)`. - - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial cell state for each element in the batch. - - If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. - + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the batch. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the batch. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} Outputs: output, (h_n, c_n) - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features `(h_t)` from the last layer of the LSTM, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. If ``proj_size > 0`` - was specified, output shape will be `(seq_len, batch, num_directions * proj_size)`. - - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len`. If ``proj_size > 0`` - was specified, ``h_n`` shape will be `(num_layers * num_directions, batch, proj_size)`. - - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*. - - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the cell state for `t = seq_len`. + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the batch. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer @@ -581,6 +579,11 @@ class LSTM(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_persistent_rnn.rst @@ -724,49 +727,42 @@ class GRU(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. The input can also be a packed variable length - sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` - for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - Defaults to zero if not provided. If the RNN is bidirectional, - num_directions should be 2, else it should be 1. - - Outputs: output, h_n - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features h_t from the last layer of the GRU, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for each element in the batch. Defaults to zeros if not provided. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len` + where: - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} - Shape: - - Input1: :math:`(L, N, H_{in})` tensor containing input features where - :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - - Input2: :math:`(S, N, H_{out})` tensor - containing the initial hidden state for each element in the batch. - :math:`H_{out}=\text{hidden\_size}` - Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` - If the RNN is bidirectional, num_directions should be 2, else it should be 1. - - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state - for each element in the batch + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer @@ -783,6 +779,11 @@ class GRU(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_persistent_rnn.rst Examples:: From fb120493b10c950d7da58a46597d26e8c3013579 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 25 May 2021 15:32:36 -0700 Subject: [PATCH 10/18] Make Scalar.to<> for invalid types a compile-time error (#58726) Summary: Currently calling `scalar.to>()` for example compiles but throws an error at runtime. Instead, marking the non-specialized cases as `= delete` means the code fails to compile and you catch the error sooner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58726 Reviewed By: zou3519, seemethere Differential Revision: D28646057 Pulled By: ezyang fbshipit-source-id: 9e4e3d1b4586eeecbb73db61bba56560b2657351 --- c10/core/Scalar.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 802bf17e0411..4c0baa431d53 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -63,8 +63,9 @@ class C10_API Scalar { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR) // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types template - T to() const; + T to() const = delete; #undef DEFINE_ACCESSOR bool isFloatingPoint() const { @@ -186,11 +187,6 @@ class C10_API Scalar { }; // define the scalar.to() specializations -template -inline T Scalar::to() const { - throw std::runtime_error("to() cast to unexpected type."); -} - #define DEFINE_TO(T, name) \ template <> \ inline T Scalar::to() const { \ From 60af6e928ab2f0e94de1e27899048d552e6acb58 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 25 May 2021 15:33:50 -0700 Subject: [PATCH 11/18] [PyTorch Edge][Version] Fix torchscript model after backport (#58892) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58892 The torchscript model after backport misses the `constants` archive. Add it back, and extend the unit test to run torchscript part. ghstack-source-id: 129853819 Test Plan: ``` buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions' ``` Reviewed By: raziel, iseeyuan Differential Revision: D28664507 fbshipit-source-id: 5f98723231cc64ed203c062ee6f00d8adbdccf77 --- test/cpp/jit/test_lite_interpreter.cpp | 38 ++++++++++++++++++---- torch/csrc/jit/mobile/backport_manager.cpp | 2 +- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index c5acb0f68d9d..3c5aa5e5cf58 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -624,6 +624,34 @@ TEST(LiteInterpreterTest, GetByteCodeVersion) { } namespace { + +void compareModelOutput( + const std::vector& actual_result_list, + const std::vector& expect_result_list) { + AT_ASSERT(actual_result_list.size() == expect_result_list.size()); + AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0])); + AT_ASSERT( + actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); + AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); +} + +void runAndCheckTorchScriptModel( + std::stringstream& input_model_stream, + const std::vector& input_data, + const std::vector& expect_result_list, + const int64_t expect_version) { + auto actual_version = _get_model_bytecode_version(input_model_stream); + AT_ASSERT(actual_version == expect_version); + + // Load and run the backport model, then compare the result with expect + // result + Module m_mobile = load(input_model_stream); + + auto actual_result = m_mobile.forward(input_data); + std::vector actual_result_list = actual_result.toTuple()->elements(); + compareModelOutput(actual_result_list, expect_result_list); +} + void runAndCheckBytecodeModel( std::stringstream& input_model_stream, const std::vector& input_data, @@ -634,16 +662,12 @@ void runAndCheckBytecodeModel( // Load and run the backport model, then compare the result with expect // result - mobile::Module m_mobile = _load_for_mobile(input_model_stream); + Module m_mobile = load(input_model_stream); auto actual_result = m_mobile.forward(input_data); std::vector actual_result_list = actual_result.toTuple()->elements(); - AT_ASSERT(actual_result_list.size() == expect_result_list.size()); - AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0])); - AT_ASSERT( - actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); - AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); + compareModelOutput(actual_result_list, expect_result_list); } void backportAllVersionCheck( @@ -676,6 +700,8 @@ void backportAllVersionCheck( // result runAndCheckBytecodeModel( iss, input_data, expect_result_list, current_to_version); + runAndCheckTorchScriptModel( + iss, input_data, expect_result_list, current_to_version); current_to_version--; } diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 25d12c9c5661..ab4f9a3fe034 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -181,7 +181,7 @@ bool backport_v5_to_v4( // write `constants` archive auto constants_tuple = c10::ivalue::Tuple::create(std::move(constants_values)); - writeArchiveV4(writer, kArchiveNameConstants, bytecode_tuple); + writeArchiveV4(writer, kArchiveNameConstants, constants_tuple); return true; } From db5e5781adcbfc42ecc7f3f5af8b86b5e8e09f4e Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 25 May 2021 15:53:44 -0700 Subject: [PATCH 12/18] replace all remaining occurrences of deadline=1000, to prevent test flakiness Summary: Per title Test Plan: Fixes existing tests Reviewed By: robieta Differential Revision: D28690296 fbshipit-source-id: d7b5b5065517373b75d501872814c89b24ec8cfc --- caffe2/python/operator_test/activation_ops_test.py | 2 +- caffe2/python/operator_test/adadelta_test.py | 2 +- caffe2/python/operator_test/adagrad_test.py | 8 ++++---- caffe2/python/operator_test/assert_test.py | 2 +- caffe2/python/operator_test/bbox_transform_test.py | 4 ++-- caffe2/python/operator_test/boolean_mask_test.py | 4 ++-- .../python/operator_test/box_with_nms_limit_op_test.py | 6 +++--- caffe2/python/operator_test/clip_op_test.py | 2 +- caffe2/python/operator_test/clip_tensor_op_test.py | 2 +- caffe2/python/operator_test/crf_test.py | 2 +- caffe2/python/operator_test/dropout_op_test.py | 2 +- caffe2/python/operator_test/elementwise_ops_test.py | 6 +++--- caffe2/python/operator_test/erf_op_test.py | 2 +- caffe2/python/operator_test/expand_op_test.py | 2 +- caffe2/python/operator_test/filler_ops_test.py | 4 ++-- caffe2/python/operator_test/flexible_top_k_test.py | 2 +- .../fused_nbit_rowwise_conversion_ops_test.py | 2 +- caffe2/python/operator_test/gather_ops_test.py | 2 +- caffe2/python/operator_test/gather_ranges_op_test.py | 4 ++-- caffe2/python/operator_test/instance_norm_test.py | 2 +- caffe2/python/operator_test/layer_norm_op_test.py | 2 +- caffe2/python/operator_test/length_split_op_test.py | 2 +- caffe2/python/operator_test/lpnorm_op_test.py | 2 +- .../operator_test/margin_ranking_criterion_op_test.py | 2 +- caffe2/python/operator_test/matmul_op_test.py | 2 +- caffe2/python/operator_test/one_hot_ops_test.py | 2 +- caffe2/python/operator_test/pooling_test.py | 2 +- caffe2/python/operator_test/python_op_test.py | 2 +- caffe2/python/operator_test/selu_op_test.py | 4 ++-- caffe2/python/operator_test/sequence_ops_test.py | 4 ++-- .../sinusoid_position_encoding_op_test.py | 2 +- caffe2/python/operator_test/softplus_op_test.py | 2 +- .../operator_test/sparse_to_dense_mask_op_test.py | 4 ++-- caffe2/python/operator_test/string_ops_test.py | 10 +++++----- caffe2/python/operator_test/top_k_test.py | 4 ++-- caffe2/python/operator_test/torch_integration_test.py | 2 +- caffe2/python/operator_test/utility_ops_test.py | 4 ++-- caffe2/python/operator_test/weighted_sum_test.py | 2 +- caffe2/python/operator_test/wngrad_test.py | 8 ++++---- 39 files changed, 62 insertions(+), 62 deletions(-) diff --git a/caffe2/python/operator_test/activation_ops_test.py b/caffe2/python/operator_test/activation_ops_test.py index 7e5c5f423606..47216d51500c 100644 --- a/caffe2/python/operator_test/activation_ops_test.py +++ b/caffe2/python/operator_test/activation_ops_test.py @@ -243,7 +243,7 @@ def leaky_relu_ref(X): @given(X=hu.tensor(), fast_gelu=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_gelu(self, X, fast_gelu, gc, dc): op = core.CreateOperator( "Gelu", diff --git a/caffe2/python/operator_test/adadelta_test.py b/caffe2/python/operator_test/adadelta_test.py index 930f74ecd99e..6c40c379697f 100644 --- a/caffe2/python/operator_test/adadelta_test.py +++ b/caffe2/python/operator_test/adadelta_test.py @@ -53,7 +53,7 @@ def ref_adadelta(param_in, decay=hu.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_adadelta(self, inputs, lr, epsilon, decay, gc, dc): param, moment, moment_delta, grad = inputs moment = np.abs(moment) diff --git a/caffe2/python/operator_test/adagrad_test.py b/caffe2/python/operator_test/adagrad_test.py index 309c54a25cb1..3172026df1bf 100644 --- a/caffe2/python/operator_test/adagrad_test.py +++ b/caffe2/python/operator_test/adagrad_test.py @@ -26,7 +26,7 @@ class TestAdagrad(serial.SerializedTestCase): weight_decay=st.sampled_from([0.0, 0.1]), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc): param, momentum, grad = inputs momentum = np.abs(momentum) @@ -98,7 +98,7 @@ def test_adagrad_output_effective_lr( ), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_adagrad_output_effective_lr_and_update(self, inputs, lr, epsilon, gc, dc): param, momentum, grad = inputs momentum = np.abs(momentum) @@ -158,7 +158,7 @@ def test_sparse_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc): ), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc): param, momentum = inputs grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32) @@ -190,7 +190,7 @@ def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc): # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. - @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=1000) + @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=10000) @given( inputs=hu.tensors(n=3), lr=st.floats( diff --git a/caffe2/python/operator_test/assert_test.py b/caffe2/python/operator_test/assert_test.py index 2bbca5ab7376..eef33bc22bc0 100644 --- a/caffe2/python/operator_test/assert_test.py +++ b/caffe2/python/operator_test/assert_test.py @@ -14,7 +14,7 @@ class TestAssert(hu.HypothesisTestCase): dtype=st.sampled_from(['bool_', 'int32', 'int64']), shape=st.lists(elements=st.integers(1, 10), min_size=1, max_size=4), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_assert(self, dtype, shape, gc, dc): test_tensor = np.random.rand(*shape).astype(np.dtype(dtype)) diff --git a/caffe2/python/operator_test/bbox_transform_test.py b/caffe2/python/operator_test/bbox_transform_test.py index d2584f18af40..adcc2f8723d2 100644 --- a/caffe2/python/operator_test/bbox_transform_test.py +++ b/caffe2/python/operator_test/bbox_transform_test.py @@ -214,7 +214,7 @@ class TestBBoxTransformOp(serial.SerializedTestCase): clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_bbox_transform( self, num_rois, @@ -282,7 +282,7 @@ def bbox_transform_ref(rois, deltas, im_info): clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_bbox_transform_batch( self, roi_counts, diff --git a/caffe2/python/operator_test/boolean_mask_test.py b/caffe2/python/operator_test/boolean_mask_test.py index 38fe43899990..0ccdbd928512 100644 --- a/caffe2/python/operator_test/boolean_mask_test.py +++ b/caffe2/python/operator_test/boolean_mask_test.py @@ -15,7 +15,7 @@ class TestBooleanMaskOp(serial.SerializedTestCase): max_len=100, elements=hu.floats(min_value=0.5, max_value=1.0)), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_boolean_mask_gradient(self, x, gc, dc): op = core.CreateOperator("BooleanMask", ["data", "mask"], @@ -30,7 +30,7 @@ def test_boolean_mask_gradient(self, x, gc, dc): max_len=5, elements=hu.floats(min_value=0.5, max_value=1.0)), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_boolean_mask(self, x, gc, dc): op = core.CreateOperator("BooleanMask", ["data", "mask"], diff --git a/caffe2/python/operator_test/box_with_nms_limit_op_test.py b/caffe2/python/operator_test/box_with_nms_limit_op_test.py index 3131316feefd..e459edb57de3 100644 --- a/caffe2/python/operator_test/box_with_nms_limit_op_test.py +++ b/caffe2/python/operator_test/box_with_nms_limit_op_test.py @@ -83,7 +83,7 @@ def ref(*args, **kwargs): self.assertReferenceChecks(gc, op, [scores, boxes], ref) @given(**HU_CONFIG) - @settings(deadline=1000) + @settings(deadline=10000) def test_score_thresh(self, gc): in_centers = [(0, 0), (20, 20), (50, 50)] in_scores = [0.7, 0.85, 0.6] @@ -102,7 +102,7 @@ def ref(*args, **kwargs): self.assertReferenceChecks(gc, op, [scores, boxes], ref) @given(det_per_im=st.integers(1, 3), **HU_CONFIG) - @settings(deadline=1000) + @settings(deadline=10000) def test_detections_per_im(self, det_per_im, gc): in_centers = [(0, 0), (20, 20), (50, 50)] in_scores = [0.7, 0.85, 0.6] @@ -131,7 +131,7 @@ def ref(*args, **kwargs): output_classes_include_bg_cls=st.booleans(), **HU_CONFIG ) - @settings(deadline=1000) + @settings(deadline=10000) def test_multiclass( self, num_classes, diff --git a/caffe2/python/operator_test/clip_op_test.py b/caffe2/python/operator_test/clip_op_test.py index 3304121aab08..0e800dafe01a 100644 --- a/caffe2/python/operator_test/clip_op_test.py +++ b/caffe2/python/operator_test/clip_op_test.py @@ -19,7 +19,7 @@ class TestClip(serial.SerializedTestCase): max_=st.floats(min_value=0, max_value=2), inplace=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_clip(self, X, min_, max_, inplace, gc, dc): # go away from the origin point to avoid kink problems if np.isscalar(X): diff --git a/caffe2/python/operator_test/clip_tensor_op_test.py b/caffe2/python/operator_test/clip_tensor_op_test.py index efc86815bc49..c90c38234c8e 100644 --- a/caffe2/python/operator_test/clip_tensor_op_test.py +++ b/caffe2/python/operator_test/clip_tensor_op_test.py @@ -19,7 +19,7 @@ class TestClipTensorByScalingOp(serial.SerializedTestCase): use_additional_threshold=st.booleans(), inplace=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_clip_tensor_by_scaling(self, n, d, threshold, additional_threshold, use_additional_threshold, inplace, gc, dc): diff --git a/caffe2/python/operator_test/crf_test.py b/caffe2/python/operator_test/crf_test.py index 4d7b90c431a6..a4447fa3f364 100644 --- a/caffe2/python/operator_test/crf_test.py +++ b/caffe2/python/operator_test/crf_test.py @@ -15,7 +15,7 @@ class TestCRFOp(hu.HypothesisTestCase): @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15)) - @settings(deadline=1000) + @settings(deadline=10000) def test_crf_with_loss_op(self, num_tags, num_words): model = ModelHelper(name='external') embeddings_dim = 200 diff --git a/caffe2/python/operator_test/dropout_op_test.py b/caffe2/python/operator_test/dropout_op_test.py index 84c2f7e35f56..d3a5c831d875 100644 --- a/caffe2/python/operator_test/dropout_op_test.py +++ b/caffe2/python/operator_test/dropout_op_test.py @@ -48,7 +48,7 @@ def reference_dropout_test(x): output_mask=st.booleans(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_dropout_ratio0(self, X, in_place, output_mask, engine, gc, dc): """Test with ratio=0 for a deterministic reference impl.""" # TODO(lukeyeager): enable this path when the op is fixed diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py index 922e4554e9a8..130ebade010b 100644 --- a/caffe2/python/operator_test/elementwise_ops_test.py +++ b/caffe2/python/operator_test/elementwise_ops_test.py @@ -59,7 +59,7 @@ def exp_ref(X): @given(n=st.integers(0, 6), m=st.integers(4, 6), seed=st.integers(0, 1000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_log(self, n, m, gc, dc, seed): np.random.seed(seed) X = np.random.rand(n, m).astype(np.float32) + 1.0 @@ -326,7 +326,7 @@ def swish(X): @given(n=st.integers(0, 6), m=st.integers(4, 6), seed=st.integers(0, 1000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_swish_gradient_inplace(self, n, m, gc, dc, seed): np.random.seed(seed) @@ -354,7 +354,7 @@ def swish_gradient(X, Y, dY): @given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_sigmoid(self, X, inplace, engine, gc, dc): op = core.CreateOperator( "Sigmoid", diff --git a/caffe2/python/operator_test/erf_op_test.py b/caffe2/python/operator_test/erf_op_test.py index 64714db4315c..a4ed0d5fb23e 100644 --- a/caffe2/python/operator_test/erf_op_test.py +++ b/caffe2/python/operator_test/erf_op_test.py @@ -18,7 +18,7 @@ class TestErfOp(serial.SerializedTestCase): @given( X=hu.tensor(elements=hu.floats(min_value=-0.7, max_value=0.7)), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_erf(self, X, gc, dc): op = core.CreateOperator('Erf', ["X"], ["Y"]) self.assertReferenceChecks(gc, op, [X], lambda x: (np.vectorize(math.erf)(X),)) diff --git a/caffe2/python/operator_test/expand_op_test.py b/caffe2/python/operator_test/expand_op_test.py index aba2c1106da3..bd608f6fcc24 100644 --- a/caffe2/python/operator_test/expand_op_test.py +++ b/caffe2/python/operator_test/expand_op_test.py @@ -59,7 +59,7 @@ def test_expand_nonrand_shape1(self, X, gc, dc): np.ones([1, 4, 1, 2]), np.ones([4, 1, 2])]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_expand_nonrand_shape2(self, X, gc, dc): self._run_expand_op_test(X, [4, 1, 2, 2], gc, dc) self._run_expand_op_test(X, [4, -1, 2, 2], gc, dc) diff --git a/caffe2/python/operator_test/filler_ops_test.py b/caffe2/python/operator_test/filler_ops_test.py index e080dde3eb5f..442f5866cb09 100644 --- a/caffe2/python/operator_test/filler_ops_test.py +++ b/caffe2/python/operator_test/filler_ops_test.py @@ -22,7 +22,7 @@ def _fill_diagonal(shape, value): class TestFillerOperator(serial.SerializedTestCase): @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_shape_error(self, gc, dc): op = core.CreateOperator( 'GaussianFill', @@ -77,7 +77,7 @@ def test_int64_shape(self, gc, dc): b=st.integers(min_value=0, max_value=100), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_uniform_int_fill_op_blob_input(self, shape, a, b, gc, dc): net = core.Net('test_net') diff --git a/caffe2/python/operator_test/flexible_top_k_test.py b/caffe2/python/operator_test/flexible_top_k_test.py index 3e0e5722b0ce..0cccabb5f2e9 100644 --- a/caffe2/python/operator_test/flexible_top_k_test.py +++ b/caffe2/python/operator_test/flexible_top_k_test.py @@ -40,7 +40,7 @@ def flexible_top_k_ref(self, X, k): return (values_ref, indices_ref) @given(X=hu.tensor(min_dim=2), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_flexible_top_k(self, X, gc, dc): X = X.astype(dtype=np.float32) k_shape = (int(X.size / X.shape[-1]), ) diff --git a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py index b7cb5f68351f..d2e794da0651 100644 --- a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py +++ b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py @@ -205,7 +205,7 @@ def ErrorThresholdRow(X, bit_rate): class TestNBitFakeFused(hu.HypothesisTestCase): @given(bit_rate=st.sampled_from([2, 4])) - @settings(deadline=1000) + @settings(deadline=10000) def testNBit(self, bit_rate): # uncomment for debugging # np.random.seed(0) diff --git a/caffe2/python/operator_test/gather_ops_test.py b/caffe2/python/operator_test/gather_ops_test.py index fc23be13fdae..b0d64506e4c7 100644 --- a/caffe2/python/operator_test/gather_ops_test.py +++ b/caffe2/python/operator_test/gather_ops_test.py @@ -209,7 +209,7 @@ class TestGatherFused8BitRowwise(hu.HypothesisTestCase): cols_num=st.integers(1, 128), index_num=st.integers(0, 5000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_batch_gather_ops(self, rows_num, cols_num, index_num, gc, dc): data = np.random.random((rows_num, cols_num)).astype(np.float32) ind = np.random.randint(rows_num, size=(index_num, )).astype('int32') diff --git a/caffe2/python/operator_test/gather_ranges_op_test.py b/caffe2/python/operator_test/gather_ranges_op_test.py index c0d73af33601..b6ec8823f4dd 100644 --- a/caffe2/python/operator_test/gather_ranges_op_test.py +++ b/caffe2/python/operator_test/gather_ranges_op_test.py @@ -166,7 +166,7 @@ def gather_ranges_to_dense_with_key(data, ranges, key, lengths): class TestGatherRanges(serial.SerializedTestCase): @given(boarders_and_data=batched_boarders_and_data(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_gather_ranges(self, boarders_and_data, gc, dc): boarders, data = boarders_and_data @@ -187,7 +187,7 @@ def boarders_to_range(boarders): ) @given(tensor_splits=_tensor_splits(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_gather_ranges_split(self, tensor_splits, gc, dc): data, ranges, lengths, _ = tensor_splits diff --git a/caffe2/python/operator_test/instance_norm_test.py b/caffe2/python/operator_test/instance_norm_test.py index efce9d7001fe..d97385cbe215 100644 --- a/caffe2/python/operator_test/instance_norm_test.py +++ b/caffe2/python/operator_test/instance_norm_test.py @@ -60,7 +60,7 @@ def _feed_inputs(self, input_blobs, device_option): store_mean=st.booleans(), seed=st.integers(0, 1000), store_inv_stdev=st.booleans()) - @settings(deadline=1000) + @settings(deadline=10000) def test_instance_norm_gradients( self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev, epsilon, seed): diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index 67d7f14bd336..32a2511e3e8e 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -322,7 +322,7 @@ def test_layer_norm_op_pytorch_cuda(self, X, eps, elementwise_affine): eps=st.floats(1e-5, 1e-3), elementwise_affine=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_layer_norm_op_jit(self, X, eps, elementwise_affine, gc, dc): @torch.jit.script def jit_layer_norm( diff --git a/caffe2/python/operator_test/length_split_op_test.py b/caffe2/python/operator_test/length_split_op_test.py index 28d7134ac5e8..3f20ff1f4585 100644 --- a/caffe2/python/operator_test/length_split_op_test.py +++ b/caffe2/python/operator_test/length_split_op_test.py @@ -28,7 +28,7 @@ def _length_split_op_ref(self, input_lengths, n_split_array): return [np.array(output).astype(np.int32)] @given(**hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_length_split_edge(self, gc, dc): input_lengths = np.array([3, 4, 5]).astype(np.int32) n_split_ = np.array([5]).astype(np.int32) diff --git a/caffe2/python/operator_test/lpnorm_op_test.py b/caffe2/python/operator_test/lpnorm_op_test.py index 3a58cbe6d960..e7ab634d0e7c 100644 --- a/caffe2/python/operator_test/lpnorm_op_test.py +++ b/caffe2/python/operator_test/lpnorm_op_test.py @@ -16,7 +16,7 @@ class LpnormTest(hu.HypothesisTestCase): max_dim=3, dtype=np.float32), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_Lp_Norm(self, inputs, gc, dc): X = inputs[0] # avoid kinks by moving away from 0 diff --git a/caffe2/python/operator_test/margin_ranking_criterion_op_test.py b/caffe2/python/operator_test/margin_ranking_criterion_op_test.py index e28dd1ce28f8..a91de60a8c19 100644 --- a/caffe2/python/operator_test/margin_ranking_criterion_op_test.py +++ b/caffe2/python/operator_test/margin_ranking_criterion_op_test.py @@ -17,7 +17,7 @@ class TestMarginRankingCriterion(serial.SerializedTestCase): seed=st.integers(min_value=0, max_value=65535), margin=st.floats(min_value=-0.5, max_value=0.5), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_margin_ranking_criterion(self, N, seed, margin, gc, dc): np.random.seed(seed) X1 = np.random.randn(N).astype(np.float32) diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py index 8b4001a574ac..067eeabbe2d9 100644 --- a/caffe2/python/operator_test/matmul_op_test.py +++ b/caffe2/python/operator_test/matmul_op_test.py @@ -60,7 +60,7 @@ def matmul_ref(X, Y, trans_a, trans_b): trans_b=st.booleans(), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_matmul_axis( self, M, K, N, axis_a, axis_b, trans_a, trans_b, gc, dc ): diff --git a/caffe2/python/operator_test/one_hot_ops_test.py b/caffe2/python/operator_test/one_hot_ops_test.py index 593d5b5aa58c..e23e04434ab3 100644 --- a/caffe2/python/operator_test/one_hot_ops_test.py +++ b/caffe2/python/operator_test/one_hot_ops_test.py @@ -63,7 +63,7 @@ def ref(x, lens, vals): elements=st.integers(min_value=-5, max_value=5)), seed=st.integers(min_value=0, max_value=1000), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_batch_bucketized_one_hot(self, x, seed, gc, dc): np.random.seed(seed) d = x.shape[1] diff --git a/caffe2/python/operator_test/pooling_test.py b/caffe2/python/operator_test/pooling_test.py index 7ef98249bd79..2954face6b85 100644 --- a/caffe2/python/operator_test/pooling_test.py +++ b/caffe2/python/operator_test/pooling_test.py @@ -90,7 +90,7 @@ def test_pooling_big_batch(self, gc, dc): op_type=st.sampled_from(["MaxPool", "AveragePool", "MaxPool1D", "AveragePool1D"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_pooling_1d(self, stride, pad, kernel, size, input_channels, batch_size, order, op_type, gc, dc): assume(pad < kernel) diff --git a/caffe2/python/operator_test/python_op_test.py b/caffe2/python/operator_test/python_op_test.py index b071070151d1..8f41815585dc 100644 --- a/caffe2/python/operator_test/python_op_test.py +++ b/caffe2/python/operator_test/python_op_test.py @@ -14,7 +14,7 @@ class PythonOpTest(hu.HypothesisTestCase): @given(x=hu.tensor(), n=st.integers(min_value=1, max_value=20), w=st.integers(min_value=1, max_value=20)) - @settings(deadline=1000) + @settings(deadline=10000) def test_simple_python_op(self, x, n, w): def g(input_, output): output[...] = input_ diff --git a/caffe2/python/operator_test/selu_op_test.py b/caffe2/python/operator_test/selu_op_test.py index 4dd2fa1848bf..73cb0736dcee 100644 --- a/caffe2/python/operator_test/selu_op_test.py +++ b/caffe2/python/operator_test/selu_op_test.py @@ -33,7 +33,7 @@ def test_selu_1(self, X, gc, dc, engine): @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_selu_2(self, X, gc, dc, engine): alpha = 1.6732 scale = 1.0507 @@ -50,7 +50,7 @@ def test_selu_2(self, X, gc, dc, engine): @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_selu_3(self, X, gc, dc, engine): alpha = 1.3 scale = 1.1 diff --git a/caffe2/python/operator_test/sequence_ops_test.py b/caffe2/python/operator_test/sequence_ops_test.py index 65c0669abfb0..524d3c8b4149 100644 --- a/caffe2/python/operator_test/sequence_ops_test.py +++ b/caffe2/python/operator_test/sequence_ops_test.py @@ -106,7 +106,7 @@ class TestSequenceOps(serial.SerializedTestCase): args=_gen_test_add_padding(with_pad_data=True), ret_lengths=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_add_padding( self, start_pad_width, end_pad_width, args, ret_lengths, gc, dc ): @@ -278,7 +278,7 @@ def op_ref(data, indices): min_size=0, max_size=10), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_find_duplicate_elements(self, elements, gc, dc): mapping = { 0: "a", diff --git a/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py b/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py index 6e8cae62dbff..03b50bfc952d 100644 --- a/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py +++ b/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py @@ -33,7 +33,7 @@ class TestSinusoidPositionEncodingOp(serial.SerializedTestCase): amplitude=st.floats(MIN_TEST_AMPLITUDE, MAX_TEST_AMPLITUDE), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_sinusoid_embedding( self, positions_vec, embedding_size, batch_size, alpha, amplitude, gc, dc ): diff --git a/caffe2/python/operator_test/softplus_op_test.py b/caffe2/python/operator_test/softplus_op_test.py index dd183b774f92..f8ca1817176e 100644 --- a/caffe2/python/operator_test/softplus_op_test.py +++ b/caffe2/python/operator_test/softplus_op_test.py @@ -14,7 +14,7 @@ class TestSoftplus(hu.HypothesisTestCase): @given(X=hu.tensor(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_softplus(self, X, gc, dc): op = core.CreateOperator("Softplus", ["X"], ["Y"]) self.assertDeviceChecks(dc, op, [X], [0]) diff --git a/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py b/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py index 41ec8808bb6a..267babf2145f 100644 --- a/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py +++ b/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py @@ -14,7 +14,7 @@ class TestFcOperator(hu.HypothesisTestCase): @given(n=st.integers(1, 10), k=st.integers(1, 5), use_length=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_to_dense_mask(self, n, k, use_length, gc, dc): lengths = np.random.randint(k, size=n).astype(np.int32) + 1 N = sum(lengths) @@ -47,7 +47,7 @@ def test_sparse_to_dense_mask(self, n, k, use_length, gc, dc): @given(n=st.integers(1, 10), k=st.integers(1, 5), use_length=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_to_dense_mask_with_int64(self, n, k, use_length, gc, dc): lengths = np.random.randint(k, size=n).astype(np.int32) + 1 N = sum(lengths) diff --git a/caffe2/python/operator_test/string_ops_test.py b/caffe2/python/operator_test/string_ops_test.py index a0c56a686666..aa706ad73d7c 100644 --- a/caffe2/python/operator_test/string_ops_test.py +++ b/caffe2/python/operator_test/string_ops_test.py @@ -20,7 +20,7 @@ def _string_lists(alphabet=None): class TestStringOps(serial.SerializedTestCase): @given(strings=_string_lists()) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_prefix(self, strings): length = 3 # although we are utf-8 encoding below to avoid python exceptions, @@ -48,7 +48,7 @@ def string_prefix_ref(strings): string_prefix_ref) @given(strings=_string_lists()) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_suffix(self, strings): length = 3 strings = np.array( @@ -72,7 +72,7 @@ def string_suffix_ref(strings): string_suffix_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_starts_with(self, strings): prefix = 'a' strings = np.array( @@ -96,7 +96,7 @@ def string_starts_with_ref(strings): string_starts_with_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_ends_with(self, strings): suffix = 'a' strings = np.array( @@ -120,7 +120,7 @@ def string_ends_with_ref(strings): string_ends_with_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_equals(self, strings): text = "" if strings: diff --git a/caffe2/python/operator_test/top_k_test.py b/caffe2/python/operator_test/top_k_test.py index fa628456c3a4..035b1fb3d099 100644 --- a/caffe2/python/operator_test/top_k_test.py +++ b/caffe2/python/operator_test/top_k_test.py @@ -140,7 +140,7 @@ def bind_ref(X_loc): @given(bs=st.integers(1, 3), n=st.integers(100, 10000), flatten_indices=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_top_k_4(self, bs, n, flatten_indices, gc, dc): k = np.random.randint(n // 3, 3 * n // 4) X = np.random.rand(bs, n).astype(dtype=np.float32) @@ -177,7 +177,7 @@ def bind_ref(X_loc): @given(bs=st.integers(1, 3), n=st.integers(1, 5000), flatten_indices=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_top_k_6(self, bs, n, flatten_indices, gc, dc): k = n X = np.random.rand(bs, n).astype(dtype=np.float32) diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index e568f8bdff74..f99a61688de6 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -991,7 +991,7 @@ def test_gather_ranges_to_dense_op(self): np.testing.assert_array_almost_equal(ref_outputs[i], outputs[i].numpy()) @given(lengths_0=st.integers(1, 10), lengths_1=st.integers(1, 10)) - @settings(deadline=1000) + @settings(deadline=10000) def test_merge_id_lists(self, lengths_0, lengths_1): def _merge_id_lists(lengths, values): ref_op = core.CreateOperator( diff --git a/caffe2/python/operator_test/utility_ops_test.py b/caffe2/python/operator_test/utility_ops_test.py index aeefbf596afe..187328f9e484 100644 --- a/caffe2/python/operator_test/utility_ops_test.py +++ b/caffe2/python/operator_test/utility_ops_test.py @@ -332,7 +332,7 @@ def sum_op_ref(*args): ) ), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_lengths_gather(self, inputs, gc, dc): items = inputs[0] lengths = inputs[1] @@ -359,7 +359,7 @@ def lengths_gather_op(items, lengths, indices): @given( inputs=hu.lengths_tensor(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_lengths_to_ranges(self, inputs, gc, dc): _, lengths = inputs diff --git a/caffe2/python/operator_test/weighted_sum_test.py b/caffe2/python/operator_test/weighted_sum_test.py index 2c7dffe92672..fbbe2a6bf6d8 100644 --- a/caffe2/python/operator_test/weighted_sum_test.py +++ b/caffe2/python/operator_test/weighted_sum_test.py @@ -61,7 +61,7 @@ def weighted_sum_op_ref(*args): @given(n=st.integers(1, 8), m=st.integers(1, 10), d=st.integers(1, 4), grad_on_w=st.booleans(), seed=st.integers(min_value=0, max_value=65535), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_weighted_sum_grad( self, n, m, d, grad_on_w, seed, gc, dc): input_names = [] diff --git a/caffe2/python/operator_test/wngrad_test.py b/caffe2/python/operator_test/wngrad_test.py index 48fe0f94731e..0a1f0405e92a 100644 --- a/caffe2/python/operator_test/wngrad_test.py +++ b/caffe2/python/operator_test/wngrad_test.py @@ -113,7 +113,7 @@ def test_wngrad_dense_base(self, inputs, seq_b, lr, epsilon, gc, dc): epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_wngrad_dense_output_effective_lr(self, inputs, seq_b, lr, epsilon, gc, dc): param, grad = inputs @@ -142,7 +142,7 @@ def test_wngrad_dense_output_effective_lr(self, inputs, seq_b, epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_wngrad_dense_output_effective_lr_and_update( self, inputs, seq_b, lr, epsilon, gc, dc): param, grad = inputs @@ -165,7 +165,7 @@ def test_wngrad_dense_output_effective_lr_and_update( # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. - @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=1000) + @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=10000) @given(inputs=hu.tensors(n=2), seq_b=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), @@ -186,7 +186,7 @@ def test_sparse_wngrad(self, inputs, seq_b, lr, epsilon, gc, dc): epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_wngrad_empty(self, inputs, seq_b, lr, epsilon, gc, dc): param = inputs[0] seq_b = np.array([seq_b, ], dtype=np.float32) From 32273e806a7462aea6d55212f81a93b7862d7db0 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 25 May 2021 16:31:55 -0700 Subject: [PATCH 13/18] Ensure NativeFunctions.h codegen output is deterministic (#58889) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58889 fixes https://github.com/pytorch/pytorch/issues/58796 Planning on re-testing locally tomorrow morning to confirm, but this change should fix the non-determinism in the codegen output that was causing `ccache` not to re-use its cached output. I built from the commit referenced in https://github.com/pytorch/pytorch/issues/58796 a few times and ran `diff -Naur` on the codegen output in `build/aten/src/ATen`. After a few tries, `NativeFunctions.h` had a few diffs. The diffs were all related to the ordering of functional/inplace/out variants of a NativeFunctionGroup, which looked non-deterministic. That looks like it's coming from my calling `set()` to filter out duplicate NativeFunction declarations. The earlier version of the codegen also called `set()` to filter out duplicates, but it did so individually for each `NativeFunction` object, before merging the groups (I'm not too sure why this didn't introduce non-determinism before. though). With the refactor from https://github.com/pytorch/pytorch/pull/57361, we're calling `set()` on the declarations from every operator for a given DispatchKey, which is probably what introduced the nondeterminism. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D28675941 Pulled By: bdhirsh fbshipit-source-id: bb66de00aafeeb9720d85e8156ac9f7539aed0d6 --- tools/codegen/gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index af2cab42d1f7..9ce4ebcafcf7 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -1026,7 +1026,7 @@ def make_file_manager(install_dir: str) -> FileManager: 'native_function_declarations': list(concatMap( # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! - lambda f: list(set(concatMap( + lambda f: list(OrderedDict.fromkeys(concatMap( lambda backend_idx: dest.compute_native_function_declaration(f, backend_idx), backend_indices.values()))), From 26c1f0f72e71c096648a16993484234399da307c Mon Sep 17 00:00:00 2001 From: driazati Date: Tue, 25 May 2021 17:00:04 -0700 Subject: [PATCH 14/18] [skip ci] Skip debug info on PRs (#58897) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58897 We don't need to be building debug info on PRs since it's just filling up S3/CircleCI storage with useless 800 MB zips, this flips it so it's only run on master + release branches. See #58898 for CI signal Also see pytorch/builder counterpart (unlike the last debuginfo PR there is no hard dependency between these two so there won't be any churn on un-rebased PRs): https://github.com/pytorch/builder/pull/778 Test Plan: Imported from OSS Reviewed By: seemethere, samestep Differential Revision: D28689413 Pulled By: driazati fbshipit-source-id: 77a37e84afe492215008d5e023ceab0c24adb33c --- .circleci/scripts/binary_linux_build.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.circleci/scripts/binary_linux_build.sh b/.circleci/scripts/binary_linux_build.sh index e36d06906246..755a467fe248 100755 --- a/.circleci/scripts/binary_linux_build.sh +++ b/.circleci/scripts/binary_linux_build.sh @@ -22,5 +22,9 @@ else build_script='manywheel/build.sh' fi +if [[ "$CIRCLE_BRANCH" == "master" ]] || [[ "$CIRCLE_BRANCH" == release/* ]]; then + export BUILD_DEBUG_INFO=1 +fi + # Build the package SKIP_ALL_TESTS=1 "/builder/$build_script" From 083d3bb93b685e5f44a56fe9c293a6b76abb110e Mon Sep 17 00:00:00 2001 From: Serhat Yilmaz Date: Tue, 25 May 2021 20:03:32 -0700 Subject: [PATCH 15/18] [torch][repeat_interlaeve] Add to exception list in backward compat check (#58966) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58966 Same as title. Test Plan: CI since updated the check Reviewed By: ngimel Differential Revision: D28699577 fbshipit-source-id: 436fdc648a4c653081ff0e1b6b809c4af742055a --- test/backward_compatibility/check_backward_compatibility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 8e03ad397913..9de94c511255 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -89,6 +89,7 @@ ("aten::_amp_update_scale", datetime.date(2021, 6, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)), + ("aten::repeat_interleave", datetime.date(2021, 5, 26)), ] def allow_listed(schema, allow_list): From 49c2da0ee06c2e3d9d300b76b655d31a6b7756ee Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 25 May 2021 21:12:52 -0700 Subject: [PATCH 16/18] [testing] improve broadcasts_input error message (#58295) Summary: Context: The Error message when `broadcasts_input` is marked incorrectly is uninformative [See Error Currently] https://github.com/pytorch/pytorch/pull/57941#discussion_r631749435 Error Currently ``` Traceback (most recent call last): File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 326, in test_variant_consistency_eager _test_consistency_helper(samples, variants) File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 310, in _test_consistency_helper variant_forward = variant(cloned, File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__ self._raiseFailure("{} not raised".format(exc_name)) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure raise self.test_case.failureException(msg) AssertionError: RuntimeError not raised ``` Error After PR ``` Traceback (most recent call last): File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 329, in test_variant_consistency_eager _test_consistency_helper(samples, variants) File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 313, in _test_consistency_helper variant_forward = variant(cloned, File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__ self._raiseFailure("{} not raised".format(exc_name)) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure raise self.test_case.failureException(msg) AssertionError: RuntimeError not raised : inplace variant either allowed resizing or you have marked the sample SampleInput(input=Tensor, args=(tensor([[[ 2.1750, -8.5027, -3.1403, -6.9942, 3.2609], [-2.5057, -5.9123, -5.4633, 6.1203, -8.2124], [-3.5802, -8.4869, -6.0700, 2.3431, -8.1955], [-7.3316, 1.3248, -6.8661, 7.1483, -8.0719], [ 4.5977, -4.0448, -6.2044, -2.1314, -8.4956]], [[ 3.2769, -8.4360, 1.2826, 7.1749, 4.7653], [-0.2816, -2.5997, -4.7659, -3.7814, 3.9704], [-2.1778, -3.8117, -6.0276, -0.8423, -5.9646], [ 8.6544, -3.0922, 0.2558, -4.9318, -4.7596], [ 4.5583, 4.3830, 5.8793, 0.9713, -2.1481]], [[-1.0447, 0.9334, 7.6405, -4.8933, -7.4010], [ 7.7168, -8.4266, -5.5980, -6.9368, 7.1309], [-8.7720, -5.0890, -0.4975, 1.9518, 1.7074], [-8.5783, 8.5510, -8.5459, -3.5451, 8.4319], [ 8.5052, -8.9149, -6.6298, -1.2750, -5.7367]], [[-6.5625, 8.2795, -4.9311, 1.9501, -7.1777], [-8.4035, 1.1136, -7.6418, -7.0726, -2.8281], [ 4.2668, -0.2883, -6.2246, 2.3396, 1.2911], [ 4.6550, -1.9525, 4.4873, -3.8061, -0.8653], [-3.4256, 4.4423, 8.2937, -5.3456, -4.2624]], [[ 7.6128, -6.3932, 4.7131, -5.4938, 6.4792], [-6.5385, 2.4385, 4.5570, 3.7803, -8.3281], [-2.9785, -4.4745, -1.1778, -8.9324, 1.3663], [ 3.7437, 3.5171, -6.3135, -8.4519, -2.7033], [-5.0568, -8.4630, -4.2870, -3.7284, -1.5238]]], device='cuda:0', dtype=torch.float32, requires_grad=True),), broadcasts_input=True) incorrectly with `broadcasts_self=True ``` **NOTE**: Printing the sample looks very verbose and it may be hard to figure out which sample is incorrectly configured if there are multiple samples with similar input shapes. Two Options to make this error less verbose * Don't print the sample and just print `inplace variant either allowed resizing or you have marked one of the sample incorrectly with broadcasts_self=True` * Have some mechanism to name samples which will be printed in the `repr` (which will need extra machinery) Pull Request resolved: https://github.com/pytorch/pytorch/pull/58295 Reviewed By: ngimel Differential Revision: D28627308 Pulled By: mruberry fbshipit-source-id: b3bdeacac3cf9c0d984f0b85410ecce474291d20 --- test/test_ops.py | 5 +- .../_internal/common_methods_invocations.py | 58 ++++++++++++++++--- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index f36fc8bd514c..ea160b65ae1b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -321,7 +321,10 @@ def _test_consistency_helper(samples, variants): cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input if variant in inplace_ops and sample.broadcasts_input: - with self.assertRaises(RuntimeError): + with self.assertRaises(RuntimeError, + msg=('inplace variant either incorrectly allowed ' + 'resizing or you have marked the sample {}' + ' incorrectly with `broadcasts_self=True'.format(sample.summary()))): variant_forward = variant(cloned, *sample.args, **sample.kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e6d73e0540fe..88437be53066 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -81,9 +81,9 @@ def __init__(self, cls_name=None, test_name=None, *, class SampleInput(object): """Represents sample inputs to a function.""" - __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input'] + __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input', 'name'] - def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False): + def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False, name=""): # input is the first input to the op and must be either a Tensor or TensorList (Sequence[Tensor]). # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). # op with TensorList inputs do not support method or inplace variants. @@ -92,6 +92,7 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N self.args = args self.kwargs = kwargs if kwargs is not None else {} self.output_process_fn_grad = output_process_fn_grad + self.name = name # Specifies if `self.input` is broadcasted or not, # given that the operator supports broadcasting. @@ -103,17 +104,56 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N # for such inputs (as they will error out otherwise). self.broadcasts_input = broadcasts_input - def __repr__(self): + def _repr_helper(self, formatter): + # Helper function to return the details of the SampleInput as `str` + # It consolidates all the fields of SampleInput and allows, + # formatting the fields like `input`, `args`, etc with `formatter` + # callable to customize the representation. + # Look at `summary` method for example. arguments = [ - 'input=Tensor' if isinstance(self.input, torch.Tensor) else f'input=TensorList[{len(self.input)}]', - f'args={self.args}' if len(self.args) > 0 else None, - f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None, - (f'output_process_fn_grad={self.output_process_fn_grad}' - if self.output_process_fn_grad is not None else None), - f'broadcasts_input={self.broadcasts_input}'] + f'input={formatter(self.input)}', + f'args={formatter(self.args)}', + f'kwargs={formatter(self.kwargs)}', + f'output_process_fn_grad={self.output_process_fn_grad}', + f'broadcasts_input={self.broadcasts_input}', + f'name={repr(self.name)}'] return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + def __repr__(self): + return self._repr_helper(lambda x: x) + + def summary(self): + # Returns the SampleInput details in a more + # friendly format. + # It formats `Tensor` and `TensorList` + # in a more condensed representation. + def is_iter(arg): + try: + iter(arg) + return True + except TypeError as te: + return False + + def formatter(arg): + # Format any instance of `Tensor` (standalone, in list, or in dict) + # by Tensor[TensorShape] + # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] + if isinstance(arg, torch.Tensor): + shape = str(tuple(arg.shape)).replace('(', '').replace(')', '') + return f"Tensor[{shape}]" + elif isinstance(arg, dict): + return {k: formatter(v) for k, v in arg.items()} + elif is_iterable_of_tensors(arg): + return "TensorList[" + ", ".join(map(formatter, arg)) + "]" + elif is_iter(arg): # Handle list, tuple or any iterable type + return "(" + ",".join(map(formatter, arg)) + ")" + + return repr(arg) + + return self._repr_helper(formatter) + + class AliasInfo(object): """Class holds alias information. For example, torch.abs -> torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ From 948df6c7a9935011e3142c60f96a821c23595be0 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 25 May 2021 22:00:55 -0700 Subject: [PATCH 17/18] [numpy] torch.i0: promote integer inputs to float (#52735) Summary: Reference : https://github.com/pytorch/pytorch/issues/42515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/52735 Reviewed By: zou3519 Differential Revision: D28630505 Pulled By: mruberry fbshipit-source-id: e81a35dfc1a322daf0c44718901470fac677bc94 --- aten/src/ATen/native/UnaryOps.cpp | 2 +- .../ATen/native/cuda/UnarySpecialOpsKernel.cu | 2 +- .../_internal/common_methods_invocations.py | 26 ++++++++++++++++--- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 04ba69a604f3..6a13af13b37f 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -49,6 +49,7 @@ CREATE_UNARY_FLOAT_META_FUNC(erfinv) CREATE_UNARY_FLOAT_META_FUNC(exp) CREATE_UNARY_FLOAT_META_FUNC(exp2) CREATE_UNARY_FLOAT_META_FUNC(expm1) +CREATE_UNARY_FLOAT_META_FUNC(i0) CREATE_UNARY_FLOAT_META_FUNC(lgamma) CREATE_UNARY_FLOAT_META_FUNC(log) CREATE_UNARY_FLOAT_META_FUNC(log10) @@ -78,7 +79,6 @@ TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) { } CREATE_UNARY_META_FUNC(bitwise_not) CREATE_UNARY_META_FUNC(frac) -CREATE_UNARY_META_FUNC(i0) CREATE_UNARY_META_FUNC(round) CREATE_UNARY_META_FUNC(sgn) diff --git a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu index b6218b9f5581..85108c980c15 100644 --- a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu @@ -31,7 +31,7 @@ void exp2_kernel_cuda(TensorIteratorBase& iter) { } void i0_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_i0(a); }); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 88437be53066..b3d18480956a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2043,6 +2043,24 @@ def wrapped_fn(x): return wrapped_fn +def np_unary_ufunc_integer_promotion_wrapper_with_astype(fn): + # Check np_unary_ufunc_integer_promotion_wrapper + def is_integral(dtype): + return dtype in [np.bool_, bool, np.uint8, np.int8, np.int16, np.int32, np.int64] + + @wraps(fn) + def wrapped_fn(x): + # As the default dtype can change, acquire it when function is called. + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + if is_integral(x.dtype): + return fn(x).astype(np_dtype) + return fn(x) + + return wrapped_fn + + # Metadata class for Fast Fourier Transforms in torch.fft. class SpectralFuncInfo(OpInfo): """Operator information for torch.fft transforms. """ @@ -4699,11 +4717,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_fliplr_flipud, supports_out=False), UnaryUfuncInfo('i0', - ref=np.i0, + ref=np_unary_ufunc_integer_promotion_wrapper_with_astype( + scipy.special.i0) if TEST_SCIPY else _NOTHING, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 5e-1}),), - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + safe_casts_outputs=True, supports_autograd=False), UnaryUfuncInfo('special.i0e', aten_name='special_i0e', From be4ba29d49566a1d9069b8fb35d4b9b44bb1f1c5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 25 May 2021 22:15:09 -0700 Subject: [PATCH 18/18] Detect overflow in numel of sparse COO tensor (#57492) Summary: Fixes https://github.com/pytorch/pytorch/issues/57416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/57492 Reviewed By: albanD Differential Revision: D28273649 Pulled By: mruberry fbshipit-source-id: 08ba50509556df1981d7ede025d84a836d2e8e5e --- aten/src/ATen/SparseTensorImpl.h | 5 +++++ c10/core/TensorImpl.h | 33 +++++++++++++++++++++++++++++++- test/test_sparse.py | 14 ++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index a416e5e53051..e2fc89a9db84 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -29,6 +29,11 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { // because many algorithms proceed by merging two sorted lists (of indices). bool coalesced_ = false; + // compute_numel with integer multiplication overflow check, see gh-57542 + void refresh_numel() { + TensorImpl::safe_refresh_numel(); + } + public: // Public for now... explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 5e973da15fcd..e383ffb4c57a 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2023,6 +2023,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return n; } + /** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ + int64_t safe_compute_numel() const { + int64_t n = 1; + for (auto s : sizes()) { + TORCH_CHECK( + s == 0 || n <= std::numeric_limits::max() / s, + "numel: integer multiplication overflow"); + n *= s; + } + return n; + } + /** * Compute whether or not a tensor is contiguous based on the sizes and * strides of a tensor. @@ -2041,12 +2057,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { protected: /** - * Recompute the cached numel of a tensor. Call this if you modify sizes. + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. + * + * For tensors with sparse layouts, use safe_refresh_numel() instead + * because it will catch integer overflow that may occur for tensors + * with sparse layouts and large dimensions. */ void refresh_numel() { numel_ = compute_numel(); } + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. Use only for tensors with sparse layouts because only + * sparse tensor are likely to have sizes that may lead to integer + * overflow when computing numel. + */ + void safe_refresh_numel() { + numel_ = safe_compute_numel(); + } + /** * Recompute the cached contiguity of a tensor. Call this if you modify sizes * or strides. diff --git a/test/test_sparse.py b/test/test_sparse.py index c201704cdf5f..5b9b873fe646 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -246,6 +246,20 @@ def test_sparse_sum(): ref = test_sparse_sum() self.assertTrue(ref.expired()) + @dtypes(torch.double) + def test_ctor_large_sizes(self, device, dtype): + # Test that integer overflow is detected when computing numel + # of a sparse tensor with large dimensions (gh-57416). Notice + # that numel is computed internally when constructing a + # tensor, hence the overflow may appear during the tensor + # construction step. + N = 100000 + indices = torch.tensor([[N, N - 1]] * 4, dtype=torch.int64, device=device) + values = torch.tensor([1, 2], dtype=dtype, device=device) + self.assertRaises(RuntimeError, + lambda: torch.sparse_coo_tensor( + indices, values, (N + 1,) * 4, device=device)) + @dtypes(torch.double, torch.cdouble) def test_ctor_size_checks(self, device, dtype): indices = self.index_tensor([