-
Notifications
You must be signed in to change notification settings - Fork 1
/
iou.go
62 lines (53 loc) · 2.17 KB
/
iou.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package metrics
import (
"github.com/chewxy/math32"
"github.com/rai-project/config"
"github.com/rai-project/dlframework"
)
// https://stackoverflow.com/questions/28723670/intersection-over-union-between-two-detections
// https://resources.wolframcloud.com/NeuralNetRepository/resources/SSD-VGG-300-Trained-on-PASCAL-VOC-Data
func BoundingBoxIntersectionOverUnion(boxA, boxB *dlframework.BoundingBox) float64 {
// determine the (x, y)-coordinates of the intersection rectangle
xA := math32.Max(boxA.GetXmin(), boxB.GetXmin())
yA := math32.Max(boxA.GetYmin(), boxB.GetYmin())
xB := math32.Min(boxA.GetXmax(), boxB.GetXmax())
yB := math32.Min(boxA.GetYmax(), boxB.GetYmax())
// compute the area of intersection rectangle
interArea := float64(xB-xA) * float64(yB-yA)
// compute the area of both the prediction and ground-truth
// rectangles
boxAArea := float64(boxA.GetXmax()-boxA.GetXmin()) * float64(boxA.GetYmax()-boxA.GetYmin())
boxBArea := float64(boxB.GetXmax()-boxB.GetXmin()) * float64(boxB.GetYmax()-boxB.GetYmin())
// compute the intersection over union by taking the intersection
// area and dividing it by the sum of prediction + ground-truth
// areas - the interesection area
iou := interArea / (boxAArea + boxBArea - interArea)
// return the intersection over union value
return iou
}
func IntersectionOverUnion(featA, featB *dlframework.Feature) float64 {
boxA, ok := featA.Feature.(*dlframework.Feature_BoundingBox)
if !ok {
panic("unable to convert first feature to boundingbox")
}
boxB, ok := featB.Feature.(*dlframework.Feature_BoundingBox)
if !ok {
panic("unable to convert second feature to boundingbox")
}
return BoundingBoxIntersectionOverUnion(boxA.BoundingBox, boxB.BoundingBox)
}
func init() {
config.AfterInit(func() {
RegisterFeatureCompareFunction("IntersectionOverUnion",
func(actual *dlframework.Features, expected interface{}) float64 {
if actual == nil || len(*actual) != 1 {
panic("expecting one feature for argument")
}
expectedFeature, ok := expected.(*dlframework.Feature)
if !ok {
panic("expecting a feature for second argument")
}
return IntersectionOverUnion((*actual)[0], expectedFeature)
})
})
}