In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import os, sys
from os.path import join, abspath
import time
import pdb
import glob

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.filters
import skimage.io

from PIL import  Image
from multiprocessing import Pool
import cv2

import ctypes as ct

import plotly.graph_objects as go

In [None]:
sys.path.append('../utils')
sys.path.append('../chamfer_utils')

In [None]:
# from helper_funcs import create_folder, remove_outliers
from shapenet_taxonomy import shapenet_category_to_id

In [None]:
exp_dir = '/home/ubuntu/ssl_3d_recon/expts_chair'
log_dir = join(exp_dir, 'log_proj_pcl_disp_test')
categ = 'car'
categ = shapenet_category_to_id[categ]

pcl_data_dir = '/home/ubuntu/ssl_3d_recon/data/ShapeNet_v1/%s'%(categ)
data_dir = '/home/ubuntu/ssl_3d_recon/data/ShapeNet_rendered/%s'%(categ)
mode = 'test'
models = sorted(np.load('/home/ubuntu/ssl_3d_recon/splits/images_list_%s_%s.npy'%(categ, mode), allow_pickle=True))

names = sorted(glob.glob(join(log_dir, '*.npy')))

In [None]:
def remove_outliers(pcl, min_val=-.5, max_val=0.5):
    '''
    Remove outlier points in pcl and replace with existing points --> used only
    during visualization, SHOULD NOT be used during metric calculation
    Args:
            pcl: float, (BS,N_PTS,3); input point cloud with outliers
            min_val, max_val: float, (); minimum and maximum value of the
                        co-ordinates, beyond which point is treated as outlier
    Returns:
            pcl: float, (BS,N_PTS,3); cleaned point cloud
    '''
    pcl_clip = np.clip(pcl, min_val, max_val)
    indices = np.equal(pcl, pcl_clip)
    ind, _ = np.where(indices!=True)
    pcl[ind] = pcl[0]
    return pcl

In [None]:
# Make temporary data directories in /home/ubuntu/ssl_3d_recon
# Note: We do not have X server so cannot display images machine remotely on the VM.
# Going to put them temporarily in an output directory as below and then scp the images to view them locally.
import os
pcl_viz_dirs = ['images', 'pcl']
for dir_name in pcl_viz_dirs:
    os.makedirs('/home/ubuntu/ssl_3d_recon/VIZ/%s'%(dir_name), exist_ok=True)

output_img_base_path = '/home/ubuntu/ssl_3d_recon/VIZ/images'

In [None]:
def get_pcl(pcl_data):
    x = pcl_data[:, 0]
    y = pcl_data[:, 1]
    z = pcl_data[:, 2]
    return x, y, z

In [None]:
def showpoints(xyz, c0=None, c1=None, c2=None, waittime=0, showrot=False,
               magnifyBlue=0, freezerot=False, background=(0,0,0), normalizecolor=True,
               ballradius=10):
    # xyz=xyz-xyz.mean(axis=0)
    radius=((xyz**2).sum(axis=-1)**0.5).max()
    xyz/=(radius*2.2)
    if c0 is None:
        c0=np.zeros((len(xyz),),dtype='float32')+255
    if c1 is None:
        c1=c0
    if c2 is None:
        c2=c0
    if normalizecolor:
        c0/=(c0.max()+1e-14)/255.0
        c1/=(c1.max()+1e-14)/255.0
        c2/=(c2.max()+1e-14)/255.0
    c0=np.require(c0,'float32','C')
    c1=np.require(c1,'float32','C')
    c2=np.require(c2,'float32','C')
    
    

In [None]:
n_plots = 100


def viz_pcl(ballradius=3):
    '''
    Save the input image, GT and predicted point cloud
    '''
    num_images = 0
    for idx in range(n_plots):
        img_name, img_id = models[idx][0].split('_')

        # Load the gt and pred point clouds
        gt_path = join(pcl_data_dir, img_name, 'pointcloud_1024.npy')
        print('GT path: ', gt_path)
        gt_pcl = np.load(gt_path)
        
        pcl = np.load(names[idx])[:,:3]
#         pcl = remove_outliers(pcl)

        # Load and display input image
        image_path = join(data_dir, img_name,'render_%s.png'%(img_id))
        if not os.path.exists(image_path):
            continue
        ip_img = skimage.io.imread(image_path)
        num_images+=1

        # RGB to BGR for cv2.
        ip_img = np.flip(ip_img[:,:,:3], -1)
        output_img_path = join('%s/%s_%s.png')%(output_img_base_path, img_name, img_id)
        Image.fromarray(np.uint8(ip_img)).save(output_img_path)
        
        # Save pointclouds
        gt_pcl_data = get_pcl(gt_pcl)
        eval_pcl_data = get_pcl(pcl)
        
        # Plot pointclouds
        plotting_pointcloud = np.array(pcl).T
        fig = go.Figure(data=[go.Scatter3d(x=plotting_pointcloud[0], y=plotting_pointcloud[1], z=plotting_pointcloud[2],
                                   mode='markers', marker=dict(size=1))])
        fig.show()
        

    print('Num test images that do exist: ', num_images)

#     showpoints(gt_pcl, ballradius=ballradius)
#     showpoints(pcl, ballradius=ballradius)
        # saveBool = show3d_balls.showtwopoints(gt_pcl, pcl, ballradius=ballradius)

In [None]:
viz_pcl()