Skip to content

Commit

Permalink
Revert "[AOTI] support freezing for MKLDNN (#124350)"
Browse files Browse the repository at this point in the history
This reverts commit 654afb6.

Reverted #124350 on behalf of https://github.com/clee2000 due to Seems to have broken inductor/test_aot_inductor.py::AOTInductorTestNonABICompatibleCpu::test_freezing_non_abi_compatible_cpu https://hud.pytorch.org/pytorch/pytorch/commit/654afb6f3ae3ddbd926a753f9af95a6f6e22131c https://github.com/pytorch/pytorch/actions/runs/9224838183/job/25382780192 ([comment](#124350 (comment)))
  • Loading branch information
pytorchmergebot committed May 24, 2024
1 parent 2ac739c commit 5ae9daa
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 447 deletions.
42 changes: 0 additions & 42 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>

#if AT_MKLDNN_ENABLED()

Expand Down Expand Up @@ -62,33 +61,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
}
}

int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return reinterpret_cast<int64_t>(data_ptr);
}

at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
std::vector<uint8_t> vector_serialized_md{
opaque_metadata, opaque_metadata + opaque_metadata_size};
ideep::tensor::desc deserialized_ideep_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
// groups is needed for grouped conv
deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
#endif

auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
}

Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
Expand All @@ -109,11 +81,6 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
return mklimpl->unsafe_opaque_handle()->get_target();
}

int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
return t.get_desc().get_size();
}

ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
TORCH_CHECK(
tensor.device().is_cpu(),
Expand Down Expand Up @@ -200,15 +167,6 @@ int set_verbose(int level) {
return ideep::utils::set_verbose(level);
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
TORCH_FN(data_ptr_from_mkldnn));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
TORCH_FN(nbytes_from_mkldnn));
}

}}

#endif // AT_MKLDNN_ENABLED()
12 changes: 0 additions & 12 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,12 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
return get_mkldnn_dtype(t.scalar_type());
}

TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);

TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);

// Construct aten MKL-DNN tensor given an ideep tensor
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);

// Retrieve `ideep::tensor` from MKL-DNN tensor
TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);

TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);

// Construct an `ideep::tensor` "view" from dense tensor, note the
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);
Expand Down
27 changes: 0 additions & 27 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
Expand Down Expand Up @@ -510,25 +508,6 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
return {packed_w1, packed_w2};
}

static Tensor get_mkldnn_serialized_md(const Tensor& self) {
const ideep::tensor packed_w = itensor_from_tensor(self);
auto packed_w_desc = packed_w.get_desc();
std::vector<uint8_t> serialized_wei_desc;

#if IDEEP_PREREQ(3, 4, 1, 2)
serialized_wei_desc = packed_w_desc.get_blob();
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
#endif
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
auto res = at::empty_like(serialized_md);
// serialized_md shares the buffer with serialized_wei_desc,
// which will be released outside of this function thus invalidating the buffer of serialized_md.
// A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
res.copy_(serialized_md);
return res;
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -544,12 +523,6 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
TORCH_FN(get_mkldnn_serialized_md ));
}

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ TORCH_LIBRARY(mkldnn, m) {
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
}

TORCH_LIBRARY(mkldnn_prepacked, m) {
Expand Down
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

Expand Down
103 changes: 7 additions & 96 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Owner(s): ["module: inductor"]
import copy
import itertools
import os
import sys
import tempfile
Expand Down Expand Up @@ -90,8 +89,6 @@ def check_model(
options=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
):
with torch.no_grad(), config.patch(
{
Expand All @@ -117,7 +114,7 @@ def check_model(
disable_constraint_solver,
)

self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertTrue(same(actual, expected))


def check_model_with_multiple_inputs(
Expand Down Expand Up @@ -315,10 +312,6 @@ def forward(self, x, y):
)
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
Expand All @@ -338,80 +331,6 @@ def forward(self, x, y):
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_conv_freezing(self):
for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
iC = 2
oC = 3

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)

example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)

with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_deconv_freezing(self):
dtypes = [torch.float]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype, groups in itertools.product(dtypes, [2, 1]):
iC = 4
oC = 2

class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
dtype
)

def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)

example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_linear_freezing(self):
for dtype in [torch.float32, torch.bfloat16]:

class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)

def forward(self, y):
return torch.nn.functional.linear(y, self.weight)

example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)

with config.patch({"freezing": True}):
self.check_model(LinearModel(self.device), example_inputs)

@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
Expand Down Expand Up @@ -1471,9 +1390,7 @@ def forward(self, x, y):
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
self.check_model(
Model(), example_inputs, atol=1e-4, rtol=1e-4
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
self.check_model(Model(), example_inputs)

if self.device == "cuda":
so_path = torch._export.aot_compile(Model(), example_inputs)
Expand Down Expand Up @@ -2955,12 +2872,6 @@ def fail_non_abi_compatible_cuda(is_skip=False):
# test_failures, xfail by default, set is_skip=True to skip
CPU_TEST_FAILURES = {
"test_add_complex": fail_stack_allocation(is_skip=True),
# TODO: test_conv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_deconv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
is_skip=True
Expand All @@ -2974,12 +2885,9 @@ def fail_non_abi_compatible_cuda(is_skip=False):
"test_dynamic_scalar": fail_stack_allocation(is_skip=True),
# https://github.com/pytorch/pytorch/issues/122980
"test_fft_c2c": fail_stack_allocation(is_skip=True),
# TODO: test_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_linear_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
# minimal arrayref interface only works with CPU; test crashes.
Expand Down Expand Up @@ -3221,6 +3129,9 @@ class AOTInductorTestNonABICompatibleCpu(TestCase):
"test_duplicate_constant_folding": TestFailure(
("non_abi_compatible_cpu",), is_skip=True
),
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
"test_runtime_checks_dtype_failed": TestFailure(
Expand Down
17 changes: 2 additions & 15 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,10 +1522,6 @@ def use_custom_generated_macros() -> str:

def use_fb_internal_macros() -> str:
if config.is_fbcode():
# TODO: this is to avoid FC breakage for fbcode. When using newly
# generated model.so on an older verion of PyTorch, need to use
# the v1 version for aoti_torch_create_tensor_from_blob
create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
openmp_lib = build_paths.openmp_lib()
preprocessor_flags = " ".join(
(
Expand All @@ -1534,7 +1530,7 @@ def use_fb_internal_macros() -> str:
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
)
)
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
else:
return ""

Expand Down Expand Up @@ -2080,9 +2076,7 @@ def _compile_consts_darwin(consts: bytes) -> str:

output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
tensor.untyped_storage().nbytes()
for (name, tensor) in graph.constants.items()
if name not in graph.folded_constants
)
Expand Down Expand Up @@ -2115,13 +2109,6 @@ def _to_bytes(t: torch.Tensor) -> bytes:
if t.numel() == 0:
return b""

if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
)
return bytes(raw_array.contents)

t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),
Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,8 +1971,6 @@ def codegen_loops(self, code, worksharing):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
# compared with JIT Inductor which uses TORCH_CHECK
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"
Expand Down
5 changes: 0 additions & 5 deletions torch/_inductor/codegen/cpp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@
"cuda": "at::kCUDA",
}

LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
}

INDEX_TYPE = "long"

GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
Expand Down
Loading

0 comments on commit 5ae9daa

Please sign in to comment.