In [None]:
%matplotlib inline
import os
import time
import shutil
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

sns.set_style("dark")
plt.style.use("dark_background")

import torch
import torch.nn as nn

from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency

from datetime import datetime

import utils
import batch_iterator
import models
import create_data

training_timestamp = str(int(time.time()))
model_dir = f'trained_models/model_{training_timestamp}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device("cuda:0")

In [None]:
def print_and_log(text):
    print(text)
    print(text, file=open(f'{model_dir}/log.txt', 'a'))

In [None]:
def plot_pointcloud(pointcloud, title="", from_mesh=False, points_to_sample=5000):
    if from_mesh:
        points = sample_points_from_meshes(pointcloud.to(device), points_to_sample)
        x, y, z = points.clone().detach().cpu().squeeze().unbind(1)
    else:
        x, y, z = torch.from_numpy(pointcloud).unbind(1) 
    fig = plt.figure(figsize=(5, 5))
    ax = Axes3D(fig)
    ax.scatter3D(x, z, -y)
    plt.axis('off')
    ax.set_title(title)
    ax.view_init(190, 30)
    plt.savefig(f"{model_dir}/{title}.png")
    plt.show()

In [None]:
points_to_sample = 1000
sphere_level = 3
sphere_level_verts = {2: 162, 3: 642, 4: 2562}

In [None]:
folder_path = 'data/'
if len(os.listdir(os.path.join(os.getcwd(), folder_path))) < 300:
    create_data.create_data('meshes/', 50, points_to_sample, sphere_level)
shutil.copy2('./train_model.ipynb', model_dir)

In [None]:
batch_size = 16
learning_rate = 1e-5

In [None]:
model_names = utils.get_model_names(folder_path)
training_x, training_y, test_x, test_y = utils.read_data(model_names, folder_path, points_to_sample, sphere_level_verts[sphere_level])
batch_iter = batch_iterator.BatchIterator(training_x, training_y, batch_size)

In [None]:
model = models.PointNetCls(sphere_level_verts[sphere_level]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
epochs = 40
train_epoch_step = int(round(batch_iter.size / batch_size))
val_epoch_step = int(round(test_x.shape[0] / batch_size))
display_steps = np.linspace(1, train_epoch_step, 10, endpoint=True).astype(np.uint32)
train_losses = []
train_epoch_losses = []
val_losses = []
lowest_val_loss = np.inf
best_model = copy.deepcopy(model.state_dict())

src_mesh = ico_sphere(sphere_level, device)
w_chamfer = 1.0 
w_edge = 0.05
w_normal = 0.0005
w_laplacian = 0.005
for epoch_i in range(1, epochs+1):
    print_and_log(f"{datetime.now()} Epoch:{epoch_i}, Training")
    model.train()
    train_epoch_loss = 0.0
    for step_i in range(1, train_epoch_step+1):
        batch_x, batch_y = batch_iter.next_batch()
        batch_x = torch.transpose(torch.from_numpy(batch_x), 1, 2).to(device)
        batch_y = torch.from_numpy(batch_y).to(device)
        
        model.zero_grad()
        pred_y, _, _ = model(batch_x)

        all_loss = 0.0
        for pointcloud, deform_verts in zip(batch_x, pred_y):
            pred_mesh = src_mesh.offset_verts(deform_verts)
            pred_pc = sample_points_from_meshes(pred_mesh, points_to_sample)
            pred_pc = pred_pc - torch.mean(pred_pc, axis=1)[0]
            pred_pc = pred_pc / torch.max(pred_pc, axis=1)[0]
            loss_chamfer, _ = chamfer_distance(torch.transpose(pointcloud, 0, 1).unsqueeze(0), pred_pc)
            loss_edge = mesh_edge_loss(pred_mesh)
            loss_normal = mesh_normal_consistency(pred_mesh)
            loss_laplacian = mesh_laplacian_smoothing(pred_mesh, method="uniform")
            cur_loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
            all_loss += cur_loss
        all_loss /= batch_size

        all_loss.backward()
        optimizer.step()
        
        train_losses.append(all_loss.item())
        train_epoch_loss += all_loss.item()
        
        if step_i in display_steps:
            print_and_log(f"{datetime.now()} Epoch:{epoch_i}, Training Step:{step_i}/{train_epoch_step}, "
                          f"Iter:{step_i*batch_size}/{train_epoch_step*batch_size}, Loss Chamfer:{loss_chamfer * w_chamfer:.4f}, "
                          f"Loss Edge:{loss_edge * w_edge:.4f}, Loss Normal:{loss_normal * w_normal:.4f}, Loss Laplacian:{loss_laplacian * w_laplacian:.4f}")
    
    train_epoch_loss /= train_epoch_step
    train_epoch_losses.append(train_epoch_loss)
    
    print_and_log(f"{datetime.now()} Epoch:{epoch_i}, Validation")
    model.eval()
    val_epoch_loss = 0
    
    with torch.no_grad():
        for val_step_i in range(val_epoch_step):
            batch_x = test_x[val_step_i*batch_size:val_step_i*batch_size+batch_size]
            batch_y = test_y[val_step_i*batch_size:val_step_i*batch_size+batch_size]
            batch_x = torch.transpose(torch.from_numpy(batch_x), 1, 2).to(device)
            batch_y = torch.from_numpy(batch_y).to(device)
            
            pred_y, _, _ = model(batch_x)

            all_loss = 0.0
            for pointcloud, deform_verts in zip(batch_x, pred_y):
                pred_mesh = src_mesh.offset_verts(deform_verts)
                pred_pc = sample_points_from_meshes(pred_mesh, points_to_sample)
                pred_pc = pred_pc - torch.mean(pred_pc, axis=1)[0]
                pred_pc = pred_pc / torch.max(pred_pc, axis=1)[0]
                loss_chamfer, _ = chamfer_distance(torch.transpose(pointcloud, 0, 1).unsqueeze(0), pred_pc)
                loss_edge = mesh_edge_loss(pred_mesh)
                loss_normal = mesh_normal_consistency(pred_mesh)
                loss_laplacian = mesh_laplacian_smoothing(pred_mesh, method="uniform")
                cur_loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
                all_loss += cur_loss
            all_loss /= batch_size
            
            val_epoch_loss += all_loss.item()
            
    val_epoch_loss /= val_epoch_step
    val_losses.append(val_epoch_loss)
    print_and_log(f"{datetime.now()} Epoch:{epoch_i}, Validation Loss:{val_epoch_loss:.6f}")
    if val_epoch_loss <= lowest_val_loss:
        lowest_val_loss = val_epoch_loss
        print_and_log(f"{datetime.now()} Epoch:{epoch_i}, Best Validation Loss Obtained - Saving Model")
        best_model = copy.deepcopy(model.state_dict())

In [None]:
model.load_state_dict(best_model)
torch.save(model.state_dict(), f"{model_dir}/model_{lowest_val_loss:.4f}.pth")
np.save(f"{model_dir}/train_losses.npy" , np.array(train_losses))
np.save(f"{model_dir}/train_epoch_losses.npy" , np.array(train_epoch_losses))
np.save(f"{model_dir}/val_losses.npy" , np.array(val_losses))

train_plot_steps = np.arange(len(train_losses))+1
val_plot_steps = (np.arange(len(val_losses))+1)*train_epoch_step
plt.figure(figsize=(10,5))
plt.title("Losses")
plt.plot(train_plot_steps, train_losses, label='train_loss', linewidth=3)
plt.plot(val_plot_steps, train_epoch_losses, label='train_epoch_loss', linewidth=3)
plt.plot(val_plot_steps, val_losses, label='val_loss', linewidth=3)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"{model_dir}/loss.png")
plt.show()

In [None]:
plot_sample_ix = 40
plot_pointcloud(training_x[plot_sample_ix], "mesh")

pred_deform_verts, _, _ = model(torch.transpose(torch.from_numpy(training_x[plot_sample_ix]), 0, 1).to(device).unsqueeze(0))
plot_pointcloud(ico_sphere(sphere_level).offset_verts(pred_deform_verts.detach().cpu().squeeze()), "predicted deformation", from_mesh=True, points_to_sample=points_to_sample)