Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework compat bindings. #47863

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ torch/lib64
torch/include/
torch/share/
torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
# Root level file used in CI to specify certain env configs.
# E.g., see .circleci/config.yaml
Expand Down
14 changes: 12 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,16 @@ def check_file(f):

# Use copies instead of symbolic files.
# Windows has very poor support for them.
sym_files = ['tools/shared/_utils_internal.py']
orig_files = ['torch/_utils_internal.py']
sym_files = [
'tools/shared/_utils_internal.py',
'torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h',
'torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h',
]
orig_files = [
'torch/_utils_internal.py',
'third_party/valgrind-headers/callgrind.h',
'third_party/valgrind-headers/valgrind.h',
]
for sym_file, orig_file in zip(sym_files, orig_files):
same = False
if os.path.exists(sym_file):
Expand Down Expand Up @@ -907,6 +915,8 @@ def print_box(msg):
'share/cmake/Gloo/*.cmake',
'share/cmake/Tensorpipe/*.cmake',
'share/cmake/Torch/*.cmake',
'utils/benchmark/utils/valgrind_wrapper/*.cpp',
'utils/benchmark/utils/valgrind_wrapper/*.h',
],
'caffe2': [
'python/serialized_test/data/operator_test/*.zip',
Expand Down
24 changes: 24 additions & 0 deletions torch/utils/benchmark/utils/_stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
from typing import TYPE_CHECKING


if TYPE_CHECKING or sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol


class CallgrindModuleType(Protocol):
"""Replicates the valgrind endpoints in `torch._C`.

These bindings are used to collect Callgrind profiles on earlier versions
of PyTorch and will eventually be removed.
"""
__file__: str
__name__: str

def _valgrind_supported_platform(self) -> bool:
...

def _valgrind_toggle(self) -> None:
...
68 changes: 68 additions & 0 deletions torch/utils/benchmark/utils/cpp_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""JIT C++ strings into executables."""
import os
import threading
from typing import List, Optional

import torch
from torch.utils.benchmark.utils._stubs import CallgrindModuleType
from torch.utils import cpp_extension


LOCK = threading.Lock()
SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]

# BACK_TESTING_NOTE:
# There are two workflows where this code could be used. One is the obvious
# case where someone simply builds or installs PyTorch and uses Timer.
# The other is that the entire `torch/utils/benchmark` folder from a CURRENT
# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
# source code. This is what we refer to here as "back testing". The rationale
# is that we might want to use current tooling to study some aspect of an
# earlier version of PyTorch. (e.g. a regression.)
#
# The problem is that Timer relies on several aspects of core PyTorch, namely
# some binding functions for Valgrind symbols in `torch._C` and the
# `torch.__config__._cxx_flags()` method. If we were to naively copy code
# around this wouldn't work as the symbols of interest aren't present in
# earlier versions of PyTorch. In order to work around this, we must add back
# testing shims. These shims will never activate during normal use, but will
# allow Timer to function outside of the "correct" version of PyTorch by
# emulating functionality that was added later.
#
# These shims are temporary, and as Timer becomes more integrated with
# PyTorch the cost and complexity of such shims will increase. Once back
# testing is no longer required (which is to say we have done enough historic
# analysis and the shims no longer justify their maintenance and code
# complexity costs) back testing paths will be removed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this note is super helpful. One extra benefit of the note is it also clues the reader in on what kinds of code changes to Timer are permissible, and what are not (changes that add more dependencies on C symbols => you need to add more backtesting support.)


if hasattr(torch.__config__, "_cxx_flags"):
CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
if "-g" not in CXX_FLAGS:
CXX_FLAGS.append("-g")
else:
# FIXME: Remove when back testing is no longer required.
CXX_FLAGS = ["-O2", "-fPIC", "-g"]

EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
CONDA_PREFIX = os.getenv("CONDA_PREFIX")
if CONDA_PREFIX is not None:
# Load will automatically search /usr/include, but not conda include.
EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be nice to have a Note explaining at a high level what's going on here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is "temporary" stuff but because it interacts nontrivially with some code that isn't in PyTorch itself, it will be harder for other people to figure out how this relates to the bigger picture. A note here will help a lot.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's a valid point. Check out BACK_TESTING_NOTE: and let me know if it seems reasonable.

COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
def get_compat_bindings() -> CallgrindModuleType:
with LOCK:
global COMPAT_CALLGRIND_BINDINGS
if COMPAT_CALLGRIND_BINDINGS is None:
COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
name="callgrind_bindings",
sources=[os.path.join(
SOURCE_ROOT,
"valgrind_wrapper",
"compat_bindings.cpp"
)],
extra_cflags=CXX_FLAGS,
extra_include_paths=EXTRA_INCLUDE_PATHS,
)
return COMPAT_CALLGRIND_BINDINGS
25 changes: 25 additions & 0 deletions torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Used to collect profiles of old versions of PyTorch. */
#include <callgrind.h>
#include <pybind11/pybind11.h>


bool _valgrind_supported_platform() {
#if defined(NVALGRIND)
return false;
#else
return true;
#endif
}

void _valgrind_toggle() {
#if defined(NVALGRIND)
TORCH_CHECK(false, "Valgrind is not supported.");
#else
CALLGRIND_TOGGLE_COLLECT;
#endif
}

PYBIND11_MODULE(callgrind_bindings, m) {
m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
m.def("_valgrind_toggle", &_valgrind_toggle);
}
41 changes: 0 additions & 41 deletions torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py

This file was deleted.

15 changes: 6 additions & 9 deletions torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import sys
import tempfile
import textwrap
from types import ModuleType
from typing import (
cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple,
Optional, Tuple, Union, TYPE_CHECKING)

import torch
from torch.utils.benchmark.utils import common
from torch.utils.benchmark.utils import common, cpp_jit
from torch.utils.benchmark.utils._stubs import CallgrindModuleType


__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
Expand Down Expand Up @@ -444,17 +444,14 @@ def construct(self) -> str:

class _ValgrindWrapper(object):
def __init__(self) -> None:
self._bindings_module: Optional[ModuleType] = None
self._bindings_module: Optional[CallgrindModuleType] = None
if hasattr(torch._C, "_valgrind_supported_platform"):
self._supported_platform: bool = torch._C._valgrind_supported_platform()

else:
print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
# This import will JIT the Callgrind control bindings, so don't
# invoke unless we know we'll need it.
from torch.utils.benchmark.utils.valgrind_wrapper.compat_bindings import bindings
self._bindings_module = bindings
self._supported_platform = bindings._valgrind_supported_platform()
self._bindings_module = cpp_jit.get_compat_bindings()
self._supported_platform = self._bindings_module._valgrind_supported_platform()

self._commands_available: Dict[str, bool] = {}
if self._supported_platform:
Expand Down Expand Up @@ -643,7 +640,7 @@ def _construct_script(
number: int,
error_log: str,
stat_log: str,
bindings: Optional[ModuleType],
bindings: Optional[CallgrindModuleType],
) -> str:
# The naive template looks something like:
# "for _ in range({number}): {stmt}"
Expand Down