Skip to content

[BE] Introduce torch.AcceleratorError #152023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Apr 23, 2025

Stack from ghstack (oldest at bottom):

Which inherits from RuntimeError and contains error_code, which in case of CUDA should contain error returned by cudaGetLastError

torch::detail::_new_accelerator_error_object(c10::AcceleratorError&) follows the pattern of CPython's PyErr_SetString, namely

  • Convert cstr into Python string with PyUnicode_FromString
  • Create new exception object using PyObject_CallOneArg just like it's done in _PyErr_CreateException
  • Set error_code property using PyObject_SetAttrString
  • decref all temporary references

Test that it works and captures CPP backtrace (in addition to CI) by running

import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'

import torch

x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
    x[y] = 2
    print(x)
except torch.AcceleratorError as e:
    print("Exception was raised", e.args[0])
    print("Captured error code is ", e.error_code)

which produces following output

Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
#12 at::native::abs(at::Tensor const&) from ??:0
#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#17 at::_ops::abs::call(at::Tensor const&) from ??:0
#18 at::native::isfinite(at::Tensor const&) from ??:0
#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
#22 PyObject_CallFunctionObjArgs from ??:0
#23 _PyObject_MakeTpCall from ??:0
#24 _PyEval_EvalFrameDefault from ??:0
#25 _PyObject_FastCallDictTstate from ??:0
#26 _PyStack_AsDict from ??:0
#27 _PyObject_MakeTpCall from ??:0
#28 _PyEval_EvalFrameDefault from ??:0
#29 _PyFunction_Vectorcall from ??:0
#30 _PyEval_EvalFrameDefault from ??:0
#31 _PyFunction_Vectorcall from ??:0
#32 _PyEval_EvalFrameDefault from ??:0
#33 _PyFunction_Vectorcall from ??:0
#34 _PyEval_EvalFrameDefault from ??:0
#35 PyFrame_GetCode from ??:0
#36 PyNumber_Xor from ??:0
#37 PyObject_Str from ??:0
#38 PyFile_WriteObject from ??:0
#39 _PyWideStringList_AsList from ??:0
#40 _PyDict_NewPresized from ??:0
#41 _PyEval_EvalFrameDefault from ??:0
#42 PyEval_EvalCode from ??:0
#43 PyEval_EvalCode from ??:0
#44 PyUnicode_Tailmatch from ??:0
#45 PyInit__collections from ??:0
#46 PyUnicode_Tailmatch from ??:0
#47 _PyRun_SimpleFileObject from ??:0
#48 _PyRun_AnyFileObject from ??:0
#49 Py_RunMain from ??:0
#50 Py_BytesMain from ??:0
#51 __libc_init_first from ??:0
#52 __libc_start_main from ??:0
#53 _start from ??:0

Captured error code is  710

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152023

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 96c16f9 with merge base 1ca082d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

malfet added a commit that referenced this pull request Apr 23, 2025
TODO: May be we just need generic DeviceError class that preserves error code?

ghstack-source-id: bd06c59
Pull Request resolved: #152023
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@malfet malfet changed the title [BE] Throw different errors for CUDA exceptions [BE] Introduce torch.DeviceError May 28, 2025
@mradmila
Copy link
Contributor

It would be super-nice to also have a test for this, with exception marshalled to Python having all the properties we want.

@malfet malfet added the topic: improvements topic category label May 28, 2025
[ghstack-poisoned]
@malfet malfet requested a review from albanD as a code owner May 28, 2025 06:49
@malfet malfet added the release notes: python_frontend python frontend release notes category label May 28, 2025
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 28, 2025
ghstack-source-id: a45116c
Pull Request resolved: #152023
@malfet
Copy link
Contributor Author

malfet commented May 28, 2025

It would be super-nice to also have a test for this, with exception marshalled to Python having all the properties we want.

There are already test_index_out_of_bounds_exception_cuda, see below, will modify it to check the error_code

def test_index_out_of_bounds_exception_cuda(self):

@malfet malfet changed the title [BE] Introduce torch.DeviceError [BE] Introduce torch.AcceleratorError May 29, 2025
@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label May 29, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 29, 2025
ghstack-source-id: 72ff0fa
Pull Request resolved: #152023
[ghstack-poisoned]
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 31, 2025
ghstack-source-id: 889a78e
Pull Request resolved: #152023
@malfet
Copy link
Contributor Author

malfet commented Jun 1, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError`

`torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's  [`PyErr_SetString`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L282), namely
- Convert cstr into Python string with `PyUnicode_FromString`
- Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L32)
- Set `error_code` property using `PyObject_SetAttrString`
- decref all temporary references

Test that it works and captures CPP backtrace (in addition to CI) by running
```python
import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'

import torch

x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
    x[y] = 2
    print(x)
except torch.AcceleratorError as e:
    print("Exception was raised", e.args[0])
    print("Captured error code is ", e.error_code)
```

which produces following output
```
Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
#12 at::native::abs(at::Tensor const&) from ??:0
#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#17 at::_ops::abs::call(at::Tensor const&) from ??:0
#18 at::native::isfinite(at::Tensor const&) from ??:0
#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
#22 PyObject_CallFunctionObjArgs from ??:0
#23 _PyObject_MakeTpCall from ??:0
#24 _PyEval_EvalFrameDefault from ??:0
#25 _PyObject_FastCallDictTstate from ??:0
#26 _PyStack_AsDict from ??:0
#27 _PyObject_MakeTpCall from ??:0
#28 _PyEval_EvalFrameDefault from ??:0
#29 _PyFunction_Vectorcall from ??:0
#30 _PyEval_EvalFrameDefault from ??:0
#31 _PyFunction_Vectorcall from ??:0
#32 _PyEval_EvalFrameDefault from ??:0
#33 _PyFunction_Vectorcall from ??:0
#34 _PyEval_EvalFrameDefault from ??:0
#35 PyFrame_GetCode from ??:0
#36 PyNumber_Xor from ??:0
#37 PyObject_Str from ??:0
#38 PyFile_WriteObject from ??:0
#39 _PyWideStringList_AsList from ??:0
#40 _PyDict_NewPresized from ??:0
#41 _PyEval_EvalFrameDefault from ??:0
#42 PyEval_EvalCode from ??:0
#43 PyEval_EvalCode from ??:0
#44 PyUnicode_Tailmatch from ??:0
#45 PyInit__collections from ??:0
#46 PyUnicode_Tailmatch from ??:0
#47 _PyRun_SimpleFileObject from ??:0
#48 _PyRun_AnyFileObject from ??:0
#49 Py_RunMain from ??:0
#50 Py_BytesMain from ??:0
#51 __libc_init_first from ??:0
#52 __libc_start_main from ??:0
#53 _start from ??:0

Captured error code is  710
```
Pull Request resolved: pytorch#152023
Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel
ghstack dependencies: pytorch#154436
@github-actions github-actions bot deleted the gh/malfet/299/head branch July 2, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend python frontend release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants