Skip to content

Commit

Permalink
randint_like as well on "[ONNX] Support 'aten::randint' in torchscrip…
Browse files Browse the repository at this point in the history
…t onnx exporter"

Export as 'ONNX::RandomUniform' which produces floating point result,
then round it to integer with 'ONNX::Cast'.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jul 12, 2023
1 parent bccc4a0 commit ffd7663
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 27 deletions.
10 changes: 10 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,16 @@ def forward(self, x):
x = torch.randn(2, 3, 4)
self.run_test(RandInt(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_randint_like(self):
class RandInt(torch.nn.Module):
def forward(self, x):
# This randint call always returns 3
return torch.randint_like(x, 3, 4) + x

x = torch.randn(2, 3, 4)
self.run_test(RandInt(), x)

def test_randn(self):
class RandN(torch.nn.Module):
def forward(self, x):
Expand Down
86 changes: 59 additions & 27 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@
"prim_uninitialized",
"rand_like",
"rand",
"randint_like",
"randint",
"randn_like",
"randn",
"reciprocal",
Expand Down Expand Up @@ -5124,33 +5126,6 @@ def _pad_packed_sequence(
return data, lengths


@_onnx_symbolic("aten::randn")
@_beartype.beartype
def randn(g: jit_utils.GraphContext, shapes, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
shape = symbolic_helper._maybe_get_const(shapes, "is")
if symbolic_helper._is_value(shape):
shape_const = g.op(
"ConstantOfShape",
shapes,
value_t=torch.tensor([0], dtype=torch.float),
)
return g.op(
"RandomNormalLike",
shape_const,
dtype_i=scalar_type.onnx_type(),
)
return g.op(
"RandomNormal",
shape_i=shape,
dtype_i=scalar_type.onnx_type(),
)


@_onnx_symbolic("aten::randint")
@_beartype.beartype
def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options):
Expand Down Expand Up @@ -5195,6 +5170,63 @@ def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options):
return randint


@_onnx_symbolic("aten::randint_like")
@_beartype.beartype
def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
low_i = symbolic_helper._get_const(low, "i", "low")
high_i = symbolic_helper._get_const(high, "i", "high")
if dtype is None:
scalar_type = _type_utils.JitScalarType.INT64
else:
scalar_type = _type_utils.JitScalarType(dtype)
if low_i is None:
raise symbolic_helper._onnx_unsupported("randint", low)
if high_i is None:
raise symbolic_helper._onnx_unsupported("randint", high)

randn = g.op(
"RandomUniformLike",
self,
low_f=low_i,
high_f=high_i,
)

# cast to integer type
int_dtype = _type_utils.JitScalarType.INT64
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
if int_dtype != scalar_type:
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
return randint


@_onnx_symbolic("aten::randn")
@_beartype.beartype
def randn(g: jit_utils.GraphContext, shapes, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
shape = symbolic_helper._maybe_get_const(shapes, "is")
if symbolic_helper._is_value(shape):
shape_const = g.op(
"ConstantOfShape",
shapes,
value_t=torch.tensor([0], dtype=torch.float),
)
return g.op(
"RandomNormalLike",
shape_const,
dtype_i=scalar_type.onnx_type(),
)
return g.op(
"RandomNormal",
shape_i=shape,
dtype_i=scalar_type.onnx_type(),
)


@_onnx_symbolic("aten::rand")
@_beartype.beartype
def rand(g: jit_utils.GraphContext, shapes, dtype, *options):
Expand Down

0 comments on commit ffd7663

Please sign in to comment.