Skip to content

Commit

Permalink
[onnx] Do not deref nullptr in scalar type analysis (#50237)
Browse files Browse the repository at this point in the history
Summary:
Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash.

Fixes SIGSEGV reported in #49959

Pull Request resolved: #50237

Reviewed By: walterddr

Differential Revision: D25839675

Pulled By: malfet

fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab
  • Loading branch information
malfet authored and facebook-github-bot committed Jan 8, 2021
1 parent b5ab0a7 commit 81778e2
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
Expand Up @@ -71,6 +71,7 @@ static bool IsComparisonOp(const NodeKind& nkind) {
static TensorTypePtr CreateProfiledTensorTypeWithScalarType(
const TensorTypePtr& typePtr,
const c10::ScalarType& scalar_type) {
AT_ASSERT(typePtr != nullptr);
return typePtr->withScalarType({scalar_type});
}

Expand Down Expand Up @@ -132,6 +133,15 @@ static c10::optional<c10::ScalarType> PromoteScalarTypesWithCategory(
static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
std::vector<c10::ScalarType> typesFromTensors;
std::vector<c10::ScalarType> typesFromScalars;

auto get_scalar_type =
[](const Value* input) -> c10::optional<at::ScalarType> {
if (auto tensor_type = input->type()->cast<TensorType>()) {
return tensor_type->scalarType();
}
return c10::nullopt;
};

std::for_each(
n->inputs().begin(), n->inputs().end(), [&](const Value* input) {
auto nkind = input->node()->kind();
Expand Down Expand Up @@ -180,16 +190,13 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
} else {
typesFromTensors.emplace_back(scalar_type);
}
} else if (
auto scalar_type =
input->type()->cast<TensorType>()->scalarType()) {
} else if (auto scalar_type = get_scalar_type(input)) {
typesFromTensors.emplace_back(*scalar_type);
}
});

c10::optional<c10::ScalarType> st = c10::nullopt;
const c10::optional<c10::ScalarType> output_st =
n->output()->type()->cast<TensorType>()->scalarType();
const auto output_st = get_scalar_type(n->output());

if (IsComparisonOp(n->kind())) {
// For comparison ops, always promote scalar type to highest among inputs,
Expand Down Expand Up @@ -236,7 +243,8 @@ static void UpdateScalarTypeForInputs(

for (auto input : n->inputs()) {
auto input_tensor_type = input->type()->cast<TensorType>();
auto input_scalar_type = input_tensor_type->scalarType();
auto input_scalar_type =
input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt;

if ((input->node()->kind() == onnx::Constant) ||
(input_scalar_type && (*input_scalar_type != scalar_type))) {
Expand Down

0 comments on commit 81778e2

Please sign in to comment.