In [None]:
from dataset import MultiObject3dDataset
from utils import get_clustering_colors

import matplotlib.pyplot as plt
import numpy as np

In [None]:
dataset_name = 'clevr3d'
#dataset_name = 'multishapenet'


dataset = MultiObject3dDataset(f'../data/{dataset_name}', 'val', dataset=dataset_name,
                               max_n=6, points_per_item=20000, do_frustum_culling=False)
width = 320
height = 240
idx = 3

In [None]:
example = dataset.__getitem__(idx, noisy=False)

In [None]:
# Plot input view in 2D
image = example.get('inputs')
image = np.transpose(image, (1, 2, 0))

input_depths = example['input_depths']

print('Input image')
plt.imshow(image)
plt.show()

print('Input depths')
plt.imshow(input_depths)
plt.show()

print('Masks')
masks = example['masks']
mask_colors = get_clustering_colors(10)
mask_idx = masks.argmax(0)

print(masks.shape, mask_colors.shape)

mask_img = np.einsum('khw,kc->hwc', masks, mask_colors[:masks.shape[0]])
plt.imshow(mask_img)
plt.show()

In [None]:
# Plot input view in 3D
import plotly.express as px
import plotly.graph_objects as go

input_points = example['input_points']
input_points_flat = np.reshape(input_points, (-1, 3))

image_flat = np.reshape(image, (-1, 3))
values_plotly  = [f'rgb{tuple((np.array(color)*255).astype(np.uint8))}' for color in image_flat]

fig = px.scatter_3d()

fig.add_trace(go.Scatter3d(x=input_points_flat[..., 0], y=input_points_flat[..., 1], z=input_points_flat[..., 2],
                           mode='markers', name='surface', marker=dict(size=1, color=values_plotly)))
min_c = input_points.min()
max_c = input_points.max()
full_scene = dict(   xaxis = dict(range=[min_c, max_c],),
                     yaxis = dict(range=[min_c, max_c],),
                     zaxis = dict(range=[min_c, max_c],),
                     aspectratio=dict(x=1, y=1, z=1)
                 )


fig.update_layout(scene=full_scene, coloraxis_showscale=True, width=1000, height=750)


In [None]:
# Plot training points in 3D

surface_points = example.get('surface_points')
c_pos = example.get('camera_pos')
empty_points = example.get('empty_points')

values = example['values']
values_plotly  = [f'rgb{tuple((np.array(color)*255).astype(np.uint8))}' for color in values]

fig = px.scatter_3d()
fig.add_trace(go.Scatter3d(x=surface_points[..., 0], y=surface_points[..., 1], z=surface_points[..., 2],
                           mode='markers', name='surface', marker=dict(size=1, color=values_plotly)))
fig.update_traces(marker=dict(size=1), selector=dict(mode='markers'))

# We only plot 1000 of the empty points, to avoid making the plot to busy. Adjust as preferred.
fig.add_trace(go.Scatter3d(x=empty_points[:1000, 0], y=empty_points[:1000, 1], z=empty_points[:1000, 2],
                           marker=dict(size=1), mode='markers',
                           name='empty'))


fig.add_trace(go.Scatter3d(x=c_pos[:1], y=c_pos[1:2], z=c_pos[2:], marker=dict(size=3),
                           marker_symbol='x', name='camera'))

max_c = surface_points.max()
min_c = surface_points.min()

full_scene = dict(   xaxis = dict(range=[min_c, max_c],),
                     yaxis = dict(range=[min_c, max_c],),
                     zaxis = dict(range=[min_c, max_c],),
                     aspectratio=dict(x=1, y=1, z=1)
                 )

fig.update_layout(scene=full_scene, coloraxis_showscale=False, width=1000, height=750)
