Skip to content

Commit

Permalink
Update torch/onnx/symbolic_opset
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Jul 13, 2021
1 parent 11dba85 commit 57a292c
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,11 +2918,25 @@ def remainder(g, input, other):
return g.op("Sub", input, quo)


def gelu(g, self):
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
@parse_args("v", "b")
def gelu(g, self, approximate):
if approximate:
kBeta = math.sqrt(2 / math.pi)
kKappa = 0.044715

beta = torch.tensor(kBeta, dtype=torch.double)
kappa = torch.tensor(kKappa, dtype=torch.double)
three = torch.tensor(3, dtype=torch.double)
one = torch.tensor(1, dtype=torch.double)
half = torch.tensor(0.5, dtype=torch.double)

inner = mul(g, beta, add(g, self, mul(g, kappa, g.op("Pow", self, three))))
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
else:
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))

@parse_args("v", "i", "v", "v", "f", "i")
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
Expand Down

0 comments on commit 57a292c

Please sign in to comment.