-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.go
executable file
·101 lines (81 loc) · 2.46 KB
/
run.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package tf
import (
"fmt"
models "github.com/TIBCOSoftware/flogo-contrib/activity/inference/model"
"github.com/golang/protobuf/proto"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// Run is used to execute a Tensorflow model with the model input data
func (i *TensorflowModel) Run(model *models.Model) (out map[string]interface{}, err error) {
// Grab native tf SavedModel
savedModel := model.Instance.(*tf.SavedModel)
var inputOps = make(map[string]*tf.Operation)
var outputOps []tf.Output
// Validate that the operations exsist and create operation
for k, v := range model.Metadata.Inputs.Params {
if validateOperation(v.Name, savedModel) == false {
return nil, fmt.Errorf("Invalid operation %s", v.Name)
}
inputOps[k] = savedModel.Graph.Operation(v.Name)
}
// Create output operations
var outputOrder []string
for k, o := range model.Metadata.Outputs {
outputOps = append(outputOps, savedModel.Graph.Operation(o.Name).Output(0))
outputOrder = append(outputOrder, k)
}
// create input tensors and add to map
inputs := make(map[tf.Output]*tf.Tensor)
for inputName, inputMap := range inputOps {
examplePb, err := createInputExampleTensor(inputName, model)
if err != nil {
return nil, err
}
inputs[inputMap.Output(0)] = examplePb
}
results, err := savedModel.Session.Run(inputs, outputOps, nil)
if err != nil {
return nil, err
}
// Iterate over the expected outputs, find the actual and map into map
out = make(map[string]interface{})
for k := range model.Metadata.Outputs {
for i := 0; i < len(outputOrder); i++ {
if outputOrder[i] == k {
out[k] = getTensorValue(results[i])
}
}
}
return out, nil
}
func getTensorValue(tensor *tf.Tensor) interface{} {
switch tensor.Value().(type) {
case [][]string:
return tensor.Value().([][]string)
case [][]float32:
return tensor.Value().([][]float32)
}
return nil
}
func createInputExampleTensor(inputName string, model *models.Model) (*tf.Tensor, error) {
pb, err := Example(model.Inputs[inputName])
if err != nil {
return nil, fmt.Errorf("Failed to create Example: %s", err)
}
byteList, err := proto.Marshal(pb)
if err != nil {
return nil, fmt.Errorf("marshaling error: %s", err)
}
newTensor, err := tf.NewTensor([]string{string(byteList)})
if err != nil {
return nil, err
}
return newTensor, nil
}
func validateOperation(op string, savedModel *tf.SavedModel) bool {
tfOp := savedModel.Graph.Operation(op)
if tfOp == nil {
return false
}
return true
}