In [1]:
import os
import torch
import numpy as np
import open3d as o3d
import gpytorch
from im2mesh import config
from im2mesh.checkpoints import CheckpointIO
from im2mesh.common import make_3d_grid

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = torch.device('cpu')
cfg = config.load_config('configs/pointcloud/onet_pretrained.yaml', 'configs/default.yaml')

dataset = config.get_dataset('test', cfg, return_idx=True)
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(cfg['training']['out_dir'], model=model)
re = checkpoint_io.load(cfg['test']['model_file'])

https://s3.eu-central-1.amazonaws.com/avg-projects/occupancy_networks/models/onet_pcl2mesh-5c0be168.pt
=> Loading checkpoint from url...


In [3]:
def get_data_sample(idx=500):

    data = dataset[idx]
    inputs = torch.from_numpy(data.get('inputs'))
    inputs = inputs.unsqueeze(0).to(device)

    points = torch.from_numpy(data.get('points'))
    points = points.unsqueeze(0).to(device)
    return inputs, points

def get_metadata(idx=500):
    meta = dataset.models[idx]
    return meta

def predict(inputs, points):
    kwargs = {}
    model.eval()
    with torch.no_grad():
        c = model.encode_inputs(inputs)
        occ_hat = model.decode(points, z=None, c=c, **kwargs)

    return occ_hat.logits.squeeze()

In [69]:
nx = 32
box_size = 1.1
points_grid = box_size * make_3d_grid((-0.5,)*3, (0.5,)*3, (nx,)*3)
points_grid = points_grid.unsqueeze(0).to(device)

In [70]:
# idx = 1658
idx = 2000
inputs, points = get_data_sample(idx=idx)
meta = get_metadata(idx=idx)
occ_hat = predict(inputs, points_grid)

In [71]:
occ_hat_pred = torch.where(occ_hat > 0.8)
print('num of points eval: {}'.format(occ_hat.shape))
print('num of pred points: {}'.format(occ_hat_pred[0].shape))

num of points eval: torch.Size([32768])
num of pred points: torch.Size([166])


In [72]:
class OccNetMean(gpytorch.means.Mean):
    def __init__(self, occ_net, shape_code, batch_shape=torch.Size()):
        super(OccNetMean, self).__init__()
        self.batch_shape = batch_shape
        self.occ_net = occ_net
        self.occ_net.eval()
        self.shape_code = shape_code
        
    def forward(self, x):
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        with torch.no_grad():
            occ_hat = self.occ_net.decode(x, z=None, c=self.shape_code).logits

        return occ_hat.squeeze()

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, shape_code, occ_net,
                 kernel='matern', mean_fun='occ', n=None, l=None):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        if mean_fun == 'zero':
            self.mean_module = gpytorch.means.ZeroMean()
        elif mean_fun == 'occ':
            self.mean_module = OccNetMean(occ_net=occ_net, shape_code=shape_code)
            
        if kernel == 'rbf':
            self.covar_module = gpytorch.kernels.RBFKernel()
        elif kernel == 'matern':
            self.covar_module = gpytorch.kernels.MaternKernel()
    
        self.set_hparams(n, l)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    def set_hparams(self, n, l):
        if n is not None:
            self.likelihood.noise = n
        if l is not None:
            self.covar_module.lengthscale = l
        print('likelihood noise:   {}'.format(self.likelihood.noise.item()))
        print('kernel lengthscale: {}'.format(self.covar_module.lengthscale.item()))

def evaluate_gp(model, likelihood, test_x):
    model.eval()
    likelihood.eval()

    with torch.no_grad():
        pred_y = likelihood(model(test_x))
    
    return pred_y

def train_gp(model, likelihood, num_iter=100):
    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.5) 
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))
        print('Lengthscale: {} Likelihood Noise: {}'.format(
            model.covar_module.lengthscale.item(), gp.likelihood.noise.item()))
        optimizer.step()

In [73]:
# mesh.translate


points_grid_np = points_grid.squeeze().cpu()
grid_pc = o3d.geometry.PointCloud()
grid_pc.points = o3d.utility.Vector3dVector(points_grid_np)
grid_pc.paint_uniform_color([1, 0, 0])


o3d.visualization.draw_geometries([grid_pc])

In [74]:
train_x = inputs.squeeze().cpu()
train_y = torch.ones(train_x.shape[0]) * 10
test_x = points_grid.squeeze()
print('train shape: {}'.format(train_x.shape))
print('test shape : {}'.format(test_x.shape))
with torch.no_grad():
    shape_code = model.encode_inputs(inputs)

train shape: torch.Size([300, 3])
test shape : torch.Size([32768, 3])


In [75]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
gp = ExactGPModel(
    train_x, train_y, likelihood, shape_code=shape_code, 
    occ_net=model, kernel='rbf', l=1, n=0.01)
# train_gp(gp, likelihood, num_iter=1000)

likelihood noise:   0.010000001639127731
kernel lengthscale: 1.0


In [76]:
pred_logits = evaluate_gp(gp, likelihood, test_x)
pred_probs = torch.sigmoid(pred_logits.mean)
pred_probs_hat = torch.where(pred_probs.squeeze() > 0.99)
# pred_probs_hat = torch.where(pred_logits.variance.squeeze() < 0.011)
print(pred_logits.variance.max())
print(pred_logits.variance.min())
print(pred_probs_hat[0].shape)

tensor(0.1892)
tensor(0.0101)
torch.Size([344])


In [77]:
shapenet_path = '/home/tarik/projects/datasets/shapenet/ShapeNetCore.v1'
mesh = o3d.io.read_triangle_mesh(os.path.join(shapenet_path, meta['category'], meta['model'], 'model.obj'))
mesh = mesh.translate(np.array([0, 0, 1.0]))

inputs_np = inputs.squeeze().cpu().numpy()
input_pc = o3d.geometry.PointCloud()
input_pc.points = o3d.utility.Vector3dVector(inputs_np)
input_pc.paint_uniform_color([0, 1, 0])

points_grid_pred = points_grid[0, occ_hat_pred[0]].cpu().numpy()
output_pc = o3d.geometry.PointCloud()
output_pc.points = o3d.utility.Vector3dVector(points_grid_pred)
output_pc.paint_uniform_color([1, 0, 0])
output_pc.translate([0, 0, 2.0])

points_grid_pred = points_grid[0, pred_probs_hat[0]].cpu().numpy()
output_gp_pc = o3d.geometry.PointCloud()
output_gp_pc.points = o3d.utility.Vector3dVector(points_grid_pred)
output_gp_pc.paint_uniform_color([0, 0, 1])

input_pc_t = o3d.geometry.PointCloud(input_pc)
input_pc_t.translate([0, 0, 2.0])

o3d.visualization.draw_geometries([output_gp_pc, input_pc, mesh, output_pc, input_pc_t])