Skip to content

Commit

Permalink
[ONNX] Fix assign input shape for tuple inputs & primitive type inputs (
Browse files Browse the repository at this point in the history
#54112) (#56164)

Summary: Pull Request resolved: #56164

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D27866139

Pulled By: SplitInfinity

fbshipit-source-id: c59f5a07df685e1ccdc4860d603ec422ec80d188
  • Loading branch information
BowenBao authored and facebook-github-bot committed Apr 21, 2021
1 parent 75995e4 commit 9986b10
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 31 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1554,14 +1554,14 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) {
if (typ->isSubtypeOf(IntType::get())) {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
return TensorType::createContiguous(at::kFloat, at::kCPU, {});
return TensorType::createContiguous(at::kDouble, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
return TensorType::createContiguous(at::kBool, at::kCPU, {});
}
TORCH_CHECK(false, "Unknown number type: ", typ->str());
}
inline TypePtr TensorType::fromBoolType() {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
return TensorType::createContiguous(at::kBool, at::kCPU, {});
}

inline c10::optional<c10::ScalarType> tryScalarTypeFromJitType(const c10::TypePtr & type) {
Expand Down
82 changes: 79 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,32 @@ def forward(self, a, b, c, d):
d = torch.randn(2, 3)
self.run_test(MyModel(), (a, b, c, d))

def test_tuple_input(self):
class TupleModel(torch.nn.Module):
def forward(self, a: Tuple[torch.Tensor, torch.Tensor]):
return a

x = (torch.randn(3, 4), torch.randn(4, 3))
self.run_test(TupleModel(), input=(x,))

def test_tuple_primitive_input(self):
class TupleModel(torch.nn.Module):
def forward(self, a: Tuple[int, torch.Tensor], b):
return a[0], a[1] + b

x = (3, torch.randn(4, 3))
y = torch.randn(4, 3)
self.run_test(TupleModel(), input=(x, y))

def test_nested_tuple_input(self):
class NestedTupleModel(torch.nn.Module):
def forward(self, a, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
return a + b[0] + b[1][0] + b[1][1]

x = torch.randn(4, 5)
y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1)))
self.run_test(NestedTupleModel(), input=(x, y))

@disableScriptTest()
def test_optional_inputs_with_no_optionals(self):
class NoOptionalModel(torch.nn.Module):
Expand Down Expand Up @@ -873,6 +899,46 @@ def forward(self, x, y=None, z=None):
z = torch.randn(2, 3)
self.run_test(Model(), (x, None, z))

def test_primitive_input_integer(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: int, y):
return x + y

x = 3
y = torch.randint(10, (2, 3, 4))
self.run_test(Model(), (x, y))

def test_primitive_input_floating(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: float, y):
return x + y

x = 3.0
y = torch.randn(2, 3, 4)
self.run_test(Model(), (x, y))

def test_primitive_input_bool(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, flag: bool, x, y):
if flag:
return x
else:
return y

flag = True
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(torch.jit.script(Model()), (flag, x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_cste_script(self):
class MyModel(torch.jit.ScriptModule):
Expand Down Expand Up @@ -1424,7 +1490,7 @@ def forward(self, x, y: int, z: bool, t: float):
self.run_test(ArithmeticModule(), (x, y, z, t))

class ArithmeticModule(torch.nn.Module):
def forward(self, x: float, y: float):
def forward(self, x: int, y: int):
return x == y

x = 3
Expand Down Expand Up @@ -3166,7 +3232,6 @@ def forward(self, x):
self.run_test(Model(), x)

@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # scripting prim_dtype
def test_lstm_no_hidden(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
Expand All @@ -3180,7 +3245,6 @@ def forward(self, x):
self.run_test(LSTMModel(), (input,))

@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # scripting prim_dtype
def test_lstm_proj_no_hidden(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -3623,6 +3687,18 @@ def forward(self, input):
return input > 1
self._test_compare_ops(GreaterModel(), 1)

def test_gt_primitive(self):
class GreaterModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.y : int = 2

def forward(self, x: int):
return self.y > x

x = 3
self.run_test(GreaterModel(), (x, ))

@skipIfUnsupportedMinOpsetVersion(9)
def test_ge_scalar(self):
class GreaterOrEqualModel(torch.nn.Module):
Expand Down
21 changes: 13 additions & 8 deletions torch/csrc/jit/passes/erase_number_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
namespace torch {
namespace jit {

void SetNumTypeToTensorType(Value* v) {
if (v->type()->isSubtypeOf(NumberType::get())) {
v->setType(TensorType::fromNumberType(v->type()));
} else if (v->type()->isSubtypeOf(BoolType::get())) {
v->setType(TensorType::fromBoolType());
}
}

void EraseNumberTypesOnBlock(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
++it) {
for (auto inp : it->inputs()) {
if (inp->type()->isSubtypeOf(NumberType::get())) {
inp->setType(TensorType::get());
}
SetNumTypeToTensorType(inp);
}
for (auto sub : it->blocks()) {
EraseNumberTypesOnBlock(sub);
Expand Down Expand Up @@ -49,18 +55,17 @@ void EraseNumberTypesOnBlock(Block* block) {
} break;
default: {
for (auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(TensorType::fromNumberType(o->type()));
} else if (o->type()->isSubtypeOf(BoolType::get())) {
o->setType(TensorType::fromBoolType());
}
SetNumTypeToTensorType(o);
}
} break;
}
}
}

void EraseNumberTypes(const std::shared_ptr<Graph>& graph) {
for (auto inp : graph->inputs()) {
SetNumTypeToTensorType(inp);
}
EraseNumberTypesOnBlock(graph->block());
}
} // namespace jit
Expand Down
48 changes: 31 additions & 17 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,26 +421,45 @@ static TypePtr getTensorType(const at::Tensor& t, bool complete) {
return r;
}

static TypePtr inferShapeAndTypeForInput(
TypePtr input_type,
Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
bool complete);

static TupleTypePtr getTupleTensorType(
const Stack::const_iterator& s_iter,
Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
const TypePtr& tupleType,
bool complete) {
AT_ASSERT(tupleType->kind() == TupleType::Kind);
AT_ASSERT(s_iter != s_iter_end);

TORCH_INTERNAL_ASSERT(tupleType->kind() == TupleType::Kind);
std::vector<TypePtr> types;
for (const auto& subType : tupleType->containedTypes()) {
if (subType->kind() == TupleType::Kind) {
types.push_back(
getTupleTensorType(s_iter + 1, s_iter_end, subType, complete));
} else {
types.push_back(getTensorType(s_iter->toTensor(), complete));
}
TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
types.emplace_back(
inferShapeAndTypeForInput(subType, s_iter, s_iter_end, complete));
}
return TupleType::create(types);
}

static TypePtr inferShapeAndTypeForInput(
TypePtr input_type,
Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
bool complete) {
if (input_type->kind() == TupleType::Kind) {
return getTupleTensorType(s_iter, s_iter_end, input_type, complete);
} else if (input_type->kind() == TensorType::Kind) {
auto type = getTensorType(s_iter->toTensor(), complete);
s_iter++;
return type;
} else {
// Primitive type, keep as is.
s_iter++;
return input_type;
}
}

static void setInputTensorTypes(Graph& g, const Stack& stack, bool complete) {
at::ArrayRef<Value*> input_values = g.inputs();
auto s_iter = stack.begin();
Expand All @@ -457,13 +476,8 @@ static void setInputTensorTypes(Graph& g, const Stack& stack, bool complete) {
}
}
}
if (v->type()->kind() == TupleType::Kind) {
AT_ASSERT(v->node()->kind() == prim::Param);
v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), complete));
} else {
v->setType(getTensorType(s_iter->toTensor(), complete));
s_iter++;
}
v->setType(
inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete));
}
}

Expand Down

0 comments on commit 9986b10

Please sign in to comment.