In [None]:
%matplotlib inline
import os
import tqdm
import imageio
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import seaborn as sns

sns.set_style("dark")
plt.style.use("dark_background")

import torch

from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency

from IPython import display

import utils

In [None]:
meshname0 = "SM_Env_Plant_01"
meshname1 = "SM_Env_TreeBirch_03"

mesh_list = [utils.load_mesh(f"meshes/{meshname0}.obj"),
             utils.load_mesh(f"meshes/{meshname1}.obj")]

points_to_sample = 500
sphere_level = 2

In [None]:
def animate_pointcloud(mesh, anim_file, restore_anim=True):
    if os.path.isfile(anim_file) and restore_anim:
        return anim_file
    frames = []
    for plot_i in range(24):
        points = sample_points_from_meshes(mesh, points_to_sample)
        x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
        fig = plt.figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        ax = Axes3D(fig)
        ax.scatter3D(x, z, -y)
        ax.view_init(elev=190, azim=360*(plot_i/24))
        plt.axis('off')
        plt.close()
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        frames.append(np.frombuffer(s, np.uint8).reshape((height, width, 4)))
    imageio.mimsave(anim_file, frames, 'GIF', fps=8)
    return anim_file

In [None]:
anim_file = animate_pointcloud(mesh_list[0], f"outputs/{meshname0}.gif")
display.Image(filename=anim_file)

In [None]:
anim_file = animate_pointcloud(mesh_list[1], f"outputs/{meshname1}.gif")
display.Image(filename=anim_file)

In [None]:
device = torch.device("cuda:0")
target_mesh = mesh_list[0]

src_mesh = ico_sphere(sphere_level, device)

deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)

learning_rate = 0.01
num_iter = 500
w_chamfer = 1.0 
w_edge = 0.05
w_normal = 0.0005
w_laplacian = 0.005

optimizer = torch.optim.Adam([deform_verts], lr=learning_rate, betas=(0.5, 0.999))
    
plot_period = 100
loop = tqdm.notebook.tqdm(range(num_iter))

chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []

plot_steps = np.linspace(0, num_iter-1, num=48, dtype=int)

anim_frames = []

for i in loop:
    optimizer.zero_grad()

    new_src_mesh = src_mesh.offset_verts(deform_verts)

    sample_trg = sample_points_from_meshes(target_mesh, points_to_sample)
    sample_src = sample_points_from_meshes(new_src_mesh, points_to_sample)

    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
    loss_edge = mesh_edge_loss(new_src_mesh)
    loss_normal = mesh_normal_consistency(new_src_mesh)
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    loss.backward()
    optimizer.step()
    
    loop.set_description('total_loss = %.6f' % loss)
    
    chamfer_losses.append(loss_chamfer)
    edge_losses.append(loss_edge)
    normal_losses.append(loss_normal)
    laplacian_losses.append(loss_laplacian)
        
    if i in plot_steps:
        plot_i = np.where(plot_steps == i)[0][0]
        points = sample_points_from_meshes(new_src_mesh, points_to_sample)
        x, y, z = points.clone().detach().cpu().squeeze().unbind(1)   
        fig = plt.figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        ax = Axes3D(fig)
        ax.scatter3D(x, z, -y)
        ax.view_init(elev=190, azim=360*(plot_i/24))
        plt.axis('off')
        plt.close()
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        anim_frames.append(np.frombuffer(s, np.uint8).reshape((height, width, 4)))
        
imageio.mimsave(f"outputs/optimization.gif", anim_frames, 'GIF', fps=8)

In [None]:
anim_file = animate_pointcloud(new_src_mesh, f"outputs/deformed_sphere.gif", restore_anim=False)
display.Image(filename=anim_file)

In [None]:
anim_file = animate_pointcloud(None, "outputs/optimization.gif")
display.Image(filename=anim_file)

In [None]:
plt.figure(figsize=(10,5))
plt.title("Loss")
plt.plot(chamfer_losses, label="chamfer")
plt.plot(edge_losses, label="edge")
plt.plot(normal_losses, label="normal")
plt.plot(laplacian_losses, label="laplacian")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"outputs/loss.png")
plt.show()