diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 1d38e1e1f4cd..1a63edf0fc15 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -71,6 +71,7 @@ static bool IsComparisonOp(const NodeKind& nkind) { static TensorTypePtr CreateProfiledTensorTypeWithScalarType( const TensorTypePtr& typePtr, const c10::ScalarType& scalar_type) { + AT_ASSERT(typePtr != nullptr); return typePtr->withScalarType({scalar_type}); } @@ -132,6 +133,15 @@ static c10::optional PromoteScalarTypesWithCategory( static c10::optional InferExpectedScalarType(const Node* n) { std::vector typesFromTensors; std::vector typesFromScalars; + + auto get_scalar_type = + [](const Value* input) -> c10::optional { + if (auto tensor_type = input->type()->cast()) { + return tensor_type->scalarType(); + } + return c10::nullopt; + }; + std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { auto nkind = input->node()->kind(); @@ -180,16 +190,13 @@ static c10::optional InferExpectedScalarType(const Node* n) { } else { typesFromTensors.emplace_back(scalar_type); } - } else if ( - auto scalar_type = - input->type()->cast()->scalarType()) { + } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } }); c10::optional st = c10::nullopt; - const c10::optional output_st = - n->output()->type()->cast()->scalarType(); + const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { // For comparison ops, always promote scalar type to highest among inputs, @@ -236,7 +243,8 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); - auto input_scalar_type = input_tensor_type->scalarType(); + auto input_scalar_type = + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; if ((input->node()->kind() == onnx::Constant) || (input_scalar_type && (*input_scalar_type != scalar_type))) {