/
embeddings.go
187 lines (145 loc) Β· 3.78 KB
/
embeddings.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
package face
import (
"encoding/json"
"fmt"
"strings"
"github.com/montanaflynn/stats"
"github.com/photoprism/photoprism/pkg/clusters"
)
// Embeddings represents a face embedding cluster.
type Embeddings []Embedding
// NewEmbeddings creates a new embeddings from inference results.
func NewEmbeddings(inference [][]float32) Embeddings {
result := make(Embeddings, len(inference))
var v []float32
var i int
for i, v = range inference {
e := NewEmbedding(v)
if e.CanMatch() {
result[i] = e
}
}
return result
}
// Empty tests if embeddings are empty.
func (embeddings Embeddings) Empty() bool {
if len(embeddings) < 1 {
return true
}
return len(embeddings[0]) < 1
}
// Count returns the number of embeddings.
func (embeddings Embeddings) Count() int {
if embeddings.Empty() {
return 0
}
return len(embeddings)
}
// Kind returns the type of face e.g. regular, kids, or ignored.
func (embeddings Embeddings) Kind() (result Kind) {
for _, e := range embeddings {
if k := e.Kind(); k > result {
result = k
}
}
return result
}
// One tests if there is exactly one embedding.
func (embeddings Embeddings) One() bool {
return embeddings.Count() == 1
}
// First returns the first face embedding.
func (embeddings Embeddings) First() Embedding {
if embeddings.Empty() {
return NullEmbedding
}
return embeddings[0]
}
// Float64 returns embeddings as a float64 slice.
func (embeddings Embeddings) Float64() [][]float64 {
result := make([][]float64, len(embeddings))
for i, e := range embeddings {
result[i] = e
}
return result
}
// Contains tests if another embeddings is contained within a radius.
func (embeddings Embeddings) Contains(other Embedding, radius float64) bool {
for _, e := range embeddings {
if d := e.Dist(other); d < radius {
return true
}
}
return false
}
// Dist returns the minimum distance to an embedding.
func (embeddings Embeddings) Dist(other Embedding) (dist float64) {
dist = -1
for _, e := range embeddings {
if d := e.Dist(other); d < dist || dist < 0 {
dist = d
}
}
return dist
}
// JSON returns the embeddings as JSON bytes.
func (embeddings Embeddings) JSON() []byte {
var noResult = []byte("")
if embeddings.Empty() {
return noResult
}
if result, err := json.Marshal(embeddings); err != nil {
return noResult
} else {
return result
}
}
// EmbeddingsMidpoint returns the embeddings vector midpoint.
func EmbeddingsMidpoint(embeddings Embeddings) (result Embedding, radius float64, count int) {
// Return if there are no embeddings.
if embeddings.Empty() {
return Embedding{}, 0, 0
}
// Count embeddings.
count = len(embeddings)
// Only one embedding?
if count == 1 {
// Return embedding if there is only one.
return embeddings[0], 0.0, 1
}
dim := len(embeddings[0])
// No embedding values?
if dim == 0 {
return Embedding{}, 0.0, count
}
result = make(Embedding, dim)
// The mean of a set of vectors is calculated component-wise.
for i := 0; i < dim; i++ {
values := make(stats.Float64Data, count)
for j := 0; j < count; j++ {
values[j] = embeddings[j][i]
}
if m, err := stats.Mean(values); err != nil {
log.Warnf("embeddings: %s", err)
} else {
result[i] = m
}
}
// Radius is the max embedding distance + 0.01 from result.
for _, emb := range embeddings {
if d := clusters.EuclideanDist(result, emb); d > radius {
radius = d + 0.01
}
}
return result, radius, count
}
// UnmarshalEmbeddings parses face embedding JSON.
func UnmarshalEmbeddings(s string) (result Embeddings, err error) {
if s == "" {
return result, fmt.Errorf("cannot unmarshal empeddings, empty string provided")
} else if !strings.HasPrefix(s, "[[") {
return result, fmt.Errorf("cannot unmarshal empeddings, invalid json provided")
}
err = json.Unmarshal([]byte(s), &result)
return result, err
}