In [1]:
import numpy as onp 
import data_prep

NPOINT = 1024
NMASK = 10
TRAIN_DATASET = data_prep.FlowDataset('data/flow_train.mat', npoint=NPOINT)

def get_batch_data(dataset, idxs, start_idx, end_idx):
    bsize = end_idx-start_idx
    batch_pcpair = onp.zeros((bsize, NPOINT, 6), dtype=onp.float)
    batch_flow = onp.zeros((bsize, NPOINT, 3), dtype=onp.float)
    batch_vismask = onp.zeros((bsize, NPOINT), dtype=onp.float)
    batch_momasks = onp.zeros((bsize, NMASK, NPOINT), dtype=onp.float)
    for i in range(bsize):
        pc1, pc2, flow12, vismask, momasks = dataset[idxs[i+start_idx]]
        batch_pcpair[i,...] = onp.concatenate((pc1,pc2), 1)
        batch_flow[i,...] = flow12
        batch_vismask[i,:] = vismask
        batch_momasks[i,...] = onp.transpose(momasks)
    return batch_pcpair, batch_flow, batch_vismask, batch_momasks

train_idxs = onp.arange(0, len(TRAIN_DATASET))
start_idx = 0
end_idx = len(TRAIN_DATASET)
batch_pcpair, batch_flow, batch_vismask, batch_momasks = get_batch_data(TRAIN_DATASET, train_idxs, start_idx, end_idx)

In [2]:
print(batch_pcpair.shape)
print(batch_flow.shape)

(15084, 1024, 6)
(15084, 1024, 3)


In [3]:
xyz1, xyz2 = onp.split(batch_pcpair, indices_or_sections=2, axis=2)

In [4]:
# diff = onp.expand_dims(xyz2, 1) - onp.expand_dims(xyz1 + batch_flow, 2)
# print(diff.shape)
# matching = onp.argmin(
#     onp.sum(
#         onp.square(
#             diff
#         ),
#         axis=-1
#     ),
#     axis=2
# )

In [5]:
colors = onp.random.rand(NPOINT, 3)

In [None]:
import open3d as o3d

idx = 200
viz_offset = 1.0
nssp = NPOINT

permidx = onp.random.permutation(xyz1[idx].shape[0])[:nssp]
ssp_xyz1 = xyz1[idx][permidx]
ssp_xyz2 = xyz2[idx][permidx]

pcd1 = o3d.geometry.PointCloud()
pcd1.points = o3d.utility.Vector3dVector(ssp_xyz1)

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(ssp_xyz2 + viz_offset)

corr = [(i, i) for i in range(len(ssp_xyz1))]

line_set = o3d.geometry.LineSet().create_from_point_cloud_correspondences(pcd1, pcd2, corr)
line_set.colors = o3d.utility.Vector3dVector(colors[:nssp])

o3d.visualization.draw_geometries([pcd1, pcd2])