Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.GRUCell: Segfault by heap buffer overflow #106769

Closed
Sehun0819 opened this issue Aug 8, 2023 · 8 comments
Closed

torch.nn.GRUCell: Segfault by heap buffer overflow #106769

Sehun0819 opened this issue Aug 8, 2023 · 8 comments
Labels
actionable high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: edge cases Adversarial inputs unlikely to occur in practice module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Sehun0819
Copy link
Contributor

Sehun0819 commented Aug 8, 2023

🐛 Describe the bug

torch.nn.GRUCell module crashes by heap buffer overflow with specific tensor shape.

Test code:

import torch
m = torch.nn.GRUCell(1,1)
input = torch.randn(1,1)
hx = torch.randn(1,1,5,127,1)
hx = m(input,hx)

Error log:

==750219==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x6020009e42f8 at pc 0x7f102133115e bp 0x7ffd2d81b9f0 sp 0x7ffd2d81b9e8
READ of size 8 at 0x6020009e42f8 thread T0
    #0 0x7f102133115d in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12
    #1 0x7f10213aca22 in at::TensorBase::key_set() const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/TensorBase.h:331:12
    #2 0x7f102390f2e6 in c10::detail::MultiDispatchKeySet::operator()(at::Tensor const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:57:19
    #3 0x7f1024cb623d in c10::detail::MultiDispatchKeySet& at::IterArgs<c10::detail::MultiDispatchKeySet>::apply<at::Tensor const&, at::Tensor const&, c10::Scalar const&>(at::Tensor const&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/Variadic.h:32:5
    #4 0x7f1024cb6151 in c10::DispatchKeySet c10::detail::multi_dispatch_key_set<at::Tensor, at::Tensor, c10::Scalar>(at::Tensor const&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:107:34
    #5 0x7f102618ca79 in c10::DispatchKeySet c10::DispatchKeyExtractor::getDispatchKeySetUnboxed<at::Tensor&, at::Tensor const&, c10::Scalar const&>(at::Tensor& const&, at::Tensor const& const&, c10::Scalar const& const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:178:15
    #6 0x7f1025b38d2c in at::Tensor& c10::Dispatcher::call<at::Tensor&, at::Tensor&, at::Tensor const&, c10::Scalar const&>(c10::TypedOperatorHandle<at::Tensor& (at::Tensor&, at::Tensor const&, c10::Scalar const&)> const&, at::Tensor&, at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:637:15
    #7 0x7f1025b38d2c in c10::TypedOperatorHandle<at::Tensor& (at::Tensor&, at::Tensor const&, c10::Scalar const&)>::call(at::Tensor&, at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #8 0x7f1025b38d2c in at::_ops::add__Tensor::call(at::Tensor&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:996:15
    #9 0x7f1021dadcf4 in at::Tensor::add_(at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:1662:12
    #10 0x7f1022d5d55a in at::native::(anonymous namespace)::GRUCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, at::Tensor const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:797:27
    #11 0x7f1022d5c55f in at::native::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1667:10
    #12 0x7f10291935ae in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4607:10
    #13 0x7f10295594cb in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >::operator()(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #14 0x7f10295594cb in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #15 0x7f1024d52ce4 in at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #16 0x7f1025dba1bc in at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #17 0x7f1025dba1bc in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #18 0x7f1025dba1bc in c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #19 0x7f1025dba1bc in at::_ops::gru_cell::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:8021:15
    #20 0x7f104ea1358e in at::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/ops/gru_cell.h:27:12
    #21 0x7f104e9aabc6 in torch::autograd::THPVariable_gru_cell(_object*, _object*, _object*)::$_333::operator()(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/python_torch_functions_0.cpp:7397:12
    #22 0x7f104e902d6f in torch::autograd::THPVariable_gru_cell(_object*, _object*, _object*) /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/python_torch_functions_0.cpp:7399:15
    #23 0x5072d6 in cfunction_call /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543:19
    #24 0x4f06ab in _PyObject_MakeTpCall /usr/local/src/conda/python-3.9.17/Objects/call.c:191:18
    #25 0x4ecbfa in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116:16
    #26 0x4ecbfa in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103:1
    #27 0x4ecbfa in PyObject_Vectorcall /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127:12
    #28 0x4ecbfa in call_function /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077:13
    #29 0x4ecbfa in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489:23
    #30 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #31 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #32 0x4f7d83 in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #33 0x505070 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118:11
    #34 0x505070 in method_vectorcall /usr/local/src/conda/python-3.9.17/Objects/classobject.c:83:18
    #35 0x4eb0f3 in do_call_core /usr/local/src/conda/python-3.9.17/Python/ceval.c:5125:12
    #36 0x4eb0f3 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3582:22
    #37 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #38 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #39 0x4f7d83 in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #40 0x505070 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118:11
    #41 0x505070 in method_vectorcall /usr/local/src/conda/python-3.9.17/Objects/classobject.c:83:18
    #42 0x4eb0f3 in do_call_core /usr/local/src/conda/python-3.9.17/Python/ceval.c:5125:12
    #43 0x4eb0f3 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3582:22
    #44 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #45 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #46 0x4eff1d in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #47 0x4eff1d in _PyObject_FastCallDictTstate /usr/local/src/conda/python-3.9.17/Objects/call.c:118:15
    #48 0x502cc5 in _PyObject_Call_Prepend /usr/local/src/conda/python-3.9.17/Objects/call.c:489:24
    #49 0x5cb1e2 in slot_tp_call /usr/local/src/conda/python-3.9.17/Objects/typeobject.c:6731:15
    #50 0x4f06ab in _PyObject_MakeTpCall /usr/local/src/conda/python-3.9.17/Objects/call.c:191:18
    #51 0x4ec613 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116:16
    #52 0x4ec613 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103:1
    #53 0x4ec613 in PyObject_Vectorcall /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127:12
    #54 0x4ec613 in call_function /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077:13
    #55 0x4ec613 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520:19
    #56 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #57 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #58 0x4e6716 in _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361:12
    #59 0x4e66c8 in PyEval_EvalCodeEx /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377:12
    #60 0x59398a in PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:828:12
    #61 0x5c1216 in run_eval_code_obj /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221:9
    #62 0x5bd21f in run_mod /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242:19
    #63 0x4d07aa in PyRun_InteractiveOneObjectEx /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:274:9
    #64 0x4d0947 in PyRun_InteractiveLoopFlags /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:127:15
    #65 0x455882 in PyRun_AnyFileExFlags.cold /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:86:19
    #66 0x4527f3 in pymain_run_stdin /usr/local/src/conda/python-3.9.17/Modules/main.c:518:15
    #67 0x4527f3 in pymain_run_python /usr/local/src/conda/python-3.9.17/Modules/main.c:607:21
    #68 0x4527f3 in Py_RunMain.cold /usr/local/src/conda/python-3.9.17/Modules/main.c:683:5
    #69 0x587a38 in Py_BytesMain /usr/local/src/conda/python-3.9.17/Modules/main.c:1129:12
    #70 0x7f105eaa6d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16
    #71 0x7f105eaa6e3f in __libc_start_main csu/../csu/libc-start.c:392:3
    #72 0x5878ed in _start (/home/sehoon/anaconda3/envs/torch-latest-asan/bin/python3.9+0x5878ed)

0x6020009e42f8 is located 0 bytes to the right of 8-byte region [0x6020009e42f0,0x6020009e42f8)
allocated by thread T0 here:
    #0 0x7f105eea8af7 in operator new(unsigned long) /home/sehoon/llvm-project/compiler-rt/lib/asan/asan_new_delete.cpp:99:3
    #1 0x7f10213af239 in std::__new_allocator<at::Tensor>::allocate(unsigned long, void const*) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/new_allocator.h:137:27
    #2 0x7f10213af1c0 in std::allocator_traits<std::allocator<at::Tensor> >::allocate(std::allocator<at::Tensor>&, unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/alloc_traits.h:464:20
    #3 0x7f10213aef8f in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_allocate(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:378:20
    #4 0x7f10213da680 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_create_storage(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:395:33
    #5 0x7f10213da594 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_Vector_base(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:332:9
    #6 0x7f10213c5fa8 in std::vector<at::Tensor, std::allocator<at::Tensor> >::vector(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:552:9
    #7 0x7f102347dd67 in at::native::split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2512:23
    #8 0x7f102347e63f in at::native::unsafe_split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2527:17
    #9 0x7f10284eed4a in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3602:10
    #10 0x7f1028877e25 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >::operator()(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #11 0x7f1028877e25 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #12 0x7f1024c92401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #13 0x7f1024c937ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #14 0x7f1024c937ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::redispatch<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:670:26
    #15 0x7f10267d16a5 in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:506:41
    #16 0x7f10267d16a5 in at::_ops::unsafe_split_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4793:15
    #17 0x7f103071bc58 in at::redispatch::unsafe_split_symint(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RedispatchFunctions.h:6707:16
    #18 0x7f1030561b5d in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)::$_154::operator()() const /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16806:12
    #19 0x7f1030560613 in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16804:15
    #20 0x7f10305625e7 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >::operator()(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #21 0x7f10305625e7 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:480:14
    #22 0x7f1024c92401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #23 0x7f10267d082b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #24 0x7f10267d082b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #25 0x7f10267d082b in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #26 0x7f10267d082b in at::_ops::unsafe_split_Tensor::call(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4786:15
    #27 0x7f10234d278e in at::Tensor::unsafe_split(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:3442:12
    #28 0x7f10234517ed in at::native::unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:1011:17
    #29 0x7f1029104452 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:1821:10
    #30 0x7f1029323f34 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >::operator()(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #31 0x7f1029323f34 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #32 0x7f1024c92798 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #33 0x7f1025b6b13e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #34 0x7f1025b6b13e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)> const&, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #35 0x7f1025b6b13e in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #36 0x7f1025b6b13e in at::_ops::unsafe_chunk::call(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:1721:15
    #37 0x7f1022df0e62 in at::Tensor::unsafe_chunk(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:1992:12
    #38 0x7f1022d5d3ff in at::native::(anonymous namespace)::GRUCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, at::Tensor const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:793:52
    #39 0x7f1022d5c55f in at::native::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1667:10
    #40 0x7f10291935ae in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4607:10

SUMMARY: AddressSanitizer: heap-buffer-overflow /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const
Shadow bytes around the buggy address:
  0x0c0480134800: fa fa fd fa fa fa fd fa fa fa fd fa fa fa fd fa
  0x0c0480134810: fa fa 00 04 fa fa fd fa fa fa 00 00 fa fa 00 00
  0x0c0480134820: fa fa 00 00 fa fa fd fa fa fa fd fa fa fa fd fa
  0x0c0480134830: fa fa 00 00 fa fa fd fa fa fa 00 00 fa fa 00 00
  0x0c0480134840: fa fa 00 00 fa fa fd fa fa fa 04 fa fa fa fd fa
=>0x0c0480134850: fa fa 00 00 fa fa fd fa fa fa fd fa fa fa 00[fa]
  0x0c0480134860: fa fa fd fa fa fa fd fa fa fa fd fa fa fa fa fa
  0x0c0480134870: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c0480134880: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c0480134890: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c04801348a0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
Shadow byte legend (one shadow byte represents 8 application bytes):
  Addressable:           00
  Partially addressable: 01 02 03 04 05 06 07 
  Heap left redzone:       fa
  Freed heap region:       fd
  Stack left redzone:      f1
  Stack mid redzone:       f2
  Stack right redzone:     f3
  Stack after return:      f5
  Stack use after scope:   f8
  Global redzone:          f9
  Global init order:       f6
  Poisoned by user:        f7
  Container overflow:      fc
  Array cookie:            ac
  Intra object redzone:    bb
  ASan internal:           fe
  Left alloca redzone:     ca
  Right alloca redzone:    cb
  Shadow gap:              cc
==750219==ABORTING

Error location:

template <typename cell_params>
struct GRUCell : Cell<Tensor, cell_params> {
  using hidden_type = Tensor;

  hidden_type operator()(
      const Tensor& input,
      const hidden_type& hidden,
      const cell_params& params,
      bool pre_compute_input = false) const override {
    if (input.is_cuda() || input.is_xpu()) {
      TORCH_CHECK(!pre_compute_input);
      auto igates = params.matmul_ih(input);
      auto hgates = params.matmul_hh(hidden);
      auto result = at::_thnn_fused_gru_cell(
          igates, hgates, hidden, params.b_ih(), params.b_hh());
      // Slice off the workspace argument (it's needed only for AD).
      return std::move(std::get<0>(result));
    }
    const auto chunked_igates = pre_compute_input
        ? input.unsafe_chunk(3, 1)
        : params.linear_ih(input).unsafe_chunk(3, 1);
    auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
    const auto reset_gate =
        chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
    const auto input_gate =
        chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
    const auto new_gate =
        chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();
    return (hidden - new_gate).mul_(input_gate).add_(new_gate);
  }
};

In this execution, hidden tensor with unexpected shape (1,1,5,127,1) passes guards.
params.linear_hh(hidden).unsafe_chunk(3, 1) is expected to return 3 tensors, but in this situation it returns only one tensor making chunked_hgates[1] invalid. (while it has been revealed in add_)
Note than same bug happens when using C++ frontend.

Test code(C++):

#include <stdint.h>
#include <stddef.h>
#include <c10/util/irange.h>
#include <cassert>
#include <torch/torch.h>

namespace F = torch::nn::functional;
using namespace torch::nn;

int main() {
  try {
    torch::TensorOptions toptions = torch::TensorOptions();

    auto tensor_0 = torch::randn({1,1}, toptions);
    auto tensor_1 = torch::randn({1,1,5,127,1}, toptions);
    
    auto moptions = torch::nn::GRUCellOptions(1,1).bias(true);
    auto m = torch::nn::GRUCell(moptions);

    auto result = m->forward(tensor_0, tensor_1);
  } catch (std::exception& e) {
    return -2;
  }

  return 0;
}

Error log:

==750643==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x6020004b92d8 at pc 0x7fcbac67815e bp 0x7ffcd9f05e70 sp 0x7ffcd9f05e68
READ of size 8 at 0x6020004b92d8 thread T0
    #0 0x7fcbac67815d in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12
    #1 0x7fcbac6f3a22 in at::TensorBase::key_set() const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/TensorBase.h:331:12
    #2 0x7fcbaec562e6 in c10::detail::MultiDispatchKeySet::operator()(at::Tensor const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:57:19
    #3 0x7fcbafffd23d in c10::detail::MultiDispatchKeySet& at::IterArgs<c10::detail::MultiDispatchKeySet>::apply<at::Tensor const&, at::Tensor const&, c10::Scalar const&>(at::Tensor const&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/Variadic.h:32:5
    #4 0x7fcbafffd151 in c10::DispatchKeySet c10::detail::multi_dispatch_key_set<at::Tensor, at::Tensor, c10::Scalar>(at::Tensor const&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:107:34
    #5 0x7fcbb14d3a79 in c10::DispatchKeySet c10::DispatchKeyExtractor::getDispatchKeySetUnboxed<at::Tensor&, at::Tensor const&, c10::Scalar const&>(at::Tensor& const&, at::Tensor const& const&, c10::Scalar const& const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:178:15
    #6 0x7fcbb0e7fd2c in at::Tensor& c10::Dispatcher::call<at::Tensor&, at::Tensor&, at::Tensor const&, c10::Scalar const&>(c10::TypedOperatorHandle<at::Tensor& (at::Tensor&, at::Tensor const&, c10::Scalar const&)> const&, at::Tensor&, at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:637:15
    #7 0x7fcbb0e7fd2c in c10::TypedOperatorHandle<at::Tensor& (at::Tensor&, at::Tensor const&, c10::Scalar const&)>::call(at::Tensor&, at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #8 0x7fcbb0e7fd2c in at::_ops::add__Tensor::call(at::Tensor&, at::Tensor const&, c10::Scalar const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:996:15
    #9 0x7fcbad0f4cf4 in at::Tensor::add_(at::Tensor const&, c10::Scalar const&) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:1662:12
    #10 0x7fcbae0a455a in at::native::(anonymous namespace)::GRUCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, at::Tensor const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:797:27
    #11 0x7fcbae0a355f in at::native::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1667:10
    #12 0x7fcbb44da5ae in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4607:10
    #13 0x7fcbb48a04cb in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >::operator()(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #14 0x7fcbb48a04cb in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #15 0x7fcbb0099ce4 in at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #16 0x7fcbb11011bc in at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #17 0x7fcbb11011bc in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #18 0x7fcbb11011bc in c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #19 0x7fcbb11011bc in at::_ops::gru_cell::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:8021:15
    #20 0x7fcbc1841bae in at::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/ops/gru_cell.h:27:12
    #21 0x7fcbc1810e79 in torch::nn::GRUCellImpl::forward(at::Tensor const&, at::Tensor) /home/sehoon/pytorch-latest-asan/torch/csrc/api/src/nn/modules/rnn.cpp:1006:10
    #22 0x404d6a in main /home/sehoon/pytorch-latest-asan/test/cpp/reproduce/GRUCell.cpp:20:22
    #23 0x7fcb9fc4dd8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16
    #24 0x7fcb9fc4de3f in __libc_start_main csu/../csu/libc-start.c:392:3
    #25 0x404484 in _start (/home/sehoon/pytorch-latest-asan/build/bin/reproduce_GRUCell+0x404484)

0x6020004b92d8 is located 0 bytes to the right of 8-byte region [0x6020004b92d0,0x6020004b92d8)
allocated by thread T0 here:
    #0 0x7fcbd8b62af7 in operator new(unsigned long) /home/sehoon/llvm-project/compiler-rt/lib/asan/asan_new_delete.cpp:99:3
    #1 0x7fcbac6f6239 in std::__new_allocator<at::Tensor>::allocate(unsigned long, void const*) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/new_allocator.h:137:27
    #2 0x7fcbac6f61c0 in std::allocator_traits<std::allocator<at::Tensor> >::allocate(std::allocator<at::Tensor>&, unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/alloc_traits.h:464:20
    #3 0x7fcbac6f5f8f in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_allocate(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:378:20
    #4 0x7fcbac721680 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_create_storage(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:395:33
    #5 0x7fcbac721594 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_Vector_base(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:332:9
    #6 0x7fcbac70cfa8 in std::vector<at::Tensor, std::allocator<at::Tensor> >::vector(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:552:9
    #7 0x7fcbae7c4d67 in at::native::split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2512:23
    #8 0x7fcbae7c563f in at::native::unsafe_split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2527:17
    #9 0x7fcbb3835d4a in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3602:10
    #10 0x7fcbb3bbee25 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >::operator()(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #11 0x7fcbb3bbee25 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #12 0x7fcbaffd9401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #13 0x7fcbaffda7ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #14 0x7fcbaffda7ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::redispatch<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:670:26
    #15 0x7fcbb1b186a5 in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:506:41
    #16 0x7fcbb1b186a5 in at::_ops::unsafe_split_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4793:15
    #17 0x7fcbbba62c58 in at::redispatch::unsafe_split_symint(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RedispatchFunctions.h:6707:16
    #18 0x7fcbbb8a8b5d in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)::$_154::operator()() const /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16806:12
    #19 0x7fcbbb8a7613 in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16804:15
    #20 0x7fcbbb8a95e7 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >::operator()(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #21 0x7fcbbb8a95e7 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:480:14
    #22 0x7fcbaffd9401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #23 0x7fcbb1b1782b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #24 0x7fcbb1b1782b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #25 0x7fcbb1b1782b in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #26 0x7fcbb1b1782b in at::_ops::unsafe_split_Tensor::call(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4786:15
    #27 0x7fcbae81978e in at::Tensor::unsafe_split(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:3442:12
    #28 0x7fcbae7987ed in at::native::unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:1011:17
    #29 0x7fcbb444b452 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:1821:10
    #30 0x7fcbb466af34 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >::operator()(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #31 0x7fcbb466af34 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #32 0x7fcbaffd9798 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #33 0x7fcbb0eb213e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #34 0x7fcbb0eb213e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)> const&, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #35 0x7fcbb0eb213e in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #36 0x7fcbb0eb213e in at::_ops::unsafe_chunk::call(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:1721:15
    #37 0x7fcbae137e62 in at::Tensor::unsafe_chunk(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:1992:12
    #38 0x7fcbae0a43ff in at::native::(anonymous namespace)::GRUCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, at::Tensor const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:793:52
    #39 0x7fcbae0a355f in at::native::gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1667:10
    #40 0x7fcbb44da5ae in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__gru_cell(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4607:10

SUMMARY: AddressSanitizer: heap-buffer-overflow /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const
Shadow bytes around the buggy address:
  0x0c048008f200: fa fa 00 00 fa fa 00 00 fa fa 00 00 fa fa fd fa
  0x0c048008f210: fa fa 00 00 fa fa 00 00 fa fa 00 00 fa fa fd fa
  0x0c048008f220: fa fa fd fa fa fa fd fa fa fa 00 00 fa fa fd fa
  0x0c048008f230: fa fa 00 00 fa fa 00 00 fa fa 00 00 fa fa fd fa
  0x0c048008f240: fa fa fd fd fa fa fd fd fa fa fd fa fa fa 00 00
=>0x0c048008f250: fa fa fd fa fa fa fd fa fa fa 00[fa]fa fa fd fa
  0x0c048008f260: fa fa fd fa fa fa fd fa fa fa fa fa fa fa fa fa
  0x0c048008f270: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c048008f280: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c048008f290: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c048008f2a0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
Shadow byte legend (one shadow byte represents 8 application bytes):
  Addressable:           00
  Partially addressable: 01 02 03 04 05 06 07 
  Heap left redzone:       fa
  Freed heap region:       fd
  Stack left redzone:      f1
  Stack mid redzone:       f2
  Stack right redzone:     f3
  Stack after return:      f5
  Stack use after scope:   f8
  Global redzone:          f9
  Global init order:       f6
  Poisoned by user:        f7
  Container overflow:      fc
  Array cookie:            ac
  Intra object redzone:    bb
  ASan internal:           fe
  Left alloca redzone:     ca
  Right alloca redzone:    cb
  Shadow gap:              cc
==750643==ABORTING

Versions

PyTorch version: 2.1.0a0+git416bf4e
Is debug build: True
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 12.0.1 (git@github.com:starlab-unist/llvm-project.git 99b485c50897f9ca281636746cc468bf9b7a0bad)
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.9.17 (main, Jul 5 2023, 20:41:20) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 535.86.10
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.2
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
Stepping: 7
CPU max MHz: 4000.0000
CPU min MHz: 1200.0000
BogoMIPS: 6000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 1.5 MiB (48 instances)
L1i cache: 1.5 MiB (48 instances)
L2 cache: 48 MiB (48 instances)
L3 cache: 71.5 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.1.0a0+gitad22f0f
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 h06a4308_46342
[conda] numpy 1.25.2 pypi_0 pypi
[conda] torch 2.1.0a0+gitad22f0f dev_0

cc @ezyang @gchanan @zou3519 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@Sehun0819
Copy link
Contributor Author

torch.nn.LSTMCell has the same bug(maybe).

Test code:

import torch
m = torch.nn.LSTMCell(1,1)
input = torch.randn(1,1)
hx = torch.randn(1,1,5,127,1)
cx = torch.randn(1,1)
hx, cx = m(input,(hx,cx))

Error log:

==750846==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x6020009e4658 at pc 0x7f7b1413115e bp 0x7fff4474a870 sp 0x7fff4474a868
READ of size 8 at 0x6020009e4658 thread T0
    #0 0x7f7b1413115d in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12
    #1 0x7f7b141aca22 in at::TensorBase::key_set() const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/TensorBase.h:331:12
    #2 0x7f7b1670f2e6 in c10::detail::MultiDispatchKeySet::operator()(at::Tensor const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:57:19
    #3 0x7f7b1670f4d5 in c10::detail::MultiDispatchKeySet& at::IterArgs<c10::detail::MultiDispatchKeySet>::apply<at::Tensor const&>(at::Tensor const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/Variadic.h:32:5
    #4 0x7f7b177fc851 in c10::DispatchKeySet c10::detail::multi_dispatch_key_set<at::Tensor>(at::Tensor const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:107:34
    #5 0x7f7b17850ccd in c10::DispatchKeySet c10::DispatchKeyExtractor::getDispatchKeySetUnboxed<at::Tensor&>(at::Tensor& const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:178:15
    #6 0x7f7b17ff48e4 in at::Tensor& c10::Dispatcher::call<at::Tensor&, at::Tensor&>(c10::TypedOperatorHandle<at::Tensor& (at::Tensor&)> const&, at::Tensor&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:637:15
    #7 0x7f7b17ff48e4 in c10::TypedOperatorHandle<at::Tensor& (at::Tensor&)>::call(at::Tensor&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #8 0x7f7b17ff48e4 in at::_ops::sigmoid_::call(at::Tensor&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_1.cpp:4943:15
    #9 0x7f7b1564b5e4 in at::Tensor::sigmoid_() const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:3322:12
    #10 0x7f7b15b56ac0 in at::native::(anonymous namespace)::LSTMCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, std::tuple<at::Tensor, at::Tensor> const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:761:40
    #11 0x7f7b15b5443d in at::native::lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1556:10
    #12 0x7f7b1bf93484 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4600:10
    #13 0x7f7b1c357e53 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >::operator()(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #14 0x7f7b1c357e53 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&> >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #15 0x7f7b1a2b50fa in std::tuple<at::Tensor, at::Tensor> c10::callUnboxedKernelFunction<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<at::Tensor>&&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #16 0x7f7b19eb7dab in std::tuple<at::Tensor, at::Tensor> c10::KernelFunction::call<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #17 0x7f7b19eb7dab in std::tuple<at::Tensor, at::Tensor> c10::Dispatcher::call<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&>(c10::TypedOperatorHandle<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)> const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #18 0x7f7b19eb7dab in c10::TypedOperatorHandle<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&)>::call(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #19 0x7f7b19eb7dab in at::_ops::lstm_cell::call(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_4.cpp:6170:15
    #20 0x7f7b41bb4474 in at::lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/ops/lstm_cell.h:27:12
    #21 0x7f7b41b623f0 in torch::autograd::THPVariable_lstm_cell(_object*, _object*, _object*)::$_365::operator()(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) const /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/python_torch_functions_2.cpp:7823:12
    #22 0x7f7b41ac1720 in torch::autograd::THPVariable_lstm_cell(_object*, _object*, _object*) /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/python_torch_functions_2.cpp:7825:15
    #23 0x5072d6 in cfunction_call /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543:19
    #24 0x4f06ab in _PyObject_MakeTpCall /usr/local/src/conda/python-3.9.17/Objects/call.c:191:18
    #25 0x4ecbfa in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116:16
    #26 0x4ecbfa in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103:1
    #27 0x4ecbfa in PyObject_Vectorcall /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127:12
    #28 0x4ecbfa in call_function /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077:13
    #29 0x4ecbfa in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3489:23
    #30 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #31 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #32 0x4f7d83 in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #33 0x505070 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118:11
    #34 0x505070 in method_vectorcall /usr/local/src/conda/python-3.9.17/Objects/classobject.c:83:18
    #35 0x4eb0f3 in do_call_core /usr/local/src/conda/python-3.9.17/Python/ceval.c:5125:12
    #36 0x4eb0f3 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3582:22
    #37 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #38 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #39 0x4f7d83 in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #40 0x505070 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:118:11
    #41 0x505070 in method_vectorcall /usr/local/src/conda/python-3.9.17/Objects/classobject.c:83:18
    #42 0x4eb0f3 in do_call_core /usr/local/src/conda/python-3.9.17/Python/ceval.c:5125:12
    #43 0x4eb0f3 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3582:22
    #44 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #45 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #46 0x4eff1d in _PyFunction_Vectorcall /usr/local/src/conda/python-3.9.17/Objects/call.c:396:12
    #47 0x4eff1d in _PyObject_FastCallDictTstate /usr/local/src/conda/python-3.9.17/Objects/call.c:118:15
    #48 0x502cc5 in _PyObject_Call_Prepend /usr/local/src/conda/python-3.9.17/Objects/call.c:489:24
    #49 0x5cb1e2 in slot_tp_call /usr/local/src/conda/python-3.9.17/Objects/typeobject.c:6731:15
    #50 0x4f06ab in _PyObject_MakeTpCall /usr/local/src/conda/python-3.9.17/Objects/call.c:191:18
    #51 0x4ec613 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:116:16
    #52 0x4ec613 in _PyObject_VectorcallTstate /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:103:1
    #53 0x4ec613 in PyObject_Vectorcall /usr/local/src/conda/python-3.9.17/Include/cpython/abstract.h:127:12
    #54 0x4ec613 in call_function /usr/local/src/conda/python-3.9.17/Python/ceval.c:5077:13
    #55 0x4ec613 in _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.9.17/Python/ceval.c:3520:19
    #56 0x4e6a89 in _PyEval_EvalFrame /usr/local/src/conda/python-3.9.17/Include/internal/pycore_ceval.h:40:12
    #57 0x4e6a89 in _PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:4329:14
    #58 0x4e6716 in _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.9.17/Python/ceval.c:4361:12
    #59 0x4e66c8 in PyEval_EvalCodeEx /usr/local/src/conda/python-3.9.17/Python/ceval.c:4377:12
    #60 0x59398a in PyEval_EvalCode /usr/local/src/conda/python-3.9.17/Python/ceval.c:828:12
    #61 0x5c1216 in run_eval_code_obj /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1221:9
    #62 0x5bd21f in run_mod /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:1242:19
    #63 0x4d07aa in PyRun_InteractiveOneObjectEx /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:274:9
    #64 0x4d0947 in PyRun_InteractiveLoopFlags /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:127:15
    #65 0x455882 in PyRun_AnyFileExFlags.cold /usr/local/src/conda/python-3.9.17/Python/pythonrun.c:86:19
    #66 0x4527f3 in pymain_run_stdin /usr/local/src/conda/python-3.9.17/Modules/main.c:518:15
    #67 0x4527f3 in pymain_run_python /usr/local/src/conda/python-3.9.17/Modules/main.c:607:21
    #68 0x4527f3 in Py_RunMain.cold /usr/local/src/conda/python-3.9.17/Modules/main.c:683:5
    #69 0x587a38 in Py_BytesMain /usr/local/src/conda/python-3.9.17/Modules/main.c:1129:12
    #70 0x7f7b51833d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16
    #71 0x7f7b51833e3f in __libc_start_main csu/../csu/libc-start.c:392:3
    #72 0x5878ed in _start (/home/sehoon/anaconda3/envs/torch-latest-asan/bin/python3.9+0x5878ed)

0x6020009e4658 is located 0 bytes to the right of 8-byte region [0x6020009e4650,0x6020009e4658)
allocated by thread T0 here:
    #0 0x7f7b51c35af7 in operator new(unsigned long) /home/sehoon/llvm-project/compiler-rt/lib/asan/asan_new_delete.cpp:99:3
    #1 0x7f7b141af239 in std::__new_allocator<at::Tensor>::allocate(unsigned long, void const*) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/new_allocator.h:137:27
    #2 0x7f7b141af1c0 in std::allocator_traits<std::allocator<at::Tensor> >::allocate(std::allocator<at::Tensor>&, unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/alloc_traits.h:464:20
    #3 0x7f7b141aef8f in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_allocate(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:378:20
    #4 0x7f7b141da680 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_M_create_storage(unsigned long) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:395:33
    #5 0x7f7b141da594 in std::_Vector_base<at::Tensor, std::allocator<at::Tensor> >::_Vector_base(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:332:9
    #6 0x7f7b141c5fa8 in std::vector<at::Tensor, std::allocator<at::Tensor> >::vector(unsigned long, std::allocator<at::Tensor> const&) /usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/stl_vector.h:552:9
    #7 0x7f7b1627dd67 in at::native::split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2512:23
    #8 0x7f7b1627e63f in at::native::unsafe_split(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:2527:17
    #9 0x7f7b1b2eed4a in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3602:10
    #10 0x7f7b1b677e25 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >::operator()(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #11 0x7f7b1b677e25 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_Tensor_unsafe_split(at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #12 0x7f7b17a92401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #13 0x7f7b17a937ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #14 0x7f7b17a937ac in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::redispatch<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:670:26
    #15 0x7f7b195d16a5 in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:506:41
    #16 0x7f7b195d16a5 in at::_ops::unsafe_split_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4793:15
    #17 0x7f7b2351bc58 in at::redispatch::unsafe_split_symint(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RedispatchFunctions.h:6707:16
    #18 0x7f7b23361b5d in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)::$_154::operator()() const /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16806:12
    #19 0x7f7b23360613 in torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/torch/csrc/autograd/generated/VariableType_3.cpp:16804:15
    #20 0x7f7b233625e7 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >::operator()(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #21 0x7f7b233625e7 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long), &(torch::autograd::VariableType::(anonymous namespace)::unsafe_split_Tensor(c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:480:14
    #22 0x7f7b17a92401 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::SymInt&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #23 0x7f7b195d082b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:90:16
    #24 0x7f7b195d082b in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, c10::SymInt, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)> const&, at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #25 0x7f7b195d082b in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, c10::SymInt, long)>::call(at::Tensor const&, c10::SymInt, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #26 0x7f7b195d082b in at::_ops::unsafe_split_Tensor::call(at::Tensor const&, c10::SymInt, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_3.cpp:4786:15
    #27 0x7f7b162d278e in at::Tensor::unsafe_split(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:3442:12
    #28 0x7f7b162517ed in at::native::unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/TensorShape.cpp:1011:17
    #29 0x7f7b1bf04452 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:1821:10
    #30 0x7f7b1c123f34 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >::operator()(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13:16
    #31 0x7f7b1c123f34 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__unsafe_chunk(at::Tensor const&, long, long))>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<at::Tensor const&, long, long> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463:14
    #32 0x7f7b17a92798 in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::callUnboxedKernelFunction<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long&&, long&&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:50:12
    #33 0x7f7b1896b13e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::KernelFunction::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/boxing/KernelFunction_impl.h:102:16
    #34 0x7f7b1896b13e in std::vector<at::Tensor, std::allocator<at::Tensor> > c10::Dispatcher::call<std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, long, long>(c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)> const&, at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:653:26
    #35 0x7f7b1896b13e in c10::TypedOperatorHandle<std::vector<at::Tensor, std::allocator<at::Tensor> > (at::Tensor const&, long, long)>::call(at::Tensor const&, long, long) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/core/dispatch/Dispatcher.h:501:41
    #36 0x7f7b1896b13e in at::_ops::unsafe_chunk::call(at::Tensor const&, long, long) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/Operators_2.cpp:1721:15
    #37 0x7f7b15bf0e62 in at::Tensor::unsafe_chunk(long, long) const /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/core/TensorBody.h:1992:12
    #38 0x7f7b15b56a4d in at::native::(anonymous namespace)::LSTMCell<at::native::(anonymous namespace)::CellParams>::operator()(at::Tensor const&, std::tuple<at::Tensor, at::Tensor> const&, at::native::(anonymous namespace)::CellParams const&, bool) const /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:759:32
    #39 0x7f7b15b5443d in at::native::lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/aten/src/ATen/native/RNN.cpp:1556:10
    #40 0x7f7b1bf93484 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__lstm_cell(at::Tensor const&, c10::ArrayRef<at::Tensor>, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&) /home/sehoon/pytorch-latest-asan/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:4600:10

SUMMARY: AddressSanitizer: heap-buffer-overflow /home/sehoon/pytorch-latest-asan/c10/util/intrusive_ptr.h:409:12 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::operator->() const
Shadow bytes around the buggy address:
  0x0c0480134870: fa fa fd fa fa fa 01 fa fa fa 00 04 fa fa 00 00
  0x0c0480134880: fa fa fd fa fa fa 00 00 fa fa 00 00 fa fa 00 00
  0x0c0480134890: fa fa fd fa fa fa fd fa fa fa fd fa fa fa 04 fa
  0x0c04801348a0: fa fa fd fa fa fa 00 00 fa fa fd fa fa fa fd fa
  0x0c04801348b0: fa fa fd fa fa fa 00 00 fa fa 00 00 fa fa 00 00
=>0x0c04801348c0: fa fa fd fa fa fa fd fa fa fa 00[fa]fa fa fd fa
  0x0c04801348d0: fa fa fd fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c04801348e0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c04801348f0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c0480134900: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0c0480134910: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
Shadow byte legend (one shadow byte represents 8 application bytes):
  Addressable:           00
  Partially addressable: 01 02 03 04 05 06 07 
  Heap left redzone:       fa
  Freed heap region:       fd
  Stack left redzone:      f1
  Stack mid redzone:       f2
  Stack right redzone:     f3
  Stack after return:      f5
  Stack use after scope:   f8
  Global redzone:          f9
  Global init order:       f6
  Poisoned by user:        f7
  Container overflow:      fc
  Array cookie:            ac
  Intra object redzone:    bb
  ASan internal:           fe
  Left alloca redzone:     ca
  Right alloca redzone:    cb
  Shadow gap:              cc
==750846==ABORTING

@drisspg drisspg added module: edge cases Adversarial inputs unlikely to occur in practice triage review module: crash Problem manifests as a hard crash, as opposed to a RuntimeError labels Aug 9, 2023
@drisspg
Copy link
Contributor

drisspg commented Aug 9, 2023

I was able to successfully reproduce on my machine

@drisspg
Copy link
Contributor

drisspg commented Aug 9, 2023

Would you like to submit a PR fixing this issue?

@mikaylagawarecki mikaylagawarecki added the module: nn Related to torch.nn label Aug 9, 2023
@zou3519
Copy link
Contributor

zou3519 commented Aug 14, 2023

This sounds like we're not validating the input shapes

@albanD albanD added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module actionable and removed triage review labels Aug 14, 2023
@albanD
Copy link
Collaborator

albanD commented Aug 14, 2023

Could you please send a PR fixing this?

summerdo pushed a commit to summerdo/pytorch that referenced this issue Aug 17, 2023
Fixes pytorch#106769

As mentioned in [GRUCell](https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html#grucell), `hidden` should have the same dimension as `input`, and the dimension should be either `1D` or `2D`.

As for other aspects, it has been verified in `C++`, such as the batch of `Input` and `hidden` are the same, `Input`'s Dim1 and `input_size` are the same, `hidden`'s Dim1 and `hidden_size` are the same, etc.
Pull Request resolved: pytorch#107223
Approved by: https://github.com/albanD
@Sehun0819
Copy link
Contributor Author

@albanD
Could you consider reopening it?
Thanks to @FFFrog we don't suffer this problem when we use Python API, but C++ frontend still leads to segfault(note the C++ test code I wrote above).

@FFFrog
Copy link
Collaborator

FFFrog commented Aug 30, 2023

@albanD Could you consider reopening it? Thanks to @FFFrog we don't suffer this problem when we use Python API, but C++ frontend still leads to segfault(note the C++ test code I wrote above).

Sorry, I only noticed the python frontend and ignored c++.

@FFFrog
Copy link
Collaborator

FFFrog commented Aug 30, 2023

I'll submit a PR tomorrow to fix it if no one is ready to fix the C++ frontend yet, is it ok?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: edge cases Adversarial inputs unlikely to occur in practice module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants