/
inference.go
42 lines (38 loc) · 1.22 KB
/
inference.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package codelab
import (
"errors"
"log"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// GetInputShape finds the input image dimensions.
func GetInputShape(graph *tf.Graph) (width, height int) {
input := graph.Operation("module/hub_input/images")
if input == nil {
log.Fatal("Cannot find tensor \"module/hub_input/images\"")
}
shape := input.Output(0).Shape()
return int(shape.Size(1)), int(shape.Size(2))
}
// RunInference executes the model and returns the logits.
func RunInference(graph *tf.Graph, session *tf.Session, image [][][3]float32) ([]float32, error) {
inputOp := graph.Operation("module/hub_input/images")
if inputOp == nil {
return nil, errors.New("Cannot find tensor \"module/hub_input/images\"")
}
input := inputOp.Output(0)
outputOp := graph.Operation("module/MobilenetV2/Logits/output")
if outputOp == nil {
return nil, errors.New("Cannot find tensor \"module/MobilenetV2/Logits/output\"")
}
output := outputOp.Output(0)
images, err := tf.NewTensor([][][][3]float32{image})
if err != nil {
return nil, err
}
result, err := session.Run(
map[tf.Output]*tf.Tensor{input: images}, []tf.Output{output}, nil)
if err != nil {
return nil, err
}
return result[0].Value().([][]float32)[0], nil
}