Imports

In [None]:
from torchvision import datasets, transforms
import torch.utils.data
import torch
import sys
import argparse
import matplotlib.pyplot as plt
from utils import * 
from utils import *
import open3d as o3d
from models import *
from collections import OrderedDict
import os, shutil, gc
from tqdm import tqdm_notebook

In [None]:
%matplotlib notebook

Args

In [None]:
parser = argparse.ArgumentParser(description='VAE training of LiDAR')
parser.add_argument('--batch_size',         type=int,   default=32,             help='size of minibatch used during training')
parser.add_argument('--use_selu',           type=int,   default=0,              help='replaces batch_norm + act with SELU')
parser.add_argument('--base_dir',           type=str,   default='runs/test',    help='root of experiment directory')
parser.add_argument('--no_polar',           type=int,   default=0,              help='if True, the representation used is (X,Y,Z), instead of (D, Z), where D=sqrt(X^2+Y^2)')
parser.add_argument('--lr',                 type=float, default=1e-3,           help='learning rate value')
parser.add_argument('--z_dim',              type=int,   default=1024,            help='size of the bottleneck dimension in the VAE, or the latent noise size in GAN')
parser.add_argument('--autoencoder',        type=int,   default=1,              help='if True, we do not enforce the KL regularization cost in the VAE')
parser.add_argument('--atlas_baseline',     type=int,   default=0,              help='If true, Atlas model used. Also determines the number of primitives used in the model')
parser.add_argument('--panos_baseline',     type=int,   default=0,              help='If True, Model by Panos Achlioptas used')
parser.add_argument('--kl_warmup_epochs',   type=int,   default=150,            help='number of epochs before fully enforcing the KL loss')
parser.add_argument('--debug', action='store_true')

In [None]:
args = parser.parse_args([])
args

In [None]:
# model_file = '/home/sabyasachi/Projects/ati/ati_motors/adversarial_based/prashVAE/new_runs/unet_more_layers_restarted_with_more_data_ctd/models/gen_118.pth'
model_file = '/home/sabyasachi/Projects/ati/ati_motors/adversarial_based/prashVAE/aws_runs/unet_more_layers_correct/models/gen_471.pth'

MODEL_USED_DATA_PARALLEL = True

Load Model

In [None]:
model = VAE(args).cuda()
print("Loading model from {}".format(model_file))
network=torch.load(model_file)

if MODEL_USED_DATA_PARALLEL:
    # original saved file with DataParallel
    state_dict = network
    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v

    # load params
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(network)
model.eval()

Load data

In [None]:
def getint(name):
    return int(name.split('.')[0])
    
def draw_pcd(pcd, where='opn_nb'):
    if where is 'opn_nb':
        visualizer = o3d.JVisualizer()
        visualizer.add_geometry(pcd)
        visualizer.show()
    elif where is 'opn_view':
        o3d.visualization.draw_geometries([pcd], width=1280, height=800)
    elif where is 'mat_3d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1], pts[:,2])
        plt.grid()
        plt.show()
    elif where is 'mat_2d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1])
        plt.grid()
        plt.show()
        
def draw_registration_result(src_pcd, dst_pcd, x_pt, y_pt, theta):    
    src_pcd_tmp = copy.deepcopy(src_pcd)
    dst_pcd_tmp = copy.deepcopy(dst_pcd)
    
    src_pcd_tmp.paint_uniform_color([1, 0, 0])  # red source
    dst_pcd_tmp.paint_uniform_color([0, 0, 1])  # blue target
    
    transform_mat = pose2matrix([x_pt, y_pt, 0], [0,0,theta])
    dst_pcd_tmp.transform(transform_mat)
    
    visualizer = o3d.JVisualizer()
    visualizer.add_geometry(src_pcd_tmp)
    visualizer.add_geometry(dst_pcd_tmp)
    visualizer.show()

In [None]:
test_folder = "/home/sabyasachi/Projects/ati/data/data/datasets/Carla/few_dynamic_runs/110k/dynamic/no2whl_2/_out_out_npy"
out_folder = "/home/sabyasachi/Projects/ati/data/data/datasets/Carla/few_dynamic_runs/110k/dynamic/no2whl_2/_model_out"

# test_folder = "/home/sabyasachi/Projects/ati/data/data/datasets/Carla/few_dynamic_runs/110k/dynamic/scarce_1/_out_out_npy"
# out_folder = "/home/sabyasachi/Projects/ati/data/data/datasets/Carla/few_dynamic_runs/110k/dynamic/scarce_1/_model_out"
test_files  = sorted(os.listdir(test_folder), key=getint)

In [None]:
if not os.path.exists(out_folder):
    os.makedirs(out_folder)
else:
    shutil.rmtree(out_folder)
    os.makedirs(out_folder)

Evaluate on data

In [None]:
process_input = from_polar if args.no_polar else lambda x : x

In [None]:
ply_idx = 0
for test_file in test_files:
    # Load corresponding dataset batch
    print("processing {}".format(test_file))
    dataset_val = np.load(os.path.join(test_folder, test_file))
    dataset_val, normalization_factor = preprocess(dataset_val, give_factor=True)
    dataset_val = dataset_val.astype('float32')
    val_loader  = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size,
                        shuffle=False, num_workers=1, drop_last=False)

    print("done")
    recons=[]
    for i, img_data in tqdm_notebook(enumerate(val_loader)):
        dynamic_img = img_data.cuda()

        recon, kl_cost,hidden_z = model(process_input(dynamic_img))

        recons=recon
        recons_temp=np.array(recons.detach().cpu())
        
        for frame_num in range(recons_temp.shape[0]):
            frame=from_polar(recons[frame_num:frame_num+1,:,:,:]).detach().cpu().numpy()[0]
            frame_actual = np.array([frame_image[:29] for frame_image in frame])
            frame_flat = frame_actual.reshape((3,-1))
            frame_crop = frame_flat#[:,(frame_flat[2]  > 0.005)]
            some_pcd = o3d.PointCloud()
            some_arr = frame_crop.T * normalization_factor * 25
            some_pcd.points = o3d.utility.Vector3dVector(some_arr)
            pcd_fname = str(ply_idx) + ".ply"
            single_pcd_path = os.path.join(out_folder, pcd_fname)
            o3d.io.write_point_cloud(single_pcd_path, some_pcd)
            ply_idx += 1
        gc.collect()