From 81778e28119271001cc75dc8ca74cffc6031ced8 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 8 Jan 2021 10:05:24 -0800 Subject: [PATCH] [onnx] Do not deref nullptr in scalar type analysis (#50237) Summary: Apply a little bit of defensive programming: `type->cast()` returns an optional pointer so dereferencing it can lead to a hard crash. Fixes SIGSEGV reported in https://github.com/pytorch/pytorch/issues/49959 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50237 Reviewed By: walterddr Differential Revision: D25839675 Pulled By: malfet fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab --- .../jit/passes/onnx/scalar_type_analysis.cpp | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 1d38e1e1f4cd..1a63edf0fc15 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -71,6 +71,7 @@ static bool IsComparisonOp(const NodeKind& nkind) { static TensorTypePtr CreateProfiledTensorTypeWithScalarType( const TensorTypePtr& typePtr, const c10::ScalarType& scalar_type) { + AT_ASSERT(typePtr != nullptr); return typePtr->withScalarType({scalar_type}); } @@ -132,6 +133,15 @@ static c10::optional PromoteScalarTypesWithCategory( static c10::optional InferExpectedScalarType(const Node* n) { std::vector typesFromTensors; std::vector typesFromScalars; + + auto get_scalar_type = + [](const Value* input) -> c10::optional { + if (auto tensor_type = input->type()->cast()) { + return tensor_type->scalarType(); + } + return c10::nullopt; + }; + std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { auto nkind = input->node()->kind(); @@ -180,16 +190,13 @@ static c10::optional InferExpectedScalarType(const Node* n) { } else { typesFromTensors.emplace_back(scalar_type); } - } else if ( - auto scalar_type = - input->type()->cast()->scalarType()) { + } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } }); c10::optional st = c10::nullopt; - const c10::optional output_st = - n->output()->type()->cast()->scalarType(); + const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { // For comparison ops, always promote scalar type to highest among inputs, @@ -236,7 +243,8 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); - auto input_scalar_type = input_tensor_type->scalarType(); + auto input_scalar_type = + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; if ((input->node()->kind() == onnx::Constant) || (input_scalar_type && (*input_scalar_type != scalar_type))) {