In [1]:
import torch
import numpy as np
import open3d as o3d
from open3d.web_visualizer import draw
import gpytorch
from im2mesh import config
from im2mesh.checkpoints import CheckpointIO
from im2mesh.common import make_3d_grid

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Open3D INFO] Resetting default logger to print to terminal.


In [28]:
device = torch.device('cpu')
cfg = config.load_config('configs/pointcloud/onet_pretrained.yaml', 'configs/default.yaml')

dataset = config.get_dataset('test', cfg, return_idx=True)
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(cfg['training']['out_dir'], model=model)
re = checkpoint_io.load(cfg['test']['model_file'])

https://s3.eu-central-1.amazonaws.com/avg-projects/occupancy_networks/models/onet_pcl2mesh-5c0be168.pt
=> Loading checkpoint from url...


In [29]:
def get_data_sample(idx=500):

    data = dataset[idx]
    print(dataset.metadata[dataset.get_model_dict(idx)['category']])
    
    inputs = torch.from_numpy(data.get('inputs'))
    inputs = inputs.unsqueeze(0).to(device)

    points = torch.from_numpy(data.get('points'))
    points = points.unsqueeze(0).to(device)
    return inputs, points

def predict(inputs, points):
    kwargs = {}
    model.eval()
    with torch.no_grad():
        c = model.encode_inputs(inputs)
        occ_hat = model.decode(points, z=None, c=c, **kwargs).probs

    return occ_hat.squeeze()

In [30]:
nx = 32
box_size = 1.1
points_grid = box_size * make_3d_grid((-0.5,)*3, (0.5,)*3, (nx,)*3)
points_grid = points_grid.unsqueeze(0).to(device)

In [31]:
inputs, points = get_data_sample(idx=2000)
print(inputs.shape, points.shape)
occ_hat = predict(inputs, points_grid)
print(occ_hat.shape)

{'id': '02691156', 'name': 'airplane,aeroplane,plane', 'idx': 4}
torch.Size([1, 300, 3]) torch.Size([1, 2048, 3])
torch.Size([32768])


In [66]:
occ_hat_pred = torch.where(occ_hat > 0.5)
print('num of points eval: {}'.format(occ_hat.shape))
print('num of pred points: {}'.format(occ_hat_pred[0].shape))

num of points eval: torch.Size([32768])
num of pred points: torch.Size([177])


(tensor([ 4559,  4560,  5583,  5584,  6607,  6608,  6639,  6640,  7631,  7632,
          7663,  7664,  8623,  8624,  8655,  8656,  8687,  8688,  9647,  9648,
          9679,  9680,  9711,  9712, 10671, 10672, 10703, 10704, 10735, 10736,
         11695, 11696, 11727, 11728, 11746, 11759, 11760, 12719, 12720, 12751,
         12752, 12783, 12784, 13743, 13744, 13767, 13775, 13776, 13784, 13807,
         13808, 14766, 14767, 14768, 14791, 14792, 14793, 14799, 14800, 14806,
         14807, 14808, 14831, 14832, 15789, 15790, 15791, 15792, 15793, 15794,
         15816, 15817, 15818, 15823, 15824, 15829, 15830, 15831, 15855, 15856,
         16813, 16814, 16815, 16816, 16817, 16818, 16841, 16842, 16847, 16848,
         16853, 16854, 16879, 16880, 17835, 17838, 17839, 17840, 17841, 17842,
         17844, 17871, 17872, 17903, 17904, 18859, 18862, 18863, 18864, 18865,
         18868, 18895, 18896, 18927, 18928, 19851, 19860, 19883, 19886, 19887,
         19888, 19889, 19892, 19919, 19920, 19951, 1

In [33]:
points_grid_np = points_grid.squeeze().cpu()
grid_colors = np.zeros(points_grid_np.shape)
grid_colors[:, 2] = 1
grid_pc = o3d.geometry.PointCloud()
grid_pc.points = o3d.utility.Vector3dVector(points_grid_np)
grid_pc.colors = o3d.utility.Vector3dVector(grid_colors)

points_grid_pred = points_grid[0, occ_hat_pred[0]].cpu().numpy()
output_colors = np.zeros(points_grid_pred.shape)
output_colors[:, 0] = 1
output_pc = o3d.geometry.PointCloud()
output_pc.points = o3d.utility.Vector3dVector(points_grid_pred)
output_pc.colors = o3d.utility.Vector3dVector(output_colors)


inputs_np = inputs.squeeze().cpu().numpy()
input_colors = np.zeros(inputs_np.shape)
input_colors[:, 1] = 1
input_pc = o3d.geometry.PointCloud()
input_pc.points = o3d.utility.Vector3dVector(inputs_np)
input_pc.colors = o3d.utility.Vector3dVector(input_colors)

draw([output_pc, input_pc])

[Open3D INFO] Window window_0 created.
[Open3D INFO] EGL headless mode enabled.


WebVisualizer(window_uid='window_0')

[Open3D INFO] ICE servers: {"stun:stun.l.google.com:19302", "turn:user:password@34.69.27.100:3478", "turn:user:password@34.69.27.100:3478?transport=tcp"}
FEngine (64 bits) created at 0x7fb660007460 (threading is enabled)
[Open3D INFO] Set WEBRTC_STUN_SERVER environment variable add a customized WebRTC STUN server.
[Open3D INFO] WebRTC Jupyter handshake mode enabled.
[Open3D INFO] [Called HTTP API (custom handshake)] /api/call
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCand

In [52]:
class OccNetMean(gpytorch.means.Mean):
    def __init__(self, occ_net, shape_code, batch_shape=torch.Size()):
        super(OccNetMean, self).__init__()
        self.batch_shape = batch_shape
        self.occ_net = occ_net.to('cpu')
        self.occ_net.eval()
        self.shape_code = shape_code
        
    def forward(self, x):
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        print(x.shape)
        with torch.no_grad():
            occ_hat = self.occ_net.decode(x, z=None, c=self.shape_code).probs

        return occ_hat.squeeze()

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, shape_code, occ_net,
                 kernel='rbf', mean_fun='occ', n=None, l=None):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        if mean_fun == 'zero':
            self.mean_module = gpytorch.means.ZeroMean()
        elif mean_fun == 'occ':
            self.mean_module = OccNetMean(occ_net=occ_net, shape_code=shape_code)
            
        if kernel == 'rbf':
            self.covar_module = gpytorch.kernels.RBFKernel()
    
        self.set_hparams(n, l)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    def set_hparams(self, n, l):
        if n is not None:
            self.likelihood.noise = n
        if l is not None:
            self.covar_module.lengthscale = l
        print('likelihood noise:   {}'.format(self.likelihood.noise.item()))
        print('kernel lengthscale: {}'.format(self.covar_module.lengthscale.item()))

def evaluate_gp(model, likelihood, test_x):
    model.eval()
    likelihood.eval()

    with torch.no_grad():
        pred_y = likelihood(model(test_x))
    
    return pred_y

def train_gp(model, likelihood, num_iter=1000):
    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.5) 
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
#         print('Iter %d/%d - Loss: %.3f' % (i + 1, num_epochs, loss.item()))
        optimizer.step()

In [53]:
train_x = inputs.squeeze().cpu()
train_y = torch.ones(train_x.shape[0])
likelihood = gpytorch.likelihoods.GaussianLikelihood()

In [83]:
with torch.no_grad():
    shape_code = model.encode_inputs(inputs)
model_rbf = ExactGPModel(train_x, train_y, likelihood, shape_code=shape_code, occ_net=model, n=0.01, l=0.1)

likelihood noise:   0.010000001639127731
kernel lengthscale: 0.10000000149011612


In [84]:
pred_y = evaluate_gp(model_rbf, likelihood, points_grid)

torch.Size([1, 300, 3])
torch.Size([1, 33068, 3])


In [93]:
pred_y_thresh = torch.where(pred_y.mean.squeeze() > 0.99)

In [94]:
points_grid_pred = points_grid[0, pred_y_thresh[0]].cpu().numpy()
print(points_grid_pred.shape)
output_colors = np.zeros(points_grid_pred.shape)
output_colors[:, 0] = 1
output_pc = o3d.geometry.PointCloud()
output_pc.points = o3d.utility.Vector3dVector(points_grid_pred)
output_pc.colors = o3d.utility.Vector3dVector(output_colors)

inputs_np = inputs.squeeze().cpu().numpy()
input_colors = np.zeros(inputs_np.shape)
input_colors[:, 1] = 1
input_pc = o3d.geometry.PointCloud()
input_pc.points = o3d.utility.Vector3dVector(inputs_np)
input_pc.colors = o3d.utility.Vector3dVector(input_colors)

draw([output_pc, input_pc])

(3984, 3)
[Open3D INFO] Window window_9 created.


WebVisualizer(window_uid='window_9')

[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceCandidate
[Open3D INFO] DataChannelObserver::OnStateChange label: ServerDataChannel, state: open, peerid: 0.8957463745930327
[Open3D INFO] DataChannelObserver::OnStateChange label: ClientDataChannel, state: open, peerid: 0.8957463745930327
[Open3D INFO] Sending init frames to window_9.


[2602:994][79505] (stun_port.cc:96): Binding request timed out from 10.75.15.x:52778 (enp4s0)
[2619:738][79505] (stun_port.cc:96): Binding request timed out from 10.75.15.x:45468 (enp4s0)
[2624:512][79505] (stun_port.cc:96): Binding request timed out from 10.75.15.x:35753 (enp4s0)
