### Imports

In [None]:
from torchvision import datasets, transforms
import torch.utils.data
import torch
import sys
import argparse
import matplotlib.pyplot as plt
# from utils import * 
from utils import *
import open3d as o3d
from models import *
from collections import OrderedDict
import os, shutil, gc
from tqdm import tqdm_notebook

In [None]:
%matplotlib notebook

### Args

In [None]:
parser = argparse.ArgumentParser(description='VAE training of LiDAR')
parser.add_argument('--batch_size',         type=int,   default=16,             help='size of minibatch used during training')
parser.add_argument('--use_selu',           type=int,   default=0,              help='replaces batch_norm + act with SELU')
parser.add_argument('--base_dir',           type=str,   default='runs/test',    help='root of experiment directory')
parser.add_argument('--no_polar',           type=int,   default=0,              help='if True, the representation used is (X,Y,Z), instead of (D, Z), where D=sqrt(X^2+Y^2)')
parser.add_argument('--lr',                 type=float, default=1e-3,           help='learning rate value')
parser.add_argument('--z_dim',              type=int,   default=1024,            help='size of the bottleneck dimension in the VAE, or the latent noise size in GAN')
parser.add_argument('--autoencoder',        type=int,   default=1,              help='if True, we do not enforce the KL regularization cost in the VAE')
parser.add_argument('--atlas_baseline',     type=int,   default=0,              help='If true, Atlas model used. Also determines the number of primitives used in the model')
parser.add_argument('--panos_baseline',     type=int,   default=0,              help='If True, Model by Panos Achlioptas used')
parser.add_argument('--kl_warmup_epochs',   type=int,   default=150,            help='number of epochs before fully enforcing the KL loss')
parser.add_argument('--debug', action='store_true')

In [None]:
args = parser.parse_args([])
args

### Set model paths

In [None]:
MODEL_BASE_PATH = "/home/saby/Projects/ati/ati_motors/adversarial_based/static_reconstruction_method/"

In [None]:
# MODEL_FOLDER_NAME = "second_attempt_triple_data_restarted_correctly_1024"
# MODEL_FILE_NAME = "gen_150.pth"
# model = VAE(args, n_filters=64).cuda()
# LEARN_TO_FILTER = False
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "second_attempt_filtered_32f_triple_data_restarted_again2_correctly_1024"
# MODEL_FILE_NAME = "gen_245.pth"
# model = VAE_filtered(args, n_filters=32).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "second_attempt_filtered_64f_triple_data_restarted_correctly_1024"
# MODEL_FILE_NAME = "gen_105.pth"
# model = VAE_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "second_attempt_ground_weighted_filtered_64f_triple_data_continued_correctly_1024"
# MODEL_FILE_NAME = "gen_260.pth"
# model = VAE_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "fourth_attempt_ground_weighted_filtered_polar_new_unet_64f_2048"
# MODEL_FILE_NAME = "gen_300.pth"
# model = Unet_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "fifth_attempt_no_sem_ground_weighted_filtered_polar_old_unet_32f_1024"
# MODEL_FILE_NAME = "gen_498.pth"
# model = VAE_filtered(args, n_filters=32).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "fifth_attempt_sem_ground_weighted_filtered_polar_old_unet_32f_1024_continued"
# MODEL_FILE_NAME = "gen_145.pth"
# model = VAE_filtered(args, n_filters=32).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "fifth_attempt_slam_weighted_sem_ground_weighted_filtered_polar_old_unet_32f_1024"
# MODEL_FILE_NAME = "gen_115.pth"
# model = VAE_filtered(args, n_filters=32).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "fifth_attempt_inpainting_weighted_no_sem_ground_weighted_filtered_polar_old_unet_32f_1024"
# MODEL_FILE_NAME = "gen_280.pth"
# model = VAE_filtered(args, n_filters=32).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

###################################
# MODEL_FOLDER_NAME = "first_new_attempt_new_unet_64f"
# MODEL_FILE_NAME = "gen_125.pth"
# model = Unet_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "first_new_attempt_new_unet_64f_no_dropout"
# MODEL_FILE_NAME = "gen_84.pth"
# model = Unet_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = "trying_new_unet_correctly_64f_continued"
# MODEL_FILE_NAME = "gen_80.pth"
# model = Unet_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

MODEL_FOLDER_NAME = "trying_new_unet_correctly_64f_no_inpaint"
MODEL_FILE_NAME = "gen_183.pth"
model = Unet_filtered(args, n_filters=64).cuda()
LEARN_TO_FILTER = True
MODEL_USED_DATA_PARALLEL = False

# MODEL_FOLDER_NAME = ""
# MODEL_FILE_NAME = "gen_.pth"
# model = VAE_filtered(args, n_filters=64).cuda()
# LEARN_TO_FILTER = True
# MODEL_USED_DATA_PARALLEL = False

In [None]:
MODEL_TEST_PATH = os.path.join(MODEL_BASE_PATH, MODEL_FOLDER_NAME, 'models', MODEL_FILE_NAME)
if not os.path.exists(MODEL_TEST_PATH):
    print("No Model file found at : {}".format(MODEL_TEST_PATH))
    assert False

### Load Model

In [None]:
# model = VAE_filtered(args, n_filters=64).cuda()
print("Loading model from {}".format(MODEL_TEST_PATH))
network=torch.load(MODEL_TEST_PATH)

if MODEL_USED_DATA_PARALLEL:
    # original saved file with DataParallel
    state_dict = network
    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v

    # load params
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(network)
model.eval()

### Set data paths

In [None]:
DATA_BASE_PATH = "/home/saby/Projects/ati/data/data/datasets/Carla/16beam-Data/small_map/testing"
DATA_TEST_FOLDER_LIST = ["8", "24", "48"]
SAVE_PCD_NPY = True
TEST_NPY_FOLDER = "_out_out_npy"

In [None]:
LIDAR_RANGE = 100

### Start Testing

In [None]:
def getint(name):
    return int(name.split('.')[0])
    
def draw_pcd(pcd, where='opn_nb'):
    if where is 'opn_nb':
        visualizer = o3d.JVisualizer()
        visualizer.add_geometry(pcd)
        visualizer.show()
    elif where is 'opn_view':
        o3d.visualization.draw_geometries([pcd], width=1280, height=800)
    elif where is 'mat_3d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1], pts[:,2])
        plt.grid()
        plt.show()
    elif where is 'mat_2d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1])
        plt.grid()
        plt.show()
        
def draw_registration_result(src_pcd, dst_pcd, x_pt, y_pt, theta):    
    src_pcd_tmp = copy.deepcopy(src_pcd)
    dst_pcd_tmp = copy.deepcopy(dst_pcd)
    
    src_pcd_tmp.paint_uniform_color([1, 0, 0])  # red source
    dst_pcd_tmp.paint_uniform_color([0, 0, 1])  # blue target
    
    transform_mat = pose2matrix([x_pt, y_pt, 0], [0,0,theta])
    dst_pcd_tmp.transform(transform_mat)
    
    visualizer = o3d.JVisualizer()
    visualizer.add_geometry(src_pcd_tmp)
    visualizer.add_geometry(dst_pcd_tmp)
    visualizer.show()
    
process_input = from_polar if args.no_polar else lambda x : x



In [None]:
def masked_dynamic_recon(dynamic, recon, mask):
    # bin_mask = (mask[:,0] - mask[:,1]).round().view((mask.shape[0], 1, mask.shape[2], mask.shape[3]))
    bin_mask = mask[:,1].round().view((mask.shape[0], 1, mask.shape[2], mask.shape[3]))
    masked_dynamic = (dynamic * (1-mask))
    masked_recon = (dynamic * (1-mask)) + (mask * recon)
    return masked_dynamic, masked_recon

In [None]:
for DATA_TEST_FOLDER in DATA_TEST_FOLDER_LIST:
    print("\n\n\n    Test folder : {}".format(DATA_TEST_FOLDER))
    ######## Set paths
    OUTPUT_PCD_FOLDER = MODEL_FOLDER_NAME + "_" + MODEL_FILE_NAME.split(".")[0] + "_pcd" 
    if SAVE_PCD_NPY:
        OUTPUT_NPY_FOLDER = OUTPUT_PCD_FOLDER + "_out_npy"

    TEST_NPY_FOLDER_PATH = os.path.join(DATA_BASE_PATH, DATA_TEST_FOLDER, TEST_NPY_FOLDER)
    OUTPUT_PCD_FOLDER_PATH = os.path.join(DATA_BASE_PATH, DATA_TEST_FOLDER, OUTPUT_PCD_FOLDER)
    if SAVE_PCD_NPY:
        OUTPUT_NPY_FOLDER_PATH = os.path.join(DATA_BASE_PATH, DATA_TEST_FOLDER, OUTPUT_NPY_FOLDER)

    test_files  = sorted(os.listdir(TEST_NPY_FOLDER_PATH), key=getint)

    if not os.path.exists(OUTPUT_PCD_FOLDER_PATH):
        os.makedirs(OUTPUT_PCD_FOLDER_PATH)
    else:
        shutil.rmtree(OUTPUT_PCD_FOLDER_PATH)
        os.makedirs(OUTPUT_PCD_FOLDER_PATH)

    if SAVE_PCD_NPY:
        if not os.path.exists(OUTPUT_NPY_FOLDER_PATH):
            os.makedirs(OUTPUT_NPY_FOLDER_PATH)
        else:
            shutil.rmtree(OUTPUT_NPY_FOLDER_PATH)
            os.makedirs(OUTPUT_NPY_FOLDER_PATH)

    ply_idx = 1
    if SAVE_PCD_NPY:
        npy_idx = 0

    for test_file in test_files:
        ###### Load corresponding dataset batch
        print("processing {}".format(test_file))
        dataset_val = np.load(os.path.join(TEST_NPY_FOLDER_PATH, test_file))
        dataset_val = preprocess(dataset_val, LIDAR_RANGE)
        dataset_val = dataset_val.astype('float32')
        val_loader  = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size,
                            shuffle=False, num_workers=12, drop_last=False)

        print("done")
        print("Saving pcds to {}".format(OUTPUT_PCD_FOLDER_PATH))
        recons=[]
        total_recon = []
        ##### For all batches of data
        for i, img_data in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
            dynamic_img = img_data.cuda()

            if LEARN_TO_FILTER:
                recon, xmask = model(process_input(dynamic_img))
                masked_dynamic, masked_recon = masked_dynamic_recon(dynamic_img, recon, xmask)
                recon=masked_recon
            else:
                recon = model(process_input(dynamic_img))

            recons=recon
            recons_temp=np.array(recons.detach().cpu())
            
            ###### Save all pcds
            for frame_num in range(recons_temp.shape[0]):
                frame = from_polar(recons[frame_num:frame_num+1,:,:,:]).detach().cpu().numpy()[0]
                frame_actual = np.array([frame_image for frame_image in frame])
                frame_flat = frame_actual.reshape((3,-1))
                frame_crop = frame_flat#[:,(frame_flat[2]  > 0.005)]
                some_pcd = o3d.geometry.PointCloud()
                some_arr = frame_crop.T * LIDAR_RANGE
                some_pcd.points = o3d.utility.Vector3dVector(some_arr)
                pcd_fname = str(ply_idx) + ".ply"
                single_pcd_path = os.path.join(OUTPUT_PCD_FOLDER_PATH, pcd_fname)
                o3d.io.write_point_cloud(single_pcd_path, some_pcd)
                ply_idx += 1
            gc.collect()

            ##### Append model outputs array
            if SAVE_PCD_NPY:
                recon_arr = from_polar(recon).detach().cpu().numpy()
                # add color mask as zeros for now
                if LEARN_TO_FILTER:
                    bin_mask = xmask[:,0].round().view((xmask.shape[0], 1, xmask.shape[2], xmask.shape[3]))
                    bin_mask = bin_mask.detach().cpu().numpy()
                    color_arr = bin_mask
                else:
                    color_arr = np.zeros((recon_arr.shape[0], 1, recon_arr.shape[2], recon_arr.shape[3]))


                recon_arr_4d = np.concatenate((recon_arr, color_arr), axis=1)
                if i == 0:
                    total_recon = recon_arr_4d
                else:
                    total_recon = np.concatenate((total_recon, recon_arr_4d), axis=0)
                gc.collect()
        print("done")
        
        ##### Save model outputs array npy if necessary
        if SAVE_PCD_NPY:
            total_recon = total_recon.transpose(0,2,3,1)
            npy_name = str(npy_idx) + ".npy"
            npy_path = os.path.join(OUTPUT_NPY_FOLDER_PATH, npy_name)
            print("Saving to {}".format(npy_path))
            np.save(npy_path, total_recon)
            npy_idx += 1
            print("done")
print("Done for all folders")

In [None]:
dataset_val = np.load("/home/saby/Projects/ati/data/data/datasets/Real_World/Real_Train_Data/static_npy_data/1.npy")

In [None]:
dataset_val = preprocess(dataset_val, 100)