-
Notifications
You must be signed in to change notification settings - Fork 110
/
postprocessor.go
37 lines (33 loc) · 1.01 KB
/
postprocessor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package classification
import "strings"
// Postprocessor defines a function that filters/modifies on an incoming array of Classifications.
type Postprocessor func(Classifications) Classifications
// NewScoreFilter returns a function that filters out classifications below a certain confidence
// score.
func NewScoreFilter(conf float64) Postprocessor {
return func(in Classifications) Classifications {
out := make(Classifications, 0, len(in))
for _, c := range in {
if c.Score() >= conf {
out = append(out, c)
}
}
return out
}
}
// NewLabelFilter returns a function that filters out classifications without one of the chosen labels.
// Does not filter when input is empty.
func NewLabelFilter(labels map[string]interface{}) Postprocessor {
return func(in Classifications) Classifications {
if len(labels) < 1 {
return in
}
out := make(Classifications, 0, len(in))
for _, c := range in {
if _, ok := labels[strings.ToLower(c.Label())]; ok {
out = append(out, c)
}
}
return out
}
}