/
tf_predict.go
119 lines (98 loc) · 2.8 KB
/
tf_predict.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package tf
import (
"errors"
"sync"
tfg "github.com/tensorflow/tensorflow/tensorflow/go"
)
type Model struct {
name string // memory store tf model name, for supporting multi model
path string // load model use, model dir(exportDir)
tags []string // load model use, model tags
options *tfg.SessionOptions // load model use, session options
model *tfg.SavedModel // load and save tf model
inputParamKey []string // required
outputParamKey string // required
count int // stats: load count
lock sync.RWMutex
}
// New according the input params to generate the special tf model
func New(name, exportDir string, tags, inputParamKey []string, outputParamKey string) *Model {
return &Model{
name: name,
path: exportDir,
tags: tags,
inputParamKey: inputParamKey,
outputParamKey: outputParamKey,
}
}
// Predict tf predict
func (m *Model) Predict(dataSet []interface{}) (ret interface{}, err error) {
if dataSet == nil || len(dataSet) == 0 {
return nil, errors.New("nil input")
}
if len(m.inputParamKey) != len(dataSet) {
return nil, errors.New("input data size not equal param key size")
}
input := make(map[tfg.Output]*tfg.Tensor, len(dataSet))
for index, data := range dataSet {
tfData, err := tfg.NewTensor(data)
if err != nil {
return nil, err
}
input[m.model.Graph.Operation(m.inputParamKey[index]).Output(index)] = tfData
}
output := []tfg.Output{
m.model.Graph.Operation(m.outputParamKey).Output(0),
}
rt, err := m.model.Session.Run(input, output, nil)
if err != nil {
return nil, err
}
ret = rt[0].Value() // WARN: only result
return
}
// Load tf model from special path
func (m *Model) Load() error {
m.lock.Lock()
defer m.lock.Unlock()
// TODO 1. judge model file exist
// TODO 2. check others
tfModel, err := tfg.LoadSavedModel(m.path, m.tags, m.options)
if err != nil {
return err
}
m.model = tfModel
m.count++
return nil
}
// Register register and load model
func Register(name, exportDir string, tags []string) (*Model, error) {
return RegisterWithParamName(name, exportDir, tags, []string{"serving_default_input"}, "StatefulPartitionedCall")
}
// RegisterWithParamName register with param key, and load model
func RegisterWithParamName(name, exportDir string, tags, inputParamKey []string, outputParamKey string) (*Model, error) {
m := New(name, exportDir, tags, inputParamKey, outputParamKey)
return m, m.Load()
}
// Destruct destroy model
func (m *Model) Destruct() error {
m.lock.Lock()
defer m.lock.Unlock()
if m == nil {
return nil
}
m.tags = nil
m.options = nil
m.model = nil
return nil
}
func (m *Model) Name() string {
return m.name
}
func (m *Model) Path() string {
return m.path
}
// Version tensorflow version
func Version() string {
return tfg.Version()
}