Skip to content

Commit

Permalink
fixed caching and sampling problem
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyul committed Dec 31, 2019
1 parent 3d6c7cb commit fbc174f
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions flying_things_dataset.py
Expand Up @@ -44,27 +44,27 @@ def __getitem__(self, index):
if len(self.cache) < self.cache_size:
self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)

if self.train:
n1 = pos1.shape[0]
sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
n2 = pos2.shape[0]
sample_idx2 = np.random.choice(n2, self.npoints, replace=False)

pos1 = pos1[sample_idx1, :]
pos2 = pos2[sample_idx2, :]
color1 = color1[sample_idx1, :]
color2 = color2[sample_idx2, :]
flow = flow[sample_idx1, :]
mask1 = mask1[sample_idx1]
else:
pos1 = pos1[:self.npoints, :]
pos2 = pos2[:self.npoints, :]
color1 = color1[:self.npoints, :]
color2 = color2[:self.npoints, :]
flow = flow[:self.npoints, :]
mask1 = mask1[:self.npoints]

return pos1, pos2, color1, color2, flow, mask1
if self.train:
n1 = pos1.shape[0]
sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
n2 = pos2.shape[0]
sample_idx2 = np.random.choice(n2, self.npoints, replace=False)

pos1_ = np.copy(pos1[sample_idx1, :])
pos2_ = np.copy(pos2[sample_idx2, :])
color1_ = np.copy(color1[sample_idx1, :])
color2_ = np.copy(color2[sample_idx2, :])
flow_ = np.copy(flow[sample_idx1, :])
mask1_ = np.copy(mask1[sample_idx1])
else:
pos1_ = np.copy(pos1[:self.npoints, :])
pos2_ = np.copy(pos2[:self.npoints, :])
color1_ = np.copy(color1[:self.npoints, :])
color2_ = np.copy(color2[:self.npoints, :])
flow_ = np.copy(flow[:self.npoints, :])
mask1_ = np.copy(mask1[:self.npoints])

return pos1_, pos2_, color1_, color2_, flow_, mask1_

def __len__(self):
return len(self.datapath)
Expand Down

0 comments on commit fbc174f

Please sign in to comment.