## This notebook will help you train a latent Point-Cloud GAN.

(Assumes latent_3d_points is in the PYTHONPATH and that a trained AE model exists)

In [None]:
import numpy as np
import os.path as osp
import matplotlib.pylab as plt

from latent_3d_points.src.point_net_ae import PointNetAutoEncoder
from latent_3d_points.src.autoencoder import Configuration as Conf
from latent_3d_points.src.neural_net import MODEL_SAVER_ID

from latent_3d_points.src.in_out import snc_category_to_synth_id, create_dir, PointCloudDataSet, \
                                        load_all_point_clouds_under_folder

from latent_3d_points.src.general_utils import plot_3d_point_cloud
from latent_3d_points.src.tf_utils import reset_tf_graph

from latent_3d_points.src.vanilla_gan import Vanilla_GAN
from latent_3d_points.src.w_gan_gp import W_GAN_GP
from latent_3d_points.src.generators_discriminators import latent_code_discriminator_two_layers,\
latent_code_generator_two_layers

import numpy as np
import open3d as o3d
from open3d import JVisualizer

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

Specify where the raw point-clouds and the pre-trained AE are.

In [None]:
# Top-dir of where point-clouds are stored.
top_in_dir = '../data/shape_net_core_uniform_samples_2048/'    

ae_configuration = '../data/single_class_ae/configuration'

In [None]:
# Where to save GANs check-points etc.
top_out_dir = '../data/'
experiment_name = 'latent_gan_with_chamfer_ae'

ae_epoch = 500           # Epoch of AE to load.
bneck_size = 128         # Bottleneck-size of the AE
n_pc_points = 2048       # Number of points per model.

# class_name = raw_input('Give me the class name (e.g. "chair"): ').lower()
class_name = "chair"
class_name = class_name.lower()

In [None]:
# Load point-clouds.
syn_id = snc_category_to_synth_id()[class_name]
# class_dir = osp.join(top_in_dir , syn_id)
class_dir = osp.join(top_in_dir , "own")
# class_dir = osp.join(top_in_dir , snc_category_to_synth_id()["table"])
all_pc_data = load_all_point_clouds_under_folder(class_dir, n_threads=8, file_ending='.ply', verbose=True)
print 'Shape of DATA =', all_pc_data.point_clouds.shape

In [None]:
# Load pre-trained AE
reset_tf_graph()
ae_conf = Conf.load(ae_configuration)
ae_conf.encoder_args['verbose'] = False
ae_conf.decoder_args['verbose'] = False
ae = PointNetAutoEncoder(ae_conf.experiment_name, ae_conf)
ae.restore_model(ae_conf.train_dir, ae_epoch, verbose=True)

In [None]:
# Use AE to convert raw pointclouds to latent codes.
latent_codes = ae.get_latent_codes(all_pc_data.point_clouds)
latent_data = PointCloudDataSet(latent_codes)
print 'Shape of DATA =', latent_data.point_clouds.shape

In [None]:
# Check the decoded AE latent-codes look descent.
L = ae.decode(latent_codes)
i = 0
plot_3d_point_cloud(L[i][:, 0], L[i][:, 1], L[i][:, 2], in_u_sphere=True);
i = 20
plot_3d_point_cloud(L[i][:, 0], L[i][:, 1], L[i][:, 2], in_u_sphere=True);

In [None]:
i=0
L = ae.decode(latent_codes)
for i in range(1):
    points = L[i]
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)

    visualizer = JVisualizer()
    visualizer.add_geometry(pcd)
    visualizer.show()

    o3d.io.write_point_cloud("pcd_aereconstruction"+str(i)+".pcd", pcd, write_ascii=False, compressed=False, print_progress=False)
# ###############
for i in range(1):
    pcdin=all_pc_data.point_clouds
    pointsin = pcdin[i]
    print(type(pointsin))
    pcd_in = o3d.geometry.PointCloud()
    pcd_in.points = o3d.utility.Vector3dVector(pointsin)
    visualizer_in = JVisualizer()
    visualizer_in.add_geometry(pcd_in)
    visualizer_in.show()
    o3d.io.write_point_cloud("pcd_input"+str(i)+".pcd", pcd_in, write_ascii=False, compressed=False, print_progress=False)


############################
latent_codes = ae.get_latent_codes(all_pc_data.point_clouds)
print(latent_codes[0])

In [None]:
print(o3d.__version__)

In [None]:
# Set GAN parameters.

use_wgan = True     # Wasserstein with gradient penalty, or not?
n_epochs = 1        # Epochs to train.

plot_train_curve = True
save_gan_model = False
saver_step = np.hstack([np.array([1, 5, 10]), np.arange(50, n_epochs + 1, 50)])

# If true, every 'saver_step' epochs we produce & save synthetic pointclouds.
save_synthetic_samples = True
# How many synthetic samples to produce at each save step.
n_syn_samples = latent_data.num_examples

# Optimization parameters
init_lr = 0.0001
batch_size = 50
noise_params = {'mu':0, 'sigma': 0.2}
noise_dim = bneck_size
beta = 0.5 # ADAM's momentum.

n_out = [bneck_size] # Dimensionality of generated samples.

if save_synthetic_samples:
    synthetic_data_out_dir = osp.join(top_out_dir, 'OUT/synthetic_samples/', experiment_name)
    create_dir(synthetic_data_out_dir)

if save_gan_model:
    train_dir = osp.join(top_out_dir, 'OUT/latent_gan', experiment_name)
    create_dir(train_dir)

In [None]:
reset_tf_graph()

if use_wgan:
    lam = 10 # lambda of W-GAN-GP
    gan = W_GAN_GP(experiment_name, init_lr, lam, n_out, noise_dim, \
                  latent_code_discriminator_two_layers, 
                  latent_code_generator_two_layers,\
                  beta=beta)
else:    
    gan = Vanilla_GAN(experiment_name, init_lr, n_out, noise_dim,
                     latent_code_discriminator_two_layers, latent_code_generator_two_layers,
                     beta=beta)

In [None]:
accum_syn_data = []
train_stats = []

In [None]:
def generate_hack(self, n_samples, noise_params):
        noise = self.generator_noise_distribution_hack(n_samples, self.noise_dim, **noise_params)
        feed_dict = {self.noise: noise}
        return self.sess.run([self.generator_out], feed_dict=feed_dict)[0]
    
def generator_noise_distribution_hack(self, n_samples, ndims, mu, sigma):
        return np.random.normal(mu, sigma, (n_samples, ndims))

In [None]:
# Train the GAN.

for _ in range(n_epochs):
    loss, duration = gan._single_epoch_train(latent_data, batch_size, noise_params)
    epoch = int(gan.sess.run(gan.increment_epoch))
    print epoch, loss

    if save_gan_model and epoch in saver_step:
        checkpoint_path = osp.join(train_dir, MODEL_SAVER_ID)
        gan.saver.save(gan.sess, checkpoint_path, global_step=gan.epoch)

#     if save_synthetic_samples and epoch in saver_step:

    syn_latent_data = gan.generate(n_syn_samples, noise_params) #original
#     syn_latent_data = gan.generate_chair(n_syn_samples, noise_params) #hardcoded chair
#     syn_latent_data = gan.generate_chair_ml(n_syn_samples, noise_params) #hardcoded chair with missing leg
#     syn_latent_data = gan.generate_table(n_syn_samples, noise_params) #hardcoded table
    syn_data = ae.decode(syn_latent_data)
    np.savez(osp.join(synthetic_data_out_dir, 'epoch_' + str(epoch)), syn_data)
    print(syn_data[0].size)
    for k in range(10):  # plot three (syntetic) random examples.
#             plot_3d_point_cloud(syn_data[k][:, 0], syn_data[k][:, 1], syn_data[k][:, 2],
#                                in_u_sphere=True)
        points = syn_data[k]
#             print(points)
        lines = points.size
        colors = [[1, 0, 0] for i in range(lines)]
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)

        visualizer = JVisualizer()
        visualizer.add_geometry(pcd)
        visualizer.show()
        o3d.io.write_point_cloud("pcd_generated"+str(k)+".pcd", pcd, write_ascii=False, compressed=False, print_progress=False)

    train_stats.append((epoch, ) + loss)

In [None]:
if plot_train_curve:
    x = range(len(train_stats))
    d_loss = [t[1] for t in train_stats]
    g_loss = [t[2] for t in train_stats]
    plt.plot(x, d_loss, '--')
    plt.plot(x, g_loss)
    plt.title('Latent GAN training. (%s)' %(class_name))
    plt.legend(['Discriminator', 'Generator'], loc=0)
    
    plt.tick_params(axis='x', which='both', bottom='off', top='off')
    plt.tick_params(axis='y', which='both', left='off', right='off')
    
    plt.xlabel('Epochs.') 
    plt.ylabel('Loss.')