Skip to content

Commit

Permalink
Use f-strings in torch.utils.cpp_extension (#47025)
Browse files Browse the repository at this point in the history
Summary:
Plus two minor fixes to `torch/csrc/Module.cpp`:
 - Use iterator of type `Py_ssize_t` for array indexing in `THPModule_initNames`
 - Fix clang-tidy warning of unneeded defaultGenerator copy by capturing it as `const auto&`

Pull Request resolved: #47025

Reviewed By: samestep

Differential Revision: D24605907

Pulled By: malfet

fbshipit-source-id: c276567d320758fa8b6f4bd64ff46d2ea5d40eff
  • Loading branch information
malfet authored and facebook-github-bot committed Oct 29, 2020
1 parent 9d23fd5 commit 42a5114
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 53 deletions.
6 changes: 3 additions & 3 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
if (!types) return nullptr;

int num_classes = PySequence_Fast_GET_SIZE(types.get());
auto num_classes = PySequence_Fast_GET_SIZE(types.get());
names.reserve(names.size() + num_classes);
for (size_t i = 0; i < num_classes; i++) {
for (Py_ssize_t i = 0; i < num_classes; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject");
PyTypeObject* type = (PyTypeObject*)obj;
Expand Down Expand Up @@ -864,7 +864,7 @@ Call this whenever a new thread is created in order to propagate values from
ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
#endif

auto defaultGenerator = at::detail::getDefaultCPUGenerator();
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
// This reference is meant to be given away, so no need to incref here.
ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false));
Expand Down
94 changes: 44 additions & 50 deletions torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _find_cuda_home() -> Optional[str]:
if not os.path.exists(cuda_home):
cuda_home = None
if cuda_home and not torch.cuda.is_available():
print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'")
return cuda_home

def _find_rocm_home() -> Optional[str]:
Expand All @@ -72,7 +72,7 @@ def _find_rocm_home() -> Optional[str]:
if not os.path.exists(rocm_home):
rocm_home = None
if rocm_home and torch.version.hip is None:
print("No ROCm runtime is found, using ROCM_HOME='{}'".format(rocm_home))
print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
return rocm_home


Expand Down Expand Up @@ -275,13 +275,13 @@ def check_compiler_abi_compatibility(compiler) -> bool:
version = (0, 0, 0) if match is None else match.groups()
except Exception:
_, error, _ = sys.exc_info()
warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error))
warnings.warn(f'Error checking compiler version for {compiler}: {error}')
return False

if tuple(map(int, version)) >= minimum_required_version:
return True

compiler = '{} {}'.format(compiler, ".".join(version))
compiler = f'{compiler} {".".join(version)}'
warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))

return False
Expand Down Expand Up @@ -715,7 +715,7 @@ def _define_torch_extension_name(self, extension):
# as the library name
names = extension.name.split('.')
name = names[-1]
define = '-DTORCH_EXTENSION_NAME={}'.format(name)
define = f'-DTORCH_EXTENSION_NAME={name}'
self._add_compile_flag(extension, define)

def _add_gnu_cpp_abi_flag(self, extension):
Expand Down Expand Up @@ -1102,9 +1102,7 @@ def load_inline(name,
# Make the function docstring the same as the function name.
functions = dict((f, f) for f in functions)
elif not isinstance(functions, dict):
raise ValueError(
"Expected 'functions' to be a list or dict, but was {}".format(
type(functions)))
raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}")
for function_name, docstring in functions.items():
if with_pytorch_error_handling:
module_def.append(
Expand Down Expand Up @@ -1170,9 +1168,9 @@ def _jit_compile(name,
)
if version > 0:
if version != old_version and verbose:
print('The input conditions for extension module {} have changed. '.format(name) +
'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name))
name = '{}_v{}'.format(name, version)
print(f'The input conditions for extension module {name} have changed. ' +
f'Bumping to version {version} and re-building as {name}_v{version}...')
name = f'{name}_v{version}'

if version != old_version:
baton = FileBaton(os.path.join(build_directory, 'lock'))
Expand Down Expand Up @@ -1205,7 +1203,7 @@ def _jit_compile(name,
baton.wait()
elif verbose:
print('No modifications detected for re-loaded extension '
'module {}, skipping build step...'.format(name))
f'module {name}, skipping build step...')

if verbose:
print(f'Loading extension module {name}...')
Expand Down Expand Up @@ -1292,11 +1290,11 @@ def _write_ninja_file_and_build_library(
with_cuda=with_cuda)

if verbose:
print('Building extension module {}...'.format(name))
print(f'Building extension module {name}...')
_run_ninja_build(
build_directory,
verbose,
error_prefix="Error building extension '{}'".format(name))
error_prefix=f"Error building extension '{name}'")


def is_ninja_available():
Expand Down Expand Up @@ -1342,10 +1340,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
extra_ldflags.append('torch.lib')
extra_ldflags.append('torch_python.lib')
extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path))
extra_ldflags.append('/LIBPATH:{}'.format(lib_path))
extra_ldflags.append(f'/LIBPATH:{python_lib_path}')
extra_ldflags.append(f'/LIBPATH:{lib_path}')
else:
extra_ldflags.append('-L{}'.format(lib_path))
extra_ldflags.append(f'-L{lib_path}')
extra_ldflags.append('-lc10')
if with_cuda:
extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
Expand All @@ -1359,19 +1357,18 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
if verbose:
print('Detected CUDA files, patching ldflags')
if IS_WINDOWS:
extra_ldflags.append('/LIBPATH:{}'.format(
_join_cuda_home('lib/x64')))
extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}')
extra_ldflags.append('cudart.lib')
if CUDNN_HOME is not None:
extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64'))
elif not IS_HIP_EXTENSION:
extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64')))
extra_ldflags.append(f'-L{_join_cuda_home("lib64")}')
extra_ldflags.append('-lcudart')
if CUDNN_HOME is not None:
extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64')))
extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
elif IS_HIP_EXTENSION:
assert ROCM_VERSION is not None
extra_ldflags.append('-L{}'.format(_join_rocm_home('lib')))
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc')
return extra_ldflags

Expand Down Expand Up @@ -1421,7 +1418,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
# If not given, determine what's needed for the GPU that can be found
if not _arch_list:
capability = torch.cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])]
arch_list = [f'{capability[0]}.{capability[1]}']
else:
# Deal with lists that are ' ' separated (only deal with ';' after)
_arch_list = _arch_list.replace(' ', ';')
Expand All @@ -1434,12 +1431,12 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
flags = []
for arch in arch_list:
if arch not in valid_arch_strings:
raise ValueError("Unknown CUDA arch ({}) or GPU not supported".format(arch))
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
else:
num = arch[0] + arch[2]
flags.append('-gencode=arch=compute_{},code=sm_{}'.format(num, num))
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if arch.endswith('+PTX'):
flags.append('-gencode=arch=compute_{},code=compute_{}'.format(num, num))
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')

return list(set(flags))

Expand All @@ -1466,8 +1463,7 @@ def _get_build_directory(name: str, verbose: bool) -> str:
root_extensions_directory = get_default_build_root()

if verbose:
print('Using {} as PyTorch extensions root...'.format(
root_extensions_directory))
print(f'Using {root_extensions_directory} as PyTorch extensions root...')

build_directory = os.path.join(root_extensions_directory, name)
if not os.path.exists(build_directory):
Expand All @@ -1483,7 +1479,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
max_jobs = os.environ.get('MAX_JOBS')
if max_jobs is not None and max_jobs.isdigit():
if verbose:
print('Using envvar MAX_JOBS ({}) as the number of workers...'.format(max_jobs))
print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...')
return int(max_jobs)
if verbose:
print('Allowing ninja to set a default number of workers... '
Expand Down Expand Up @@ -1550,7 +1546,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
# `error` is a CalledProcessError (which has an `ouput`) attribute, but
# mypy thinks it's Optional[BaseException] and doesn't narrow
if hasattr(error, 'output') and error.output: # type: ignore
message += ": {}".format(error.output.decode()) # type: ignore
message += f": {error.output.decode()}" # type: ignore
raise RuntimeError(message) from e


Expand Down Expand Up @@ -1592,10 +1588,10 @@ def _write_ninja_file_to_build_library(path,
user_includes += system_includes
system_includes.clear()

common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)]
common_cflags = [f'-DTORCH_EXTENSION_NAME={name}']
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
common_cflags += ['-I{}'.format(include) for include in user_includes]
common_cflags += ['-isystem {}'.format(include) for include in system_includes]
common_cflags += [f'-I{include}' for include in user_includes]
common_cflags += [f'-isystem {include}' for include in system_includes]

common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]

Expand Down Expand Up @@ -1639,9 +1635,9 @@ def object_file_path(source_file: str) -> str:
if _is_cuda_file(source_file) and with_cuda:
# Use a different object filename in case a C++ and CUDA file have
# the same filename but different extension (.cpp vs. .cu).
target = '{}.cuda.o'.format(file_name)
target = f'{file_name}.cuda.o'
else:
target = '{}.o'.format(file_name)
target = f'{file_name}.o'
return target

objects = [object_file_path(src) for src in sources]
Expand All @@ -1657,7 +1653,7 @@ def object_file_path(source_file: str) -> str:
ldflags = _nt_quote_args(ldflags)

ext = 'pyd' if IS_WINDOWS else 'so'
library_target = '{}.{}'.format(name, ext)
library_target = f'{name}.{ext}'

_write_ninja_file(
path=path,
Expand Down Expand Up @@ -1719,20 +1715,20 @@ def sanitize_flags(flags):

# Version 1.3 is required for the `deps` directive.
config = ['ninja_required_version = 1.3']
config.append('cxx = {}'.format(compiler))
config.append(f'cxx = {compiler}')
if with_cuda:
if IS_HIP_EXTENSION:
nvcc = _join_rocm_home('bin', 'hipcc')
else:
nvcc = _join_cuda_home('bin', 'nvcc')
config.append('nvcc = {}'.format(nvcc))
config.append(f'nvcc = {nvcc}')

flags = ['cflags = {}'.format(' '.join(cflags))]
flags.append('post_cflags = {}'.format(' '.join(post_cflags)))
flags = [f'cflags = {" ".join(cflags)}']
flags.append(f'post_cflags = {" ".join(post_cflags)}')
if with_cuda:
flags.append('cuda_cflags = {}'.format(' '.join(cuda_cflags)))
flags.append('cuda_post_cflags = {}'.format(' '.join(cuda_post_cflags)))
flags.append('ldflags = {}'.format(' '.join(ldflags)))
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
flags.append(f'ldflags = {" ".join(ldflags)}')

# Turn into absolute paths so we can emit them into the ninja build
# file wherever it is.
Expand Down Expand Up @@ -1765,7 +1761,7 @@ def sanitize_flags(flags):
object_file = object_file.replace(':', '$:')
source_file = source_file.replace(" ", "$ ")
object_file = object_file.replace(" ", "$ ")
build.append('build {}: {} {}'.format(object_file, rule, source_file))
build.append(f'build {object_file}: {rule} {source_file}')

if library_target is not None:
link_rule = ['rule link']
Expand All @@ -1776,15 +1772,13 @@ def sanitize_flags(flags):
cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
else:
raise RuntimeError("MSVC is required to load C++ extensions")
link_rule.append(
' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format(
cl_path))
link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')
else:
link_rule.append(' command = $cxx $in $ldflags -o $out')

link = ['build {}: link {}'.format(library_target, ' '.join(objects))]
link = [f'build {library_target}: link {" ".join(objects)}']

default = ['default {}'.format(library_target)]
default = [f'default {library_target}']
else:
link_rule, link, default = [], [], []

Expand All @@ -1796,7 +1790,7 @@ def sanitize_flags(flags):
with open(path, 'w') as build_file:
for block in blocks:
lines = '\n'.join(block)
build_file.write('{}\n\n'.format(lines))
build_file.write(f'{lines}\n\n')


def _join_cuda_home(*paths) -> str:
Expand Down

0 comments on commit 42a5114

Please sign in to comment.