/
classifier.go
86 lines (73 loc) · 2.31 KB
/
classifier.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// This package is an extremely naive implementation of a tensorflow image
// classification wrapper. It abstracts away a good amount of the boilerplate
// required to load and process images using the model/label outputs from
// tensorhubs retrain.py
// See: https://github.com/tensorflow/hub/blob/master/examples/image_retraining/retrain.py
package classifier /* import "s32x.com/gamedetect/classifier" */
import (
"bufio"
"io/ioutil"
"os"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// Classifier is a struct used for classifying images
type Classifier struct {
config Config
graph *tf.Graph
session *tf.Session
labels []string
}
// NewClassifier creates a new Classifier using the default config
func NewClassifier(graphPath, labelPath string) (*Classifier, error) {
return NewClassifierWithConfig(graphPath, labelPath, DefaultConfig)
}
// NewClassifierWithConfig creates a new image Classifier for processing image
// predictions
func NewClassifierWithConfig(graphPath, labelPath string,
config Config) (*Classifier, error) {
// Read the passed inception model file
bytes, err := ioutil.ReadFile(graphPath)
if err != nil {
return nil, err
}
// Populate a new graph using the read model
graph := tf.NewGraph()
if err := graph.Import(bytes, ""); err != nil {
return nil, err
}
// Create a new execution session using the graph
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
// Read all labels in the passed labelPath
labels, err := readLabels(labelPath)
if err != nil {
return nil, err
}
// Return a fully populated Classifier
return &Classifier{config: config, graph: graph, session: session,
labels: labels}, nil
}
// Close closes the Classifier by closing all it's closers ;)
func (c *Classifier) Close() error { return c.session.Close() }
// readLabels takes a path to a labels file, parses out, and returns all labels
// as a slice of strings
func readLabels(labelsPath string) ([]string, error) {
// Open the passed labels file and defer close it
f, err := os.Open(labelsPath)
if err != nil {
return nil, err
}
defer f.Close()
// Scan all lines and populate a slice of labels
var labels []string
s := bufio.NewScanner(f)
for s.Scan() {
labels = append(labels, s.Text())
}
if err := s.Err(); err != nil {
return nil, err
}
return labels, nil
}