Skip to content

Commit

Permalink
[ONNX] Add silu operator support for onnx (#51193)
Browse files Browse the repository at this point in the history
Support for yolov5 compound-scaled object detection models export.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Feb 2, 2021
1 parent 6d26106 commit 73c93b0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4372,6 +4372,18 @@ def forward(self, x):
dynamic_axes={'x': [1, 2]},
test_with_inputs=[y])

def test_silu(self):
class SiLUModel(torch.nn.Module):
def __init__(self):
super(SiLUModel, self).__init__()
self.silu = torch.nn.SiLU()

def forward(self, x):
return self.silu(x)

x = torch.randn(2, 3, 4)
self.run_test(SiLUModel(), (x))

def test_remainder(self):
class RemainderModel(torch.nn.Module):
def forward(self, input, other):
Expand Down
4 changes: 4 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,10 @@ def prelu(g, self, weight):
return g.op("PRelu", self, weight)


def silu(g, input):
return g.op('Mul', input, g.op('Sigmoid', input))


def relu(g, input):
return g.op("Relu", input)

Expand Down

0 comments on commit 73c93b0

Please sign in to comment.