In [None]:
import torch
from torch import nn

from tqdm.notebook import tqdm

import numpy as np

from kaolin.rep import TriangleMesh
from kaolin.metrics.mesh import chamfer_distance,edge_length,point_to_surface

import os

import plotly.graph_objects as go
import matplotlib.pyplot as plt



   No module named 'nuscenes'


In [None]:
def get_model(file):
    mesh = TriangleMesh.from_obj(f'{file}')
    v,f= mesh.vertices,mesh.faces

#     v = v - v.min(0)[0][None, :]
#     v /= torch.abs(v).max()
#     v *= 2
#     v -= v.max(0)[0][None, :] / 2
    
    v = v.cuda()
    f=f.cuda()
    
    mesh=TriangleMesh.from_tensors(v,f)
    return mesh
    
    
def plot_mesh(mesh):
    v,f=mesh.vertices.detach().cpu().numpy(),mesh.faces.detach().cpu().numpy()
    x,y,z=v[:,0],v[:,1],v[:,2]
    i,j,k=f[:,0],f[:,1],f[:,2]
    fig = go.Figure(data=[
        go.Mesh3d(
            x=x,
            y=y,
            z=z,
            colorbar_title='z',
            colorscale=[[0, 'gold'],
                        [0.5, 'mediumturquoise'],
                        [1, 'magenta']],
            # Intensity of each vertex, which will be interpolated and color-coded
            intensity = np.linspace(0, 1, len(x), endpoint=True),

            i = i,
            j = j,
            k = k,
            name='mesh',
            showscale=False
        )
    ],
                    layout=go.Layout(
        scene = dict(xaxis = dict(nticks=4, range=[-0.5,0.5]),
                     yaxis = dict(nticks=4, range=[-0.5,0.5]),
                     zaxis = dict(nticks=4, range=[-0.5,0.5])
                    )))

    fig.show()

In [None]:
m1=get_model('car2.obj')
m1.vertices.requires_grad=True
m2=get_model('car1.obj')

In [None]:
plot_mesh(m1)
plot_mesh(m2)


In [None]:
optimizer = torch.optim.Adam([m1.vertices],lr=0.01)

loop = tqdm(range(15000))
losses=[]

for i in loop:

    optimizer.zero_grad()
    
    loss=chamfer_distance(m1,m2,1,1,10000) + 10*edge_length(m1)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    if i%1000==0:
        plot_mesh(m1)
    loop.set_description(f'loss:{losses[-1]}')

plt.plot(losses)



In [8]:
m1.sample(10)

(tensor([[-0.3763, -0.0782,  0.0452],
         [ 0.3166, -0.1144, -0.1155],
         [ 0.0957,  0.0448,  0.1361],
         [-0.3459,  0.0068,  0.1563],
         [-0.1247,  0.1003, -0.1088],
         [ 0.2506, -0.0566, -0.1553],
         [ 0.3904,  0.0028, -0.1034],
         [ 0.0123,  0.1292, -0.0710],
         [-0.2488, -0.0055, -0.1782],
         [ 0.0906,  0.0441, -0.1616]], device='cuda:0', grad_fn=<AddBackward0>),
 tensor([36015, 70028, 25104, 15105, 35828, 71543, 30686, 23848,  4950, 24645],
        device='cuda:0'))