Skip to content

Commit

Permalink
Update on "mocked module detection during save pass"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
PaliC committed Dec 10, 2021
2 parents 6ac9fae + dec111e commit bc4c320
Show file tree
Hide file tree
Showing 6 changed files with 964 additions and 679 deletions.
2 changes: 1 addition & 1 deletion cmake/Modules/FindMKLDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ SET(MKLDNN_LIBRARIES)
SET(MKLDNN_INCLUDE_DIR)

SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
SET(MKLDNN_ROOT "${IDEEP_ROOT}/mkl-dnn")
SET(MKLDNN_ROOT "${IDEEP_ROOT}/mkl-dnn/third_party/oneDNN")

FIND_PACKAGE(BLAS)
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
Expand Down
19 changes: 10 additions & 9 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8748,16 +8748,17 @@ def test_inference_mode_context_manager(self):
self.assertFalse(torch.is_inference_mode_enabled())

def test_inference_mode_decorator(self):
@torch.inference_mode()
def func(x):
self.assertTrue(torch.is_inference_mode_enabled())
return x * x
for mode in (True, False):
@torch.inference_mode(mode)
def func(x):
self.assertEqual(torch.is_inference_mode_enabled(), mode)
return x * x

for requires_grad in (True, False):
c = torch.ones(1, 2, 3, requires_grad=requires_grad)
d = func(c)
self.assertTrue(torch.is_inference(d))
self.assertFalse(d.requires_grad)
for requires_grad in (True, False):
c = torch.ones(1, 2, 3, requires_grad=requires_grad)
d = func(c)
self.assertTrue(not mode or torch.is_inference(d))
self.assertEqual(d.requires_grad, requires_grad and not mode)

def test_inference_mode_tensor_creation(self):
with torch.inference_mode():
Expand Down
61 changes: 31 additions & 30 deletions third_party/mkl-dnn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,54 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
"#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "/* undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE */",
"#cmakedefine DNNL_WITH_SYCL": "/* #undef DNNL_WITH_SYCL */",
"#cmakedefine DNNL_WITH_LEVEL_ZERO": "/* #undef DNNL_WITH_LEVEL_ZERO */",
"#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */",
}

template_rule(
name = "include_dnnl_version",
src = "include/oneapi/dnnl/dnnl_version.h.in",
out = "include/oneapi/dnnl/dnnl_version.h",
name = "third_party/oneDNN/include_dnnl_version",
src = "third_party/oneDNN/include/oneapi/dnnl/dnnl_version.h.in",
out = "third_party/oneDNN/include/oneapi/dnnl/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "2",
"@DNNL_VERSION_MINOR@": "2",
"@DNNL_VERSION_MINOR@": "3",
"@DNNL_VERSION_PATCH@": "3",
"@DNNL_VERSION_HASH@": "7336ca9f055cf1bfa13efb658fe15dc9b41f0740",
"@DNNL_VERSION_HASH@": "f40443c413429c29570acd6cf5e3d1343cf647b4",
},
)

template_rule(
name = "include_dnnl_config",
src = "include/oneapi/dnnl/dnnl_config.h.in",
out = "include/oneapi/dnnl/dnnl_config.h",
name = "third_party/oneDNN/include_dnnl_config",
src = "third_party/oneDNN/include/oneapi/dnnl/dnnl_config.h.in",
out = "third_party/oneDNN/include/oneapi/dnnl/dnnl_config.h",
substitutions = _DNNL_RUNTIME_OMP,
)

cc_library(
name = "mkl-dnn",
srcs = glob([
"src/common/*.cpp",
"src/cpu/**/*.cpp",
"third_party/oneDNN/src/common/*.cpp",
"third_party/oneDNN/src/cpu/**/*.cpp",
], exclude=[
"src/cpu/aarch64/**/*.cpp",
"third_party/oneDNN/src/cpu/aarch64/**/*.cpp",
]),
hdrs = glob([
"include/oneapi/dnnl/*.h",
"include/oneapi/dnnl/*.hpp",
"include/*.h",
"include/*.hpp",
"src/cpu/**/*.hpp",
"src/cpu/**/*.h",
"src/common/*.hpp",
"src/common/ittnotify/jitprofiling.h",
"third_party/oneDNN/include/oneapi/dnnl/*.h",
"third_party/oneDNN/include/oneapi/dnnl/*.hpp",
"third_party/oneDNN/include/*.h",
"third_party/oneDNN/include/*.hpp",
"third_party/oneDNN/src/cpu/**/*.hpp",
"third_party/oneDNN/src/cpu/**/*.h",
"third_party/oneDNN/src/common/*.hpp",
"third_party/oneDNN/src/common/ittnotify/jitprofiling.h",
], exclude=[
"src/cpu/aarch64/**/*.hpp",
"src/cpu/aarch64/**/*.h",
"third_party/oneDNN/src/cpu/aarch64/**/*.hpp",
"third_party/oneDNN/src/cpu/aarch64/**/*.h",
]) + [
"include/oneapi/dnnl/dnnl_config.h",
"include/oneapi/dnnl/dnnl_version.h",
"third_party/oneDNN/include/oneapi/dnnl/dnnl_config.h",
"third_party/oneDNN/include/oneapi/dnnl/dnnl_version.h",
],
copts = [
"-DUSE_AVX",
Expand All @@ -69,13 +70,13 @@ cc_library(
"//conditions:default": ["-DDNNL_CPU_RUNTIME=2"],
}),
includes = [
"include/",
"include/oneapi/",
"include/oneapi/dnnl/",
"src/",
"src/common/",
"src/cpu/",
"src/cpu/x64/xbyak/",
"third_party/oneDNN/include/",
"third_party/oneDNN/include/oneapi/",
"third_party/oneDNN/include/oneapi/dnnl/",
"third_party/oneDNN/src/",
"third_party/oneDNN/src/common/",
"third_party/oneDNN/src/cpu/",
"third_party/oneDNN/src/cpu/x64/xbyak/",
],
visibility = ["//visibility:public"],
linkopts = [
Expand Down
25 changes: 17 additions & 8 deletions torch/autograd/grad_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
from typing import Any, Callable, TypeVar, cast


__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
'inference_mode']

Expand All @@ -24,7 +23,7 @@ def __call__(self, func: F) -> F:

@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.__class__():
with self.clone():
return func(*args, **kwargs)
return cast(F, decorate_context)

Expand All @@ -38,10 +37,9 @@ def generator_context(*args, **kwargs):
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
cls = type(self)
try:
# Issuing `None` to a generator fires it up
with cls():
with self.clone():
response = gen.send(None)

while True:
Expand All @@ -51,18 +49,18 @@ def generator_context(*args, **kwargs):

except GeneratorExit:
# Inform the still active generator about its imminent closure
with cls():
with self.clone():
gen.close()
raise

except BaseException:
# Propagate the exception thrown at us by the caller
with cls():
with self.clone():
response = gen.throw(*sys.exc_info())

else:
# Pass the last request to the generator and get its response
with cls():
with self.clone():
response = gen.send(request)

# We let the exceptions raised above by the generator's `.throw` or
Expand All @@ -81,6 +79,10 @@ def __enter__(self) -> None:
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError

def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()


class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.
Expand Down Expand Up @@ -172,7 +174,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)


class set_grad_enabled(object):
class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation to on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
Expand Down Expand Up @@ -213,13 +215,17 @@ class set_grad_enabled(object):
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(mode)
self.mode = mode

def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)

def clone(self):
return self.__class__(self.mode)


class inference_mode(_DecoratorContextManager):
r"""Context-manager that enables or disables inference mode
Expand Down Expand Up @@ -274,3 +280,6 @@ def __enter__(self):

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
del self._inference_mode_raii_guard

def clone(self):
return self.__class__(self.mode)

0 comments on commit bc4c320

Please sign in to comment.