Skip to content

Commit

Permalink
fixed iou computation for multiple categories
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 1, 2019
1 parent 68e9add commit 83f6c30
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
25 changes: 19 additions & 6 deletions examples/dgcnn_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test(loader):
model.eval()

correct_nodes = total_nodes = 0
intersections, unions = [], []
intersections, unions, categories = [], [], []
for data in loader:
data = data.to(device)
with torch.no_grad():
Expand All @@ -99,12 +99,25 @@ def test(loader):
total_nodes += data.num_nodes
i, u = i_and_u(pred, data.y, test_dataset.num_classes, data.batch)
intersections.append(i.to(torch.device('cpu')))
unions.append(i.to(torch.device('cpu')))
i, u = torch.cat(intersections, dim=0), torch.cat(unions, dim=0)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
unions.append(u.to(torch.device('cpu')))
categories.append(data.category.to(torch.device('cpu')))

return correct_nodes / total_nodes, iou.mean().item()
category = torch.cat(categories, dim=0)
intersection = torch.cat(intersections, dim=0)
union = torch.cat(unions, dim=0)

ious = [[]] * len(loader.dataset.categories)
for j in range(len(loader.dataset)):
i = intersection[j, loader.dataset.y_mask[category[j]]]
u = union[j, loader.dataset.y_mask[category[j]]]
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
ious[category[j]].append(iou.mean().item())

for cat in range(len(loader.dataset.categories)):
ious[cat] = torch.tensor(ious[cat]).mean().item()

return correct_nodes / total_nodes, torch.tensor(ious).mean().item()


for epoch in range(1, 31):
Expand Down
25 changes: 19 additions & 6 deletions examples/pointnet2_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test(loader):
model.eval()

correct_nodes = total_nodes = 0
intersections, unions = [], []
intersections, unions, categories = [], [], []
for data in loader:
data = data.to(device)
with torch.no_grad():
Expand All @@ -121,12 +121,25 @@ def test(loader):
total_nodes += data.num_nodes
i, u = i_and_u(pred, data.y, test_dataset.num_classes, data.batch)
intersections.append(i.to(torch.device('cpu')))
unions.append(i.to(torch.device('cpu')))
i, u = torch.cat(intersections, dim=0), torch.cat(unions, dim=0)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
unions.append(u.to(torch.device('cpu')))
categories.append(data.category.to(torch.device('cpu')))

return correct_nodes / total_nodes, iou.mean().item()
category = torch.cat(categories, dim=0)
intersection = torch.cat(intersections, dim=0)
union = torch.cat(unions, dim=0)

ious = [[]] * len(loader.dataset.categories)
for j in range(len(loader.dataset)):
i = intersection[j, loader.dataset.y_mask[category[j]]]
u = union[j, loader.dataset.y_mask[category[j]]]
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
ious[category[j]].append(iou.mean().item())

for cat in range(len(loader.dataset.categories)):
ious[cat] = torch.tensor(ious[cat]).mean().item()

return correct_nodes / total_nodes, torch.tensor(ious).mean().item()


for epoch in range(1, 31):
Expand Down
28 changes: 20 additions & 8 deletions torch_geometric/datasets/shapenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self,
pre_filter)
path = self.processed_paths[0] if train else self.processed_paths[1]
self.data, self.slices = torch.load(path)
self.y_mask = torch.load(self.processed_paths[2])

@property
def raw_file_names(self):
Expand All @@ -91,7 +92,9 @@ def raw_file_names(self):
@property
def processed_file_names(self):
cats = '_'.join([cat[:3].lower() for cat in self.categories])
return ['{}_{}.pt'.format(cats, s) for s in ['training', 'test']]
return [
'{}_{}.pt'.format(cats, s) for s in ['training', 'test', 'y_mask']
]

def download(self):
for name in self.raw_file_names:
Expand All @@ -103,8 +106,9 @@ def download(self):
def process_raw_path(self, data_path, label_path):
y_offset = 0
data_list = []
for category in self.categories:
idx = self.category_ids[category]
cat_ys = []
for cat_idx, cat in enumerate(self.categories):
idx = self.category_ids[cat]
point_paths = sorted(glob.glob(osp.join(data_path, idx, '*.pts')))
y_paths = sorted(glob.glob(osp.join(label_path, idx, '*.seg')))

Expand All @@ -113,26 +117,34 @@ def process_raw_path(self, data_path, label_path):
lens = [y.size(0) for y in ys]

y = torch.cat(ys).unique(return_inverse=True)[1] + y_offset
cat_ys.append(y.unique())
y_offset = y.max().item() + 1
ys = y.split(lens)

for (pos, y) in zip(points, ys):
data = Data(y=y, pos=pos)
data = Data(y=y, pos=pos, category=cat_idx)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
return data_list

y_mask = torch.zeros((len(self.categories), y_offset),
dtype=torch.uint8)
for i in range(len(cat_ys)):
y_mask[i, cat_ys[i]] = 1

return data_list, y_mask

def process(self):
train_data_list = self.process_raw_path(*self.raw_paths[0:2])
val_data_list = self.process_raw_path(*self.raw_paths[2:4])
test_data_list = self.process_raw_path(*self.raw_paths[4:6])
train_data_list, y_mask = self.process_raw_path(*self.raw_paths[0:2])
val_data_list, _ = self.process_raw_path(*self.raw_paths[2:4])
test_data_list, _ = self.process_raw_path(*self.raw_paths[4:6])

data = self.collate(train_data_list + val_data_list)
torch.save(data, self.processed_paths[0])
torch.save(self.collate(test_data_list), self.processed_paths[1])
torch.save(y_mask, self.processed_paths[2])

def __repr__(self):
return '{}({}, categories={})'.format(self.__class__.__name__,
Expand Down

0 comments on commit 83f6c30

Please sign in to comment.