In [1]:
from dataloaders.pc_dataset import SemanticKITTI
import argparse
import yaml
from easydict import EasyDict
from PIL import Image
from torchvision.transforms import transforms
import numpy as np
from torch.utils.data import DataLoader
from sparse2dense._2dpapenet import get_model as DepthCompletionModel
from matplotlib import cm
import torch

totensor = transforms.ToTensor()
topil = transforms.ToPILImage()

args = argparse.ArgumentParser()
args = args.parse_args(args=[])
with open('./config/semantic.yaml') as stream:
    config = yaml.safe_load(stream)
config.update(vars(args))
args = EasyDict(config)
dataset = SemanticKITTI(args=args)
dataloader = DataLoader(dataset, 1, False)

model = DepthCompletionModel(args).cuda()
model = model.load_from_checkpoint('/root/autodl-nas/sparse2dense_s2/best_2dpapenet.ckpt', args=args, strict=False).cuda()

In [53]:
cur_iter = iter(dataloader)
for frame in range(len(dataloader)):
    cur_data = next(cur_iter)
    H = cur_data['velodyne_proj_img0'].shape[2]
    W = cur_data['velodyne_proj_img0'].shape[3]
    H_up = int(cur_data['denser_coordinate_lines'][0].squeeze()[:,1].max())
    with torch.no_grad():
        output_data = model(cur_data)
    dense_img = output_data['all_refined_depth'].permute(0, 2, 3, 1).cpu().detach().numpy().squeeze()
    ''' matrix '''
    K = output_data['K'][0].cpu().numpy().squeeze()
    T_velo2img = output_data['T_velo2img'].cpu().numpy().squeeze()[:3, :3]
    T_4img = output_data['T_4img'].cpu().numpy().squeeze()[:3, :3]
    T_rot = output_data['T_rot'].cpu().numpy().squeeze()[:3, :3]
    K_inv = K[:3, :3]
    K_inv = np.linalg.inv(K_inv)
    T_velo2img_inv = np.linalg.inv(T_velo2img)
    T_4img_inv = np.linalg.inv(T_4img)
    T_rot_inv = np.linalg.inv(T_rot)
    ''' point index in depth img '''
    ''' transform points to velodyne_axis '''
    coordinate = np.indices((H - H_up, W)).reshape((2, -1)).transpose(1, 0)
    coordinate[:, 0:1] += H_up
    dense_points_list = []
    for img_idx in range(4):
        z_axis = dense_img[coordinate[:, 0], coordinate[:, 1] + img_idx * W].reshape((-1, 1))
        insert_uvd = np.concatenate([np.fliplr(coordinate), np.ones((coordinate.shape[0], 1))], axis=1)
        insert_xyz = (K_inv @ insert_uvd.T).T * z_axis
        mask = insert_xyz[:, 2] > 0
        insert_xyz = insert_xyz[mask]
        insert_point = (T_velo2img_inv @ insert_xyz.T).T
        for i in range(img_idx):
            insert_point = (T_4img_inv @ insert_point.T).T
        insert_point = (T_rot_inv @ insert_point.T).T
        dense_points_list.append(insert_point)
    dense_points = np.concatenate(dense_points_list, axis=0)

    max_dist = float(output_data['proj_distance'].max())
    cur_diatance = np.linalg.norm(dense_points, axis=1).astype(np.float32)
    dense_points = dense_points[cur_diatance < max_dist]

    cat_ones = np.ones((dense_points.shape[0], 1))
    dense_points = np.concatenate([dense_points, cat_ones], axis=1).astype(np.float32)
    T_xyz = output_data['T_xyz']
    T_xyz_inv = np.linalg.inv(T_xyz)
    # dense_points = dense_points.astype(np.float32)
    dense_points = (T_xyz_inv @ dense_points.T).T
    # dense_points = np.concatenate([output_data['raw_data'].cpu().numpy().squeeze(), dense_points.squeeze()], axis=0)
    dense_points = dense_points.astype(np.float32)
    dense_points.tofile(args['dataset_params']['data_path'] + '/03/denser/'+str(frame).zfill(6)+'.bin')

KeyboardInterrupt: 