Skip to content
This repository has been archived by the owner on May 31, 2024. It is now read-only.

Commit

Permalink
wip: Softmax is roughly a copy/paste from Gorgonia
Browse files Browse the repository at this point in the history
This should allow to implement axis!=1
By now, only the simple example is passing the tests.
The problem is related to the comment in the ONNX definition:

Input does not need to explicitly be a 2D vector; rather, it will be coerced into one. For an arbitrary n-dimensional tensor input \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is the axis provided, then input will be coerced into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default case where axis=1, this means the input tensor will be coerced into a 2D tensor of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size. In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D. Each of these dimensions must be matched correctly, or else the operator will throw errors.
  • Loading branch information
owulveryck committed May 8, 2019
1 parent 8116393 commit 7b45c9b
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions backend/x/gorgonnx/softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"

"github.com/owulveryck/onnx-go"
"gorgonia.org/gorgonia"
)

type softmax struct {
Expand All @@ -14,21 +15,41 @@ func init() {
register("Softmax", &softmax{})
}

func (a *softmax) apply(g *Graph, n *Node) error {
return &onnx.ErrNotImplemented{
Operator: "Softmax",
func (s *softmax) apply(g *Graph, n *Node) error {
children := getOrderedChildren(g.g, n)
err := checkCondition(children, 1)
if err != nil {
return err
}
a := children[0].gorgoniaNode
var exp, sum *gorgonia.Node
if exp, err = gorgonia.Exp(a); err == nil {
if sum, err = gorgonia.Sum(exp, s.axis); err == nil {
if sum.IsScalar() {
n.gorgoniaNode, err = gorgonia.HadamardDiv(exp, sum)
return err
}
a, b, err := gorgonia.Broadcast(exp, sum, gorgonia.NewBroadcastPattern(nil, []byte{1}))
if err != nil {
return err
}
n.gorgoniaNode, err = gorgonia.Div(a, b)
return err
}
return err
}
return err
}

func (a *softmax) init(o onnx.Operation) error {
func (s *softmax) init(o onnx.Operation) error {
axis, ok := o.Attributes["axis"]
if !ok {
a.axis = 1
s.axis = 1
return nil
}
err := errors.New("axis in not an int")
if axis, ok := axis.(int64); ok {
a.axis = int(axis)
s.axis = int(axis)
err = nil
}
return err
Expand Down

0 comments on commit 7b45c9b

Please sign in to comment.