Skip to content

Commit

Permalink
Update on "[ONNX] Diagnostic 'log' and 'log_and_raise_if_error'"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
titaiwangms committed May 15, 2023
2 parents 54957f9 + e3488f3 commit 0139807
Show file tree
Hide file tree
Showing 34 changed files with 710 additions and 252 deletions.
5 changes: 0 additions & 5 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <mutex>
#include <condition_variable>
#include <type_traits>
#include <c10/core/SafePyObject.h>

#include <ATen/core/grad_mode.h>
#include <ATen/core/enum_tag.h>
Expand Down Expand Up @@ -391,10 +390,6 @@ class TORCH_API OperatorHandle {
return operatorDef_->op.getTags();
}

void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
operatorDef_->op.setReportErrorCallback_(std::move(callback));
}

bool hasTag(const at::Tag& tag) const {
for(const auto& tag_: getTags()) {
if (tag == tag_) {
Expand Down
9 changes: 0 additions & 9 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,6 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const {
// If there is an invariant problem, report it now.
checkInvariants();

if (report_error_callback_ != nullptr) {
report_error_callback_->pyinterpreter()->reportErrorCallback(report_error_callback_->ptr(&report_error_callback_->pyinterpreter()), dispatchKey);
// reportErrorCallback should have raised an error
TORCH_INTERNAL_ASSERT(false);
}
if (dispatchKey == DispatchKey::Undefined) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"There were no tensor arguments to this function (e.g., you passed an "
Expand Down Expand Up @@ -579,10 +574,6 @@ std::string OperatorEntry::dumpComputedTable() const {
return oss.str();
}

void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
report_error_callback_ = std::move(callback);
}

// Inspect the "canonical" information in OperatorEntry. This only prints out
// *non-derived* information including kernels registered to alias dispatch keys;
// i.e., what the source of truth says about the operator. This dumping function
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <c10/util/Optional.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/PyHandleCache.h>
#include <c10/core/SafePyObject.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
Expand Down Expand Up @@ -212,7 +211,6 @@ class TORCH_API OperatorEntry final {
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);

template <typename F>
PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
Expand Down Expand Up @@ -288,9 +286,6 @@ class TORCH_API OperatorEntry final {
c10::optional<CppSignatureWithDebug> cpp_signature_;
c10::optional<CppSignatureWithDebug> sym_cpp_signature_;

// A Python custom error handler for OperatorEntry::reportError
std::unique_ptr<c10::SafePyObject> report_error_callback_;

// Whether this operator needs to be observed with RecordFunction
const bool is_observed_;

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/MaxPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static void check_max_pool1d(
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);

const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
}

} // namespace
Expand Down
4 changes: 0 additions & 4 deletions c10/core/impl/PyInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
PANIC(dispatch);
}

void reportErrorCallback(PyObject* callback, DispatchKey key) const override {
PANIC(reportErrorCallback);
}

void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey,
Expand Down
3 changes: 0 additions & 3 deletions c10/core/impl/PyInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ struct C10_API PyInterpreterVTable {
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const = 0;

virtual void reportErrorCallback(PyObject* callback, DispatchKey key)
const = 0;

// This is only invoked in the multipy/torchdeploy situation from
// pythonOpRegistrationTrampoline; this lets us get to the Python
// interpreter to actually find the appropriate Python op registration
Expand Down
76 changes: 76 additions & 0 deletions test/cpp_extensions/open_registration_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include <unordered_map>
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>

#include <torch/csrc/Device.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <torch/extension.h>
Expand Down Expand Up @@ -44,6 +46,68 @@ namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);

} // namespace at::native
struct CustomBackendMetadata : public c10::BackendMeta {
// for testing this field will mutate when clone() is called by shallow_copy_from.
int backend_version_format_{-1};
int format_number_{-1};
mutable bool cloned_{false};
// define the constructor
CustomBackendMetadata(int backend_version_format, int format_number): backend_version_format_(backend_version_format), format_number_(format_number) {}
c10::intrusive_ptr<c10::BackendMeta> clone(const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};

// we need to register two functions for serialization
void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
return;
}
CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_ == 1) {
m["backend_version_format"] = true;
}
if (tmeta->format_number_ == 29) {
m["format_number"] = true;
}
}

void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
int backend_version_format{-1};
int format_number{-1};
if (m.find("backend_version_format") != m.end()) {
backend_version_format = 1;
}
if (m.find("format_number") != m.end()) {
format_number = 29;
}
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}

void custom_serialization_registry(){
torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, &for_serialization, &for_deserialization);
}

//check if BackendMeta serialization correctly
bool check_backend_meta(const at::Tensor& t) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
return true;
}
}
return false;
}

// a fake set function is exposed to the Python side
void custom_set_backend_meta(const at::Tensor& t) {
int backend_version_format{1};
int format_number{29};
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
Expand Down Expand Up @@ -127,6 +191,14 @@ at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
return result;
}

// Some set operations for the basic use case
at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result, c10::Storage storage, int64_t storage_offset, c10::IntArrayRef size, c10::IntArrayRef stride) {
result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : c10::nullopt;
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), size, stride_opt, /*resize_storage=*/!result.is_meta());
return result;
}

// basic dummy functions related to pin_memory.
std::vector<void*> custom_pinned_data_ptr;

Expand Down Expand Up @@ -186,6 +258,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("_copy_from", &custom__copy_from);
m.impl("empty_strided", &custom_empty_strided);
m.impl("set_.source_Storage", &custom_set_source_Storage);
m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
m.impl("_pin_memory", &custom__pin_memory);
m.impl("is_pinned", &custom_is_pinned);
m.impl("resize_", &custom_resize_);
Expand Down Expand Up @@ -251,4 +324,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_abs_called", &custom_abs_called, "check if our custom abs function was called");
m.def("register_generator", &register_generator, "register generator for custom device");
m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
}
30 changes: 16 additions & 14 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def transform(

return gm

@compile(module_override={nn.Linear: AssertOverride(self)})
@compile(module_override=[AssertOverride(self)])
def train_step(mod, opt, inp):
mod(inp).sum().backward()
opt.step()
Expand Down Expand Up @@ -434,14 +434,18 @@ def test_adam_foreach(self):
def test_adam_fused(self):
self._test_adam(foreach=False, fused=True)

def _test_train_step_override(self, module_override_key):
def _test_train_step_override(self):
transform_targets = []

class DDMOverride(Override):
def replacement(
self, fqn: str, orig_submodule: torch.nn.Module
) -> torch.nn.Module:
return DummyDDM()
return (
DummyDDM()
if isinstance(orig_submodule, DataDependentModule)
else orig_submodule
)

def transform(
self,
Expand All @@ -467,7 +471,7 @@ def transform(

return gm

@compile(module_override={module_override_key: DDMOverride()})
@compile(module_override=[DDMOverride()])
def train_step(mod, opt, inp):
mod(inp).sum().backward()
opt.step()
Expand All @@ -486,13 +490,8 @@ def train_step(mod, opt, inp):

@skip_if_lt_x_gpu(2)
@with_comms
def test_module_type_override(self):
self._test_train_step_override(module_override_key=DataDependentModule)

@skip_if_lt_x_gpu(2)
@with_comms
def test_module_fqn_override(self):
self._test_train_step_override(module_override_key="ddm")
def test_module_override(self):
self._test_train_step_override()

@skip_if_lt_x_gpu(2)
@with_comms
Expand All @@ -503,8 +502,11 @@ class DDMOverride(Override):
def replacement(
self, fqn: str, orig_submodule: torch.nn.Module
) -> torch.nn.Module:
if fqn in ["ddm1", "ddm2"]:
return DummyDDM()
return (
DummyDDM()
if isinstance(orig_submodule, DataDependentModule)
else orig_submodule
)

def transform(
self,
Expand Down Expand Up @@ -543,7 +545,7 @@ def forward(self, x):

return self.relu(self.ddm2(self.l2(self.ddm1(self.l1(x)))))

@compile(module_override={DataDependentModule: DDMOverride()})
@compile(module_override=[DDMOverride()])
def train_step(mod, opt, inp):
mod(inp).sum().backward()
opt.step()
Expand Down
23 changes: 10 additions & 13 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, x):
self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)

with self.assertRaisesRegex(RuntimeError, "Input #0"):
with self.assertRaisesRegex(RuntimeError, "Input arg0"):
pass_result.graph_module(torch.zeros(2, 7, 3))

self.assertEqual(pass_result.graph_module(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
Expand Down Expand Up @@ -127,10 +127,10 @@ def forward(self, x, y):
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)

with self.assertRaisesRegex(RuntimeError, "Input #0"):
with self.assertRaisesRegex(RuntimeError, "Input arg0"):
pass_result.graph_module(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

with self.assertRaisesRegex(RuntimeError, "Input #1"):
with self.assertRaisesRegex(RuntimeError, "Input arg1"):
pass_result.graph_module(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

def test_runtime_assert_some_dims_not_specified(self) -> None:
Expand Down Expand Up @@ -162,11 +162,11 @@ def forward(self, x, y):
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)

with self.assertRaisesRegex(RuntimeError, "Input #0"):
with self.assertRaisesRegex(RuntimeError, "Input arg0"):
pass_result.graph_module(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(RuntimeError, "Input #1's dimension #0 size is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "Input arg1's dimension #0 size is specialized at 5"):
pass_result.graph_module(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
Expand Down Expand Up @@ -203,11 +203,11 @@ def forward(self, x, y):
self.assertEqual(num_assert, 7)
self.assertEqual(num_scalar_tensor, 7)

with self.assertRaisesRegex(RuntimeError, "Input #0"):
with self.assertRaisesRegex(RuntimeError, "Input arg0"):
pass_result.graph_module(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(RuntimeError, "Input #1's dimension #0 size is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "Input arg1's dimension #0 size is specialized at 5"):
pass_result.graph_module(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
Expand Down Expand Up @@ -253,7 +253,6 @@ def test_views_op_having_view_copy(self) -> None:
if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))


def test_runtime_assert_inline_constraints_for_item(self) -> None:
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -284,7 +283,6 @@ def forward(self, x):
new_inp = torch.tensor([5])
self.assertEqual(mod(new_inp), new_gm(new_inp))


def test_runtime_assert_inline_constraints_for_nonzero(self) -> None:
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -320,8 +318,6 @@ def forward(self, x):
new_inp = torch.tensor([1, 1, 1, 1])
self.assertEqual(mod(new_inp), new_gm(new_inp))

# FIXME: support control flow operators for the pass
@unittest.expectedFailure
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self):
Expand All @@ -346,8 +342,9 @@ def false_fn(x, y):
mod = M()
gm = _export(mod, (torch.tensor(True), x, y))

_ = AddRuntimeAssertionsForConstraintsPass()(gm)

pass_result = AddRuntimeAssertionsForConstraintsPass()(gm)
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
pass_result.graph_module(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))


if __name__ == '__main__':
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,7 +2742,6 @@ def forward(self, x):
xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lstsq', ''), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lstsq', 'grad_oriented'), # aten.linalg_lstsq.default - couldn't find symbolic meta funct...
xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function...
xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta funct...
xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco...
Expand Down

0 comments on commit 0139807

Please sign in to comment.