Skip to content

Commit

Permalink
update iou computation
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 1, 2019
1 parent aa81dc8 commit 68e9add
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 19 deletions.
15 changes: 10 additions & 5 deletions examples/dgcnn_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn import DynamicEdgeConv
from torch_geometric.utils import mean_iou
from torch_geometric.utils import intersection_and_union as i_and_u

from pointnet2_classification import MLP

Expand Down Expand Up @@ -89,20 +89,25 @@ def test(loader):
model.eval()

correct_nodes = total_nodes = 0
ious = []
intersections, unions = [], []
for data in loader:
data = data.to(device)
with torch.no_grad():
out = model(data)
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
ious += [mean_iou(pred, data.y, test_dataset.num_classes, data.batch)]
total_nodes += data.num_nodes
return correct_nodes / total_nodes, torch.cat(ious, dim=0).mean().item()
i, u = i_and_u(pred, data.y, test_dataset.num_classes, data.batch)
intersections.append(i.to(torch.device('cpu')))
unions.append(i.to(torch.device('cpu')))
i, u = torch.cat(intersections, dim=0), torch.cat(unions, dim=0)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1

return correct_nodes / total_nodes, iou.mean().item()


for epoch in range(1, 31):
train()
acc, iou = test(test_loader)
print('Epoch: {:02d}, Acc: {:.4f}, IoU: {:.4f}'.format(epoch, acc, iou))
scheduler.step()
14 changes: 10 additions & 4 deletions examples/pointnet2_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn import knn_interpolate
from torch_geometric.utils import mean_iou
from torch_geometric.utils import intersection_and_union as i_and_u

from pointnet2_classification import SAModule, GlobalSAModule, MLP

Expand Down Expand Up @@ -111,16 +111,22 @@ def test(loader):
model.eval()

correct_nodes = total_nodes = 0
ious = []
intersections, unions = [], []
for data in loader:
data = data.to(device)
with torch.no_grad():
out = model(data)
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
ious += [mean_iou(pred, data.y, test_dataset.num_classes, data.batch)]
total_nodes += data.num_nodes
return correct_nodes / total_nodes, torch.cat(ious, dim=0).mean().item()
i, u = i_and_u(pred, data.y, test_dataset.num_classes, data.batch)
intersections.append(i.to(torch.device('cpu')))
unions.append(i.to(torch.device('cpu')))
i, u = torch.cat(intersections, dim=0), torch.cat(unions, dim=0)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1

return correct_nodes / total_nodes, iou.mean().item()


for epoch in range(1, 31):
Expand Down
1 change: 1 addition & 0 deletions test/utils/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ def test_mean_iou():

batch = torch.tensor([0, 0, 0, 0, 1, 1])
out = mean_iou(pred, target, num_classes=2, batch=batch)
assert out.size() == (2, )
assert out[0] == (1 / 3 + 1 / 3) / 2
assert out[1] == 0.25
4 changes: 3 additions & 1 deletion torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
barabasi_albert_graph)
from .metric import (accuracy, true_positive, true_negative, false_positive,
false_negative, precision, recall, f1_score, mean_iou)
false_negative, precision, recall, f1_score,
intersection_and_union, mean_iou)

__all__ = [
'degree',
Expand Down Expand Up @@ -54,5 +55,6 @@
'precision',
'recall',
'f1_score',
'intersection_and_union',
'mean_iou',
]
34 changes: 25 additions & 9 deletions torch_geometric/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def f1_score(pred, target, num_classes):
return score


def mean_iou(pred, target, num_classes, batch=None):
r"""Computes the mean Intersection over Union score of predictions.
def intersection_and_union(pred, target, num_classes, batch=None):
r"""Computes intersection and union of predictions.
Args:
pred (LongTensor): The predictions.
Expand All @@ -156,18 +156,34 @@ def mean_iou(pred, target, num_classes, batch=None):
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
:rtype: :class:`Tensor`
:rtype: (:class:`LongTensor`, :class:`LongTensor`)
"""
pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)

if batch is not None:
i = scatter_add(pred & target, batch, dim=0).to(torch.float)
u = scatter_add(pred | target, batch, dim=0).to(torch.float)
if batch is None:
i = (pred & target).sum(dim=0)
u = (pred | target).sum(dim=0)
else:
i = (pred & target).sum(dim=0).to(torch.float)
u = (pred | target).sum(dim=0).to(torch.float)
i = scatter_add(pred & target, batch, dim=0)
u = scatter_add(pred | target, batch, dim=0)

return i, u


iou = i / u
def mean_iou(pred, target, num_classes, batch=None):
r"""Computes the mean intersection over union score of predictions.
Args:
pred (LongTensor): The predictions.
target (LongTensor): The targets.
num_classes (int): The number of classes.
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
:rtype: :class:`Tensor`
"""
i, u = intersection_and_union(pred, target, num_classes, batch)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
iou = iou.mean(dim=-1)
return iou

0 comments on commit 68e9add

Please sign in to comment.