-
Notifications
You must be signed in to change notification settings - Fork 81
/
s3dis.py
110 lines (98 loc) · 4.29 KB
/
s3dis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os.path as osp
from glob import glob
import numpy as np
import torch
from ..ops import voxelization_idx
from .custom import CustomDataset
class S3DISDataset(CustomDataset):
CLASSES = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'chair', 'table',
'bookcase', 'sofa', 'board', 'clutter')
def get_filenames(self):
if isinstance(self.prefix, str):
self.prefix = [self.prefix]
filenames_all = []
for p in self.prefix:
filenames = glob(osp.join(self.data_root, p + '*' + self.suffix))
assert len(filenames) > 0, f'Empty {p}'
filenames_all.extend(filenames)
filenames_all = sorted(filenames_all * self.repeat)
return filenames_all
def load(self, filename):
# TODO make file load results consistent
xyz, rgb, semantic_label, instance_label, _, _ = torch.load(filename)
# subsample data
if self.training:
N = xyz.shape[0]
inds = np.random.choice(N, int(N * 0.25), replace=False)
xyz = xyz[inds]
rgb = rgb[inds]
semantic_label = semantic_label[inds]
instance_label = self.getCroppedInstLabel(instance_label, inds)
return xyz, rgb, semantic_label, instance_label
def crop(self, xyz, step=64):
return super().crop(xyz, step=step)
def transform_test(self, xyz, rgb, semantic_label, instance_label):
# devide into 4 piecies
inds = np.arange(xyz.shape[0])
piece_1 = inds[::4]
piece_2 = inds[1::4]
piece_3 = inds[2::4]
piece_4 = inds[3::4]
xyz_aug = self.dataAugment(xyz, False, False, False)
xyz_list = []
xyz_middle_list = []
rgb_list = []
semantic_label_list = []
instance_label_list = []
for batch, piece in enumerate([piece_1, piece_2, piece_3, piece_4]):
xyz_middle = xyz_aug[piece]
xyz = xyz_middle * self.voxel_cfg.scale
xyz -= xyz.min(0)
xyz_list.append(np.concatenate([np.full((xyz.shape[0], 1), batch), xyz], 1))
xyz_middle_list.append(xyz_middle)
rgb_list.append(rgb[piece])
semantic_label_list.append(semantic_label[piece])
instance_label_list.append(instance_label[piece])
xyz = np.concatenate(xyz_list, 0)
xyz_middle = np.concatenate(xyz_middle_list, 0)
rgb = np.concatenate(rgb_list, 0)
semantic_label = np.concatenate(semantic_label_list, 0)
instance_label = np.concatenate(instance_label_list, 0)
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
return xyz, xyz_middle, rgb, semantic_label, instance_label
def collate_fn(self, batch):
if self.training:
return super().collate_fn(batch)
# assume 1 scan only
(scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num, inst_pointnum,
inst_cls, pt_offset_label) = batch[0]
scan_ids = [scan_id]
coords = coord.long()
batch_idxs = torch.zeros_like(coord[:, 0].int())
coords_float = coord_float.float()
feats = feat.float()
semantic_labels = semantic_label.long()
instance_labels = instance_label.long()
instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int)
instance_cls = torch.tensor([inst_cls], dtype=torch.long)
pt_offset_labels = pt_offset_label.float()
spatial_shape = np.clip((coords.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
None)
voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 4)
return {
'scan_ids': scan_ids,
'batch_idxs': batch_idxs,
'voxel_coords': voxel_coords,
'p2v_map': p2v_map,
'v2p_map': v2p_map,
'coords_float': coords_float,
'feats': feats,
'semantic_labels': semantic_labels,
'instance_labels': instance_labels,
'instance_pointnum': instance_pointnum,
'instance_cls': instance_cls,
'pt_offset_labels': pt_offset_labels,
'spatial_shape': spatial_shape,
'batch_size': 4
}