generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
52 lines (43 loc) · 1.38 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
#Define mean IoU metric
#Taken from fastai v1
def one_hot(input, targs, classes=None, argmax=True):
n, c, h, w = input.shape
if classes is None:
classes = c
range_tensor_ = (
torch.stack([torch.arange(classes)] * w * h, dim=1)
.view(classes, -1)
.to(input.device, non_blocking=True)
)
range_tensor_batch_ = (
torch.stack([range_tensor_] * n, dim=1)
.float()
.to(input.device, non_blocking=True)
)
if argmax:
input = input.argmax(dim=1)
input_ = (
torch.stack([input] * classes)
.float()
.view(classes, n, -1)
.to(input.device, non_blocking=True)
)
targs_ = (
torch.stack([targs.squeeze(1)] * classes)
.float()
.view(classes, n, -1)
.to(input.device, non_blocking=True)
)
input_ = (input_ == range_tensor_batch_).float()
targs_ = (targs_ == range_tensor_batch_).float()
return input_, targs_, classes, n, h, w
def IOU(input, targs, classes=None, argmax=True, eps=1e-15):
input_, targs_, classes, n, h, w = one_hot(input, targs, classes, argmax)
intersect_ = input_ * targs_
union_ = input_ + targs_
ious = intersect_.sum(dim=2).float() / (
union_.sum(dim=2).float() - intersect_.sum(dim=2).float() + eps
)
res = ious.sum(dim=1) / (n * classes)
return torch.tensor(res.sum())