Skip to content
Permalink
Browse files

fix iou computation

  • Loading branch information...
rusty1s committed Aug 4, 2019
1 parent 67476ff commit 527e0a0e54756328c27caf68be02e8dedc0f70f2
Showing with 18 additions and 26 deletions.
  1. +9 −13 examples/dgcnn_segmentation.py
  2. +9 −13 examples/pointnet2_segmentation.py
@@ -20,18 +20,14 @@
T.RandomRotate(15, axis=2)
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(
path,
category,
train=True,
transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(
path, category, train=False, pre_transform=pre_transform)
train_loader = DataLoader(
train_dataset, batch_size=10, shuffle=True, num_workers=6)
test_loader = DataLoader(
test_dataset, batch_size=10, shuffle=False, num_workers=6)
train_dataset = ShapeNet(path, category, train=True, transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, train=False,
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True,
num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False,
num_workers=6)


class Net(torch.nn.Module):
@@ -106,7 +102,7 @@ def test(loader):
intersection = torch.cat(intersections, dim=0)
union = torch.cat(unions, dim=0)

ious = [[]] * len(loader.dataset.categories)
ious = [[] for _ in range(loader.dataset.categories)]
for j in range(len(loader.dataset)):
i = intersection[j, loader.dataset.y_mask[category[j]]]
u = union[j, loader.dataset.y_mask[category[j]]]
@@ -19,18 +19,14 @@
T.RandomRotate(15, axis=2)
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(
path,
category,
train=True,
transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(
path, category, train=False, pre_transform=pre_transform)
train_loader = DataLoader(
train_dataset, batch_size=12, shuffle=True, num_workers=6)
test_loader = DataLoader(
test_dataset, batch_size=12, shuffle=False, num_workers=6)
train_dataset = ShapeNet(path, category, train=True, transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, train=False,
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True,
num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=12, shuffle=False,
num_workers=6)


class FPModule(torch.nn.Module):
@@ -128,7 +124,7 @@ def test(loader):
intersection = torch.cat(intersections, dim=0)
union = torch.cat(unions, dim=0)

ious = [[]] * len(loader.dataset.categories)
ious = [[] for _ in range(loader.dataset.categories)]
for j in range(len(loader.dataset)):
i = intersection[j, loader.dataset.y_mask[category[j]]]
u = union[j, loader.dataset.y_mask[category[j]]]

0 comments on commit 527e0a0

Please sign in to comment.
You can’t perform that action at this time.