-
Notifications
You must be signed in to change notification settings - Fork 102
/
dataset_wrapper.py
127 lines (108 loc) · 5.2 KB
/
dataset_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch
import numba as nb
from torch.utils import data
from dataloader.transform_3d import PadMultiViewImage, NormalizeMultiviewImage, \
PhotoMetricDistortionMultiViewImage, RandomScaleImageMultiViewImage
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
class DatasetWrapper_NuScenes(data.Dataset):
def __init__(self, in_dataset, grid_size, fill_label=0,
fixed_volume_space=False, max_volume_space=[51.2, 51.2, 3],
min_volume_space=[-51.2, -51.2, -5], phase='train', scale_rate=1):
'Initialization'
self.imagepoint_dataset = in_dataset
self.grid_size = np.asarray(grid_size)
self.fill_label = fill_label
self.fixed_volume_space = fixed_volume_space
self.max_volume_space = max_volume_space
self.min_volume_space = min_volume_space
if scale_rate != 1:
if phase == 'train':
transforms = [
PhotoMetricDistortionMultiViewImage(),
NormalizeMultiviewImage(**img_norm_cfg),
RandomScaleImageMultiViewImage([scale_rate]),
PadMultiViewImage(size_divisor=32)
]
else:
transforms = [
NormalizeMultiviewImage(**img_norm_cfg),
RandomScaleImageMultiViewImage([scale_rate]),
PadMultiViewImage(size_divisor=32)
]
else:
if phase == 'train':
transforms = [
PhotoMetricDistortionMultiViewImage(),
NormalizeMultiviewImage(**img_norm_cfg),
PadMultiViewImage(size_divisor=32)
]
else:
transforms = [
NormalizeMultiviewImage(**img_norm_cfg),
PadMultiViewImage(size_divisor=32)
]
self.transforms = transforms
def __len__(self):
return len(self.imagepoint_dataset)
def __getitem__(self, index):
data = self.imagepoint_dataset[index]
imgs, img_metas, xyz, labels = data
# deal with img augmentations
imgs_dict = {'img': imgs, 'lidar2img': img_metas['lidar2img']}
for t in self.transforms:
imgs_dict = t(imgs_dict)
imgs = imgs_dict['img']
imgs = [img.transpose(2, 0, 1) for img in imgs]
img_metas['img_shape'] = imgs_dict['img_shape']
img_metas['lidar2img'] = imgs_dict['lidar2img']
assert self.fixed_volume_space
max_bound = np.asarray(self.max_volume_space) # 51.2 51.2 3
min_bound = np.asarray(self.min_volume_space) # -51.2 -51.2 -5
# get grid index
crop_range = max_bound - min_bound
cur_grid_size = self.grid_size # 200, 200, 16
# TODO: intervals should not minus one.
intervals = crop_range / (cur_grid_size - 1)
if (intervals == 0).any():
print("Zero interval!")
# TODO: grid_ind_float should actually be returned.
# grid_ind_float = (np.clip(xyz, min_bound, max_bound - 1e-3) - min_bound) / intervals
grid_ind_float = (np.clip(xyz, min_bound, max_bound) - min_bound) / intervals
grid_ind = np.floor(grid_ind_float).astype(np.int)
# process labels
processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.fill_label
label_voxel_pair = np.concatenate([grid_ind, labels], axis=1)
label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :]
processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair)
data_tuple = (imgs, img_metas, processed_label)
data_tuple += (grid_ind, labels)
return data_tuple
@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False)
def nb_process_label(processed_label, sorted_label_voxel_pair):
label_size = 256
counter = np.zeros((label_size,), dtype=np.uint16)
counter[sorted_label_voxel_pair[0, 3]] = 1
cur_sear_ind = sorted_label_voxel_pair[0, :3]
for i in range(1, sorted_label_voxel_pair.shape[0]):
cur_ind = sorted_label_voxel_pair[i, :3]
if not np.all(np.equal(cur_ind, cur_sear_ind)):
processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
counter = np.zeros((label_size,), dtype=np.uint16)
cur_sear_ind = cur_ind
counter[sorted_label_voxel_pair[i, 3]] += 1
processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
return processed_label
def custom_collate_fn(data):
img2stack = np.stack([d[0] for d in data]).astype(np.float32)
meta2stack = [d[1] for d in data]
label2stack = np.stack([d[2] for d in data]).astype(np.int)
# because we use a batch size of 1, so we can stack these tensor together.
grid_ind_stack = np.stack([d[3] for d in data]).astype(np.float)
point_label = np.stack([d[4] for d in data]).astype(np.int)
return torch.from_numpy(img2stack), \
meta2stack, \
torch.from_numpy(label2stack), \
torch.from_numpy(grid_ind_stack), \
torch.from_numpy(point_label)