-
-
Notifications
You must be signed in to change notification settings - Fork 71
Implement operator Softmax for backend Gorgonia/Gorgonnx #46
Comments
For a start, a simple version should be implemented with the |
WIP in the branch softmax-issue-46 |
The code can be copied from the Softmax operator of Gorgonia instead of using it out-of-the-box. The commit 7b45c9b is partially implementing the softmax. The trivial test pass:
The other tests don't. The error is:
This is probably link to what onnx expect:
|
A reshape actually does the trick;
|
Using stabilization does not seems to help a lot: diff --git a/backend/x/gorgonnx/softmax.go b/backend/x/gorgonnx/softmax.go
index 11adc69..604d9af 100644
--- a/backend/x/gorgonnx/softmax.go
+++ b/backend/x/gorgonnx/softmax.go
@@ -22,7 +22,7 @@ func (s *softmax) apply(g *Graph, n *Node) error {
return err
}
a := children[0].gorgoniaNode
- var reshaped *gorgonia.Node
+ var max, reshaped *gorgonia.Node
if len(a.Shape()) > 2 {
if s.axis > len(a.Shape()) {
return errors.New("softmax cannot be applied on an axis > len(shape()) of the input")
@@ -43,8 +43,19 @@ func (s *softmax) apply(g *Graph, n *Node) error {
} else {
reshaped = a
}
+ if max, err = gorgonia.Max(reshaped); err != nil {
+ return err
+ }
+ a2, b2, err := gorgonia.Broadcast(reshaped, max, gorgonia.NewBroadcastPattern(nil, []byte{0, 1}))
+ if err != nil {
+ return err
+ }
+ output, err := gorgonia.Sub(a2, b2)
+ if err != nil {
+ return err
+ }
var exp, sum *gorgonia.Node
- if exp, err = gorgonia.Exp(reshaped); err == nil {
+ if exp, err = gorgonia.Exp(output); err == nil {
axis := 1
if exp.IsScalar() {
axis = 0
|
implemented by PR #56 |
Why is this operator needed?
This operator is needed at least to run the inception v1 model;
Implementation
Softmax
operator in ONNXLink to existing material on the backend
Softmax
on GorgoniaStableSoftmax
on GorgoniaExpected problems?
stable
or the non-stable versionTests
go test -run=ONNX/TestSoftmax
The text was updated successfully, but these errors were encountered: