-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset_custom.py
108 lines (86 loc) · 3.44 KB
/
dataset_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import sys
from tqdm import tqdm
import json
import glob
from plyfile import PlyData, PlyElement
# In[ ]:
class CustomDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
classification=False,
class_choice=None,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.data_augmentation = data_augmentation
self.classification = classification
self.seg_classes = {}
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
self.id2cat = {v: k for k, v in self.cat.items()}
self.meta = {}
for item in self.cat:
self.meta[item] = []
for key, value in self.cat.items():
newDir = os.path.join(self.root, value)
for file in glob.glob(newDir + '/*.txt'):
for category in self.cat.values():
self.meta[self.id2cat[category]].append(file)
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn))
self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
#print(self.classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes_head.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.seg_classes[ls[0]] = int(ls[1])
self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
#print(self.seg_classes, self.num_seg_classes)
def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
data = np.loadtxt(fn[1])
point_set = data[:, 0:3].astype(np.float32)
seg = data[:, -1].astype(np.int64)
#print(point_set.shape, seg.shape)
choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]
point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
point_set = point_set / dist #scale
if self.data_augmentation:
theta = np.random.uniform(0,np.pi*2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
if self.classification:
return point_set, cls
else:
return point_set, seg
def __len__(self):
return len(self.datapath)