Skip to content

Commit

Permalink
[AOTI] support freezing for MKLDNN (#124350)
Browse files Browse the repository at this point in the history
## Description
Fixes #114450. This PR builds upon the work from @imzhuhl done in #114451.

This PR requires #122472 to land firstly.

We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so.
ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time.

### Test plan:
```sh
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu
```

### TODOs in follow-up PRs
1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in
 #119220).
2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`.
https://github.com/pytorch/pytorch/blob/6c4f43f82675b5fcfe8cf3e5983d0c0f326408aa/torch/_inductor/codegen/cpp_wrapper_cpu.py#L2023-L2024

Pull Request resolved: #124350
Approved by: https://github.com/jgong5, https://github.com/desertfire
  • Loading branch information
chunyuan-w authored and pytorchmergebot committed May 25, 2024
1 parent e7a4270 commit 4a997de
Show file tree
Hide file tree
Showing 18 changed files with 452 additions and 22 deletions.
42 changes: 42 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>

#if AT_MKLDNN_ENABLED()

Expand Down Expand Up @@ -61,6 +62,33 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
}
}

int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return reinterpret_cast<int64_t>(data_ptr);
}

at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
std::vector<uint8_t> vector_serialized_md{
opaque_metadata, opaque_metadata + opaque_metadata_size};
ideep::tensor::desc deserialized_ideep_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
// groups is needed for grouped conv
deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
#endif

auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
}

Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
Expand All @@ -81,6 +109,11 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
return mklimpl->unsafe_opaque_handle()->get_target();
}

int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
return t.get_desc().get_size();
}

ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
TORCH_CHECK(
tensor.device().is_cpu(),
Expand Down Expand Up @@ -167,6 +200,15 @@ int set_verbose(int level) {
return ideep::utils::set_verbose(level);
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
TORCH_FN(data_ptr_from_mkldnn));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
TORCH_FN(nbytes_from_mkldnn));
}

}}

#endif // AT_MKLDNN_ENABLED()
12 changes: 12 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,24 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
return get_mkldnn_dtype(t.scalar_type());
}

TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);

TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);

// Construct aten MKL-DNN tensor given an ideep tensor
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);

// Retrieve `ideep::tensor` from MKL-DNN tensor
TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);

TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);

// Construct an `ideep::tensor` "view" from dense tensor, note the
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
Expand Down Expand Up @@ -508,6 +510,25 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
return {packed_w1, packed_w2};
}

static Tensor get_mkldnn_serialized_md(const Tensor& self) {
const ideep::tensor packed_w = itensor_from_tensor(self);
auto packed_w_desc = packed_w.get_desc();
std::vector<uint8_t> serialized_wei_desc;

#if IDEEP_PREREQ(3, 4, 1, 2)
serialized_wei_desc = packed_w_desc.get_blob();
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
#endif
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
auto res = at::empty_like(serialized_md);
// serialized_md shares the buffer with serialized_wei_desc,
// which will be released outside of this function thus invalidating the buffer of serialized_md.
// A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
res.copy_(serialized_md);
return res;
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -523,6 +544,12 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
TORCH_FN(get_mkldnn_serialized_md ));
}

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ TORCH_LIBRARY(mkldnn, m) {
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
}

TORCH_LIBRARY(mkldnn_prepacked, m) {
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

Expand Down
103 changes: 96 additions & 7 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import copy
import itertools
import os
import sys
import tempfile
Expand Down Expand Up @@ -89,6 +90,8 @@ def check_model(
options=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
):
with torch.no_grad(), config.patch(
{
Expand All @@ -114,7 +117,7 @@ def check_model(
disable_constraint_solver,
)

self.assertTrue(same(actual, expected))
self.assertEqual(actual, expected, atol=atol, rtol=rtol)


def check_model_with_multiple_inputs(
Expand Down Expand Up @@ -312,6 +315,10 @@ def forward(self, x, y):
)
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
Expand All @@ -331,6 +338,80 @@ def forward(self, x, y):
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_conv_freezing(self):
for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
iC = 2
oC = 3

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)

example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_deconv_freezing(self):
dtypes = [torch.float]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype, groups in itertools.product(dtypes, [2, 1]):
iC = 4
oC = 2

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)

example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_linear_freezing(self):
for dtype in [torch.float32, torch.bfloat16]:

class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)

def forward(self, y):
return torch.nn.functional.linear(y, self.weight)

example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)

with config.patch({"freezing": True}):
self.check_model(LinearModel(self.device), example_inputs)

@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
Expand Down Expand Up @@ -1390,7 +1471,9 @@ def forward(self, x, y):
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
self.check_model(Model(), example_inputs)
self.check_model(
Model(), example_inputs, atol=1e-4, rtol=1e-4
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py

if self.device == "cuda":
so_path = torch._export.aot_compile(Model(), example_inputs)
Expand Down Expand Up @@ -2872,6 +2955,12 @@ def fail_non_abi_compatible_cuda(is_skip=False):
# test_failures, xfail by default, set is_skip=True to skip
CPU_TEST_FAILURES = {
"test_add_complex": fail_stack_allocation(is_skip=True),
# TODO: test_conv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_deconv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
is_skip=True
Expand All @@ -2885,9 +2974,12 @@ def fail_non_abi_compatible_cuda(is_skip=False):
"test_dynamic_scalar": fail_stack_allocation(is_skip=True),
# https://github.com/pytorch/pytorch/issues/122980
"test_fft_c2c": fail_stack_allocation(is_skip=True),
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
# TODO: test_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_linear_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
# minimal arrayref interface only works with CPU; test crashes.
Expand Down Expand Up @@ -3127,9 +3219,6 @@ class AOTInductorTestNonABICompatibleCpu(TestCase):
"test_duplicate_constant_folding": TestFailure(
("non_abi_compatible_cpu",), is_skip=True
),
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
"test_runtime_checks_dtype_failed": TestFailure(
Expand Down
17 changes: 15 additions & 2 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,10 @@ def use_custom_generated_macros() -> str:

def use_fb_internal_macros() -> str:
if config.is_fbcode():
# TODO: this is to avoid FC breakage for fbcode. When using newly
# generated model.so on an older verion of PyTorch, need to use
# the v1 version for aoti_torch_create_tensor_from_blob
create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
openmp_lib = build_paths.openmp_lib()
preprocessor_flags = " ".join(
(
Expand All @@ -1529,7 +1533,7 @@ def use_fb_internal_macros() -> str:
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
)
)
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
else:
return ""

Expand Down Expand Up @@ -2075,7 +2079,9 @@ def _compile_consts_darwin(consts: bytes) -> str:

output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
tensor.untyped_storage().nbytes()
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
for (name, tensor) in graph.constants.items()
if name not in graph.folded_constants
)
Expand Down Expand Up @@ -2108,6 +2114,13 @@ def _to_bytes(t: torch.Tensor) -> bytes:
if t.numel() == 0:
return b""

if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
)
return bytes(raw_array.contents)

t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,8 @@ def codegen_loops(self, code, worksharing):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
# compared with JIT Inductor which uses TORCH_CHECK
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/cpp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
"cuda": "at::kCUDA",
}

LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
}

INDEX_TYPE = "long"

GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
Expand Down
Loading

0 comments on commit 4a997de

Please sign in to comment.