In [1]:
from torchvision import datasets, transforms
import torch.utils.data
import torch
import sys
import argparse
import matplotlib.pyplot as plt
# from utils import * 
from utils import *
import open3d as o3d
from models import *
from collections import OrderedDict
import os, shutil, gc
from tqdm import tqdm_notebook
import pickle

In [2]:
%matplotlib notebook

In [3]:
parser = argparse.ArgumentParser(description='VAE training of LiDAR')
parser.add_argument('--batch_size',         type=int,   default=16,             help='size of minibatch used during training')
parser.add_argument('--use_selu',           type=int,   default=0,              help='replaces batch_norm + act with SELU')
parser.add_argument('--base_dir',           type=str,   default='runs/test',    help='root of experiment directory')
parser.add_argument('--no_polar',           type=int,   default=0,              help='if True, the representation used is (X,Y,Z), instead of (D, Z), where D=sqrt(X^2+Y^2)')
parser.add_argument('--lr',                 type=float, default=1e-3,           help='learning rate value')
parser.add_argument('--z_dim',              type=int,   default=1024,            help='size of the bottleneck dimension in the VAE, or the latent noise size in GAN')
parser.add_argument('--autoencoder',        type=int,   default=1,              help='if True, we do not enforce the KL regularization cost in the VAE')
parser.add_argument('--atlas_baseline',     type=int,   default=0,              help='If true, Atlas model used. Also determines the number of primitives used in the model')
parser.add_argument('--panos_baseline',     type=int,   default=0,              help='If True, Model by Panos Achlioptas used')
parser.add_argument('--kl_warmup_epochs',   type=int,   default=150,            help='number of epochs before fully enforcing the KL loss')
parser.add_argument('--debug', action='store_true')

_StoreTrueAction(option_strings=['--debug'], dest='debug', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [4]:
args = parser.parse_args([])
args

Namespace(atlas_baseline=0, autoencoder=1, base_dir='runs/test', batch_size=16, debug=False, kl_warmup_epochs=150, lr=0.001, no_polar=0, panos_baseline=0, use_selu=0, z_dim=1024)

In [5]:
MODEL_BASE_PATH = "/home/saby/Projects/ati/ati_motors/adversarial_based/static_reconstruction_method/"

In [6]:
MODEL_FOLDER_NAME = "learning_to_filter_64beam_64f"
MODEL_FILE_NAME = "gen_62.pth"
model = Unet_filtered(args, n_filters=64).cuda()
LEARN_TO_FILTER = True
MODEL_USED_DATA_PARALLEL = False

In [7]:
LIDAR_RANGE = 120

In [8]:
MODEL_TEST_PATH = os.path.join(MODEL_BASE_PATH, MODEL_FOLDER_NAME, 'models', MODEL_FILE_NAME)
if not os.path.exists(MODEL_TEST_PATH):
    print("No Model file found at : {}".format(MODEL_TEST_PATH))
    assert False

In [9]:
# model = VAE_filtered(args, n_filters=64).cuda()
print("Loading model from {}".format(MODEL_TEST_PATH))
network=torch.load(MODEL_TEST_PATH)

if MODEL_USED_DATA_PARALLEL:
    # original saved file with DataParallel
    state_dict = network
    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v

    # load params
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(network)
model.eval()

Loading model from /home/saby/Projects/ati/ati_motors/adversarial_based/static_reconstruction_method/learning_to_filter_64beam_64f/models/gen_62.pth


Unet_filtered(
  (unet): Unet(
    (encoder_conv1): Doubleconv(
      (double_conv): Sequential(
        (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (encoder_down1): DownBlock(
      (down_double_conv): Sequential(
        (0): Down(
          (down): Sequential(
            (0): MaxPool2d(kernel_size=(2, 4), stride=(2, 4), padding=0, dilation=1, ceil_mode=False)
          )
        )
        (1): Doubleconv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05

In [10]:
def getint(name):
    return int(name.split('.')[0])
    
def draw_pcd(pcd, where='opn_nb'):
    if where is 'opn_nb':
        visualizer = o3d.JVisualizer()
        visualizer.add_geometry(pcd)
        visualizer.show()
    elif where is 'opn_view':
        o3d.visualization.draw_geometries([pcd], width=1280, height=800)
    elif where is 'mat_3d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1], pts[:,2])
        plt.grid()
        plt.show()
    elif where is 'mat_2d':
        plt.figure()
        pts = np.asarray(pcd.points)
        plt.scatter(pts[:,0], pts[:,1])
        plt.grid()
        plt.show()
        
def draw_registration_result(src_pcd, dst_pcd, x_pt, y_pt, theta):    
    src_pcd_tmp = copy.deepcopy(src_pcd)
    dst_pcd_tmp = copy.deepcopy(dst_pcd)
    
    src_pcd_tmp.paint_uniform_color([1, 0, 0])  # red source
    dst_pcd_tmp.paint_uniform_color([0, 0, 1])  # blue target
    
    transform_mat = pose2matrix([x_pt, y_pt, 0], [0,0,theta])
    dst_pcd_tmp.transform(transform_mat)
    
    visualizer = o3d.JVisualizer()
    visualizer.add_geometry(src_pcd_tmp)
    visualizer.add_geometry(dst_pcd_tmp)
    visualizer.show()
    
process_input = from_polar if args.no_polar else lambda x : x

In [11]:
def masked_dynamic_recon(dynamic, recon, mask):
    # Assuming channel 1 to be dynamic
    # if channel 1 rounds to 0 (static) then take points from dynamic (because these are static points in dynamic frame)
    # else if channel 1 rounds to 1 (dynamic) then take points from reconstructed static (because these are dynamic points in dynamic frame)
    shape_tuple = (mask.shape[0], 1, mask.shape[2], mask.shape[3])
    bin_mask_orig = mask[:,1].round().view(shape_tuple)
    bin_mask = torch.cat([bin_mask_orig, bin_mask_orig], axis=1)

    new_recon = (dynamic * (1-bin_mask)) + (bin_mask * recon)
    new_dynamic = dynamic * (1-bin_mask)
    return new_recon, bin_mask_orig

In [12]:
TEST_NPY_PATH = "../training_data/small_map/64beam/dynamic_prepreprocess_dir/3.npy"
with open(TEST_NPY_PATH, 'rb') as pkl_file:
        test_arr = pickle.load(pkl_file)

test_arr.shape

(1762, 2, 64, 1024)

In [13]:
test_dataloader = torch.utils.data.DataLoader(test_arr, batch_size=args.batch_size,
                            shuffle=False, num_workers=8, drop_last=False)

In [14]:
for i, img_data in tqdm_notebook(enumerate(test_dataloader), total=len(test_dataloader)):
    if i != 0:
        continue
    dynamic_img = img_data.cuda()

    recon, xmask = model(process_input(dynamic_img))
    masked_recon, bin_mask = masked_dynamic_recon(dynamic_img, recon, xmask)
    recon=masked_recon

    recons=recon
    recons_temp=np.array(recons.detach().cpu())

    ###### Save all pcds
#     for frame_num in range(recons_temp.shape[0]):
#         frame = from_polar(recons[frame_num:frame_num+1,:,:,:]).detach().cpu().numpy()[0]
#         frame_actual = np.array([frame_image for frame_image in frame])
#         frame_flat = frame_actual.reshape((3,-1))
#         frame_crop = frame_flat#[:,(frame_flat[2]  > 0.005)]
#         some_pcd = o3d.geometry.PointCloud()
#         some_arr = frame_crop.T * LIDAR_RANGE
#         some_pcd.points = o3d.utility.Vector3dVector(some_arr)
#         pcd_fname = str(ply_idx) + ".ply"
#         single_pcd_path = os.path.join(OUTPUT_PCD_FOLDER_PATH, pcd_fname)
#         o3d.io.write_point_cloud(single_pcd_path, some_pcd)
#         ply_idx += 1
#     gc.collect()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=111.0), HTML(value='')))




In [15]:
frame_num = 10
frame = from_polar(dynamic_img[frame_num:frame_num+1,:,:,:]).detach().cpu().numpy()[0]
frame_actual = np.array([frame_image for frame_image in frame])
frame_flat = frame_actual.reshape((3,-1))
frame_crop = frame_flat#[:,(frame_flat[2]  > 0.005)]
frame_crop.shape

(3, 65536)

In [16]:
mask = bin_mask.detach().cpu().numpy()[0]
mask.shape

(1, 64, 1024)

In [17]:
color_arr = np.concatenate((mask, np.zeros(mask.shape), np.zeros(mask.shape)), axis=0).reshape((3,-1)).T
color_arr.shape

(65536, 3)

In [18]:
some_pcd = o3d.geometry.PointCloud()
some_arr = frame_crop.T * LIDAR_RANGE
some_pcd.points = o3d.utility.Vector3dVector(some_arr)
some_pcd.colors = o3d.utility.Vector3dVector(color_arr)
some_pcd

geometry::PointCloud with 65536 points.

In [19]:
draw_pcd(some_pcd, where='opn_nb')

JVisualizer with 1 geometries