Skip to content

Commit

Permalink
[AOTI] support freezing on CPU
Browse files Browse the repository at this point in the history
ghstack-source-id: cbc0a48c92b62c560599982d419623b9294155b9
Pull Request resolved: #124350
  • Loading branch information
chunyuan-w committed Apr 18, 2024
1 parent 6dbad09 commit 6285963
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 40 deletions.
27 changes: 27 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>

#if AT_MKLDNN_ENABLED()

Expand Down Expand Up @@ -61,6 +62,18 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
}
}

int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return reinterpret_cast<int64_t>(data_ptr);
}

void* data_ptr_from_mkldnn_aot(at::Tensor* mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor->unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return data_ptr;
}

Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional<ScalarType> dtype, c10::optional<Device> device) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
Expand All @@ -81,6 +94,11 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
return mklimpl->unsafe_opaque_handle()->get_target();
}

int64_t data_size_from_mkldnn(const Tensor& mkldnn_tensor) {
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
return t.get_desc().get_size();
}

ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
TORCH_CHECK(
tensor.device().is_cpu(),
Expand Down Expand Up @@ -167,6 +185,15 @@ int set_verbose(int level) {
return ideep::utils::set_verbose(level);
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
TORCH_FN(data_ptr_from_mkldnn));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_data_size"),
TORCH_FN(data_size_from_mkldnn));
}

}}

#endif // AT_MKLDNN_ENABLED()
5 changes: 5 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
return get_mkldnn_dtype(t.scalar_type());
}

TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);

// TODO: unify with data_ptr_from_mkldnn
TORCH_API void* data_ptr_from_mkldnn_aot(at::Tensor* mkldnn_tensor);

// Construct aten MKL-DNN tensor given an ideep tensor
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional<ScalarType> dtype, c10::optional<Device> device);

Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
Expand Down Expand Up @@ -477,6 +479,23 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
return {packed_w1, packed_w2};
}

static Tensor mkldnn_serialize(const Tensor& self) {
const ideep::tensor packed_w = itensor_from_tensor(self);
auto packed_w_desc = packed_w.get_desc();

// TODO: test ideep versioning
#if IDEEP_PREREQ(3, 4, 1, 2)
std::vector<uint8_t> serialized_wei_desc = packed_w_desc.get_blob();
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
#endif
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
auto res = at::empty_like(serialized_md);
// TODO: a copy is needed here so that from_blob won't be released
res.copy_(serialized_md);
return res;
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -492,6 +511,12 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_mkldnn_serialize"),
TORCH_FN(mkldnn_serialize));
}

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ TORCH_LIBRARY(mkldnn, m) {
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_mkldnn_serialize(Tensor mkldnn_tensor) -> Tensor");
// TODO: data_size should be size_t, but seems does not work in schema here. Fix the int return dtype
m.def("mkldnn::_data_size(Tensor mkldnn_tensor) -> int");
}

TORCH_LIBRARY(mkldnn_prepacked, m) {
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/opaque_tensor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

Expand Down
87 changes: 66 additions & 21 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_model(
disable_constraint_solver,
)

self.assertTrue(same(actual, expected))
self.assertEqual(actual, expected)


def check_model_with_multiple_inputs(
Expand Down Expand Up @@ -306,24 +306,75 @@ def forward(self, x, y):
)
self.check_model(Model(self.device), example_inputs)

# TODO: unify freezing test into one function
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(9, 10, device=device)
self.padding = torch.randn(1, 10, device=device)
for dtype in [torch.float32, torch.bfloat16]:

def forward(self, x, y):
padded_weight = torch.cat((self.weight, self.padding), dim=0)
return x + torch.nn.functional.linear(y, padded_weight)
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)
# self.padding = torch.randn(1, 512, device=device).to(dtype)

example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
def forward(self, y):
# padded_weight = torch.cat((self.weight, self.padding), dim=0)
return torch.nn.functional.linear(y, self.weight)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
example_inputs = (
# torch.randn(10, 10, device=self.device).to(dtype),
torch.randn(10, 10, device=self.device).to(dtype),
)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

def test_conv_freezing(self):
import itertools

for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
iC = 2
oC = 3

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)

example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

def test_deconv_freezing(self):
import itertools

for dtype, groups in itertools.product([torch.bfloat16, torch.float], [2, 1]):
iC = 2
oC = 3

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 3, 3, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)

example_inputs = (torch.randn(2, iC, 10, 10, device=self.device).to(dtype),)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

def test_simple_split(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -2688,9 +2739,6 @@ def fail_non_abi_compatible_cuda(is_skip=False):
"test_dynamic_smem_above_default_limit": fail_with_and_without_stack_allocation(),
# https://github.com/pytorch/pytorch/issues/122980
"test_fft_c2c": fail_stack_allocation(is_skip=True),
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
# minimal arrayref interface only works with CPU; test crashes.
Expand Down Expand Up @@ -2928,9 +2976,6 @@ class AOTInductorTestNonABICompatibleCpu(TestCase):
"test_dynamic_smem_above_default_limit": TestFailure(
("non_abi_compatible_cpu",)
),
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
"test_runtime_checks_dtype_failed": TestFailure(
Expand Down
12 changes: 11 additions & 1 deletion torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,9 @@ def _compile_consts_darwin(consts: bytes) -> str:

output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
tensor.untyped_storage().nbytes()
torch.ops.mkldnn._data_size(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
for (name, tensor) in graph.constants.items()
if name not in graph.folded_constants
)
Expand Down Expand Up @@ -1842,6 +1844,13 @@ def _to_bytes(t: torch.Tensor) -> bytes:
if t.numel() == 0:
return b""

if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._data_size(t)),
)
return bytes(raw_array.contents)

t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),
Expand All @@ -1851,6 +1860,7 @@ def _to_bytes(t: torch.Tensor) -> bytes:
return bytes(raw_array.contents)

serialized_weights = b"".join(
# TODO: use size from opaque tensor for prepacked weight?
_to_bytes(graph.get_original_value_of_constant(name))
for name in graph.constants.keys()
if name not in graph.folded_constants
Expand Down
10 changes: 9 additions & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@
"cuda": "at::kCUDA",
}

LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn",
}

INDEX_TYPE = "long"

NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"}
Expand Down Expand Up @@ -2138,7 +2143,10 @@ def codegen_loops(self, code, worksharing):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
return "AOTI_TORCH_CHECK"
# return "AOTI_TORCH_CHECK"
# Using AOTI_TORCH_CHECK on some model is causing performance drop.
# TODO: add an option here to don't always make it noinline?
return "TORCH_CHECK"
else:
return "TORCH_CHECK"

Expand Down
31 changes: 28 additions & 3 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,14 @@ def codegen_model_constructor(self):
self.prefix.writeline(
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
)
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)
if tensor.is_mkldnn:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._data_size(tensor)};"
)
else:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)
from_folded = "true" if name in V.graph.folded_constants else "false"
self.prefix.writeline(
f"constants_info_[{idx}].from_folded = {from_folded};"
Expand All @@ -734,6 +739,21 @@ def codegen_model_constructor(self):
self.prefix.writeline(
f"constants_info_[{idx}].stride = {{{stride_str}}};"
)
self.prefix.writeline(
f"constants_info_[{idx}].layout = static_cast<int8_t>({self.codegen_layout(tensor.layout)});"
)

if tensor.is_mkldnn:
serialized_tensor = torch.ops.mkldnn._mkldnn_serialize(tensor)
assert (
serialized_tensor.dim() == 1
), "Expect serialized_tensor to be 1-D"

serialized_list = serialized_tensor.tolist()
serialized_list_str = self.codegen_shape_tuple(serialized_list)
self.prefix.writeline(
f"constants_info_[{idx}].serialized_md = {serialized_list_str};"
)
if name in V.graph.dynamo_flat_name_to_original_fqn:
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
name, name
Expand Down Expand Up @@ -1434,6 +1454,11 @@ def codegen_dtype(self, dtype):

return DTYPE_TO_ATEN[dtype]

def codegen_layout(self, layout):
from .cpp import LAYOUT_TO_ATEN

return LAYOUT_TO_ATEN[layout]

@functools.lru_cache(None)
def codegen_int_array_var(
self,
Expand Down
Loading

0 comments on commit 6285963

Please sign in to comment.