Skip to content

Commit

Permalink
[inductor] Slightly faster memory allocation on CUDA (#118255)
Browse files Browse the repository at this point in the history
Pull Request resolved: #118255
Approved by: https://github.com/peterbell10
ghstack dependencies: #118065, #118070, #118171
  • Loading branch information
jansel authored and pytorchmergebot committed Jan 25, 2024
1 parent 3e76a0e commit 2de24c1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 44 deletions.
22 changes: 11 additions & 11 deletions test/distributed/test_inductor_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def func(inp, *, tag, ranks, group_size):
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf0 = empty") \
.check("buf0.copy_(arg0_1)") \
.check("buf1 = buf0") \
.check("buf1_work = dist.all_reduce(buf1") \
Expand Down Expand Up @@ -625,8 +625,8 @@ def func(inp, *, tag, ranks, group_size):
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf5 = empty(") \
.check("buf0 = empty") \
.check("buf5 = empty") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check_not("copy_(") \
.check("buf1 = buf0; del buf0 # reuse") \
Expand Down Expand Up @@ -942,11 +942,11 @@ def func(inp, *, tag, ranks, group_size):
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf5 = empty(") \
.check("buf0 = empty") \
.check("buf5 = empty") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check("buf1 = empty(") \
.check("buf2 = empty(") \
.check("buf1 = empty") \
.check("buf2 = empty") \
.check_not("copy_(") \
.check("buf3_inputs = [buf0,arg0_1]") \
.check("buf3 = [buf1,buf2]") \
Expand Down Expand Up @@ -988,11 +988,11 @@ def func(inp, *, tag, ranks, group_size):
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
FileCheck() \
.check("buf0 = empty(") \
.check("buf5 = empty(") \
.check("buf0 = empty") \
.check("buf5 = empty") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check("buf1 = empty(") \
.check("buf2 = empty(") \
.check("buf1 = empty") \
.check("buf2 = empty") \
.check_not("copy_(") \
.check("buf3 = [buf1,buf2]") \
.check("buf3_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_python_wrapper(self):
result, code = run_and_get_cpp_code(compiled, *args)

FileCheck().check(
"pool1 = empty_strided(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
"pool1 = empty_strided_cuda(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
).check(
Expand Down
32 changes: 11 additions & 21 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def write_header(self):
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch import device, empty_strided
from {codecache.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
Expand All @@ -439,6 +439,7 @@ def write_header(self):
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
Expand Down Expand Up @@ -1163,32 +1164,21 @@ def make_buffer_allocation(self, buffer):
return self.make_allocation(buffer.get_name(), device, dtype, shape, stride)

def make_allocation(self, name, device, dtype, shape, stride):
if device.type == "cpu":
if device.type in ("cpu", "cuda"):
# optimized path for faster allocations, saving ~2us versus the stuff below
return (
f"{name} = empty_strided_cpu("
f"{name} = empty_strided_{device.type}("
f"{self.codegen_shape_tuple(shape)}, "
f"{self.codegen_shape_tuple(stride)}, "
f"{dtype})"
)

try:
expected = tuple(ir.make_contiguous_strides_for(shape))
except Exception: # cannot determine truth value of Relational
expected = None
if stride == expected:
return (
f"{name} = empty("
f"{self.codegen_shape_tuple(shape)}, "
f"device='{device.type}', dtype={dtype})"
)
else:
return (
f"{name} = empty_strided("
f"{self.codegen_shape_tuple(shape)}, "
f"{self.codegen_shape_tuple(stride)}, "
f"device='{device.type}', dtype={dtype})"
)
# all other devices:
return (
f"{name} = empty_strided("
f"{self.codegen_shape_tuple(shape)}, "
f"{self.codegen_shape_tuple(stride)}, "
f"device='{device.type}', dtype={dtype})"
)

def make_tensor_alias(self, new_name, old_name, comment=""):
return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
Expand Down
48 changes: 37 additions & 11 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/extension.h>

#ifdef USE_CUDA
#include <ATen/cuda/EmptyTensor.h>
#endif

#include <sstream>

namespace {
Expand Down Expand Up @@ -598,27 +603,47 @@ inline static void unwrap_size_tuple(PyObject* obj, T& output) {
}
}

static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is a lower-overhead
// version that saves ~2us on every allocation.
HANDLE_TH_ERRORS;

template <typename T>
inline static void _parse_empty_strided_args(
PyObject* args,
T& sizes,
T& strides,
at::ScalarType& dtype) {
TORCH_CHECK(PyTuple_CheckExact(args));
TORCH_CHECK(PyTuple_GET_SIZE(args) == 3);

// note PyTuple_GET_ITEM returns a borrowed ref, so no need for refcounts
at::SmallVector<int64_t, 8> sizes;
unwrap_size_tuple(PyTuple_GET_ITEM(args, 0), sizes);

at::SmallVector<int64_t, 8> strides;
unwrap_size_tuple(PyTuple_GET_ITEM(args, 1), strides);

PyObject* py_dtype = PyTuple_GET_ITEM(args, 2);
TORCH_CHECK(THPDtype_Check(py_dtype));
at::ScalarType dtype = reinterpret_cast<THPDtype*>(py_dtype)->scalar_type;
dtype = reinterpret_cast<THPDtype*>(py_dtype)->scalar_type;
}

static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is a lower-overhead
// version that saves ~2us on every allocation.
HANDLE_TH_ERRORS;
at::SmallVector<int64_t, 8> sizes;
at::SmallVector<int64_t, 8> strides;
at::ScalarType dtype;
_parse_empty_strided_args(args, sizes, strides, dtype);
return THPVariable_Wrap(at::detail::empty_strided_cpu(sizes, strides, dtype));
END_HANDLE_TH_ERRORS;
}

static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is lower-overhead.
HANDLE_TH_ERRORS;
#ifdef USE_CUDA
at::SmallVector<int64_t, 8> sizes;
at::SmallVector<int64_t, 8> strides;
at::ScalarType dtype;
_parse_empty_strided_args(args, sizes, strides, dtype);
return THPVariable_Wrap(at::detail::empty_strided_cuda(
sizes, strides, dtype, c10::DeviceType::CUDA));
#else
TORCH_CHECK(false, "PyTorch compiled without USE_CUDA");
#endif
END_HANDLE_TH_ERRORS;
}

Expand All @@ -629,6 +654,7 @@ static PyMethodDef _methods[] = {
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
{"dict_version", dict_version, METH_VARARGS, nullptr},
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};

static struct PyModuleDef _module = {
Expand Down

0 comments on commit 2de24c1

Please sign in to comment.