In [1]:
import pandas as pd
from os.path import join
import glob
from scipy.spatial.transform import Rotation
import numpy as np
import laspy
import os
import math
import torch
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class EgoPose():
    def __init__(self, path:str):
        poses = pd.read_feather(path)
        self.qws = poses['qw']
        self.qxs = poses['qx']
        self.qys = poses['qy']
        self.qzs = poses['qz']
        
        x = torch.tensor(poses['tx_m'])
        y = torch.tensor(poses['ty_m'])
        z = torch.tensor(poses['tz_m'])
        self.coordinates = torch.stack([x, y, z], axis=1)
        
    def get_rotation_matrix(self, index:int)->torch.Tensor:
        rotation = Rotation([self.qxs[index], self.qys[index], self.qzs[index], self.qws[index]])
        return torch.from_numpy(rotation.as_matrix()).type(torch.float32)
    
    def get_xyz(self, index:int)->torch.Tensor:
        return self.coordinates[index]
    
# class Calibration():
#     def __init__(self, path):
#         calib = pd.read_feather(path)

class Lidar():
    def __init__(self, path:str):
        pcd = pd.read_feather(path)
        xs = torch.from_numpy(pcd['x'].to_numpy()).type(torch.float32)
        ys = torch.from_numpy(pcd['y'].to_numpy()).type(torch.float32)
        zs = torch.from_numpy(pcd['z'].to_numpy()).type(torch.float32)
        
        self.coordinates = torch.stack([xs, ys, zs], dim=1)
        self.intensities = torch.from_numpy(pcd['intensity'].to_numpy())
        self.laser_number = torch.from_numpy(pcd['laser_number'].to_numpy())
        self.gps_time = torch.from_numpy(pcd['offset_ns'].to_numpy())
        
class Argoverse():
    def __init__(self, root:str) -> None:
        self.root = root
        lidar_files = glob.glob(join(root, 'sensors', 'lidar', '*.feather'))
        self.lidars = [Lidar(file) for file in lidar_files]
        self.ego_pose = EgoPose(join(root, 'city_SE3_egovehicle.feather'))
        self.num_index = len(self.lidars)
        
    def get_points(self, i:int)->torch.Tensor:
        lidar = self.lidars[i]
        intensities = lidar.intensities
        coordinates = lidar.coordinates
        
        # rotation
        rot_mat = self.ego_pose.get_rotation_matrix(i)
        coordinates = torch.matmul(rot_mat, coordinates.T).T
        
        # translation
        coordinates += self.ego_pose.get_xyz(i)
        
        points = torch.concatenate([coordinates, intensities.view(-1, 1)], dim=1)
        return points

In [7]:
def normalize_points(points:torch.Tensor)->torch.Tensor:
    points[:,:,:3] -= torch.mean(points[:,:,:3], dim=1)
    points[:,:,:3] /= torch.max(torch.abs(points[:,:,:3]), dim=1).values
    return points

def create_las(path:str, argoverse:Argoverse, model)->None:
    points_list = []
    classes_list = []
    
    for i in range(argoverse.num_index):
        points = argoverse.get_points(i)
        points_list.append(points)
        points = points.unsqueeze(0).cuda().type(torch.float32)
        points = normalize_points(points)
        points[:,:,3] = points[:,:,3] / 255.0
        
        with torch.no_grad():
            prediction = model(points)
            prediction = prediction.transpose(1, 2)
            prediction = torch.nn.functional.softmax(prediction, dim=1)
            prediction = prediction.argmax(dim=1).squeeze().detach().cpu()
            print(prediction.shape, prediction[prediction > 0].shape)
            classes_list.append(prediction)
            
    points = torch.concatenate(points_list, dim=0).numpy()
    classes = torch.concatenate(classes_list, dim=0).numpy()
    intensity = points[:, 3].astype(np.uint8)
    
    header = laspy.LasHeader(point_format=3, version="1.4")
    header.offsets = np.min(points[:,:3], axis=0)
    header.scales = np.array([0.01, 0.01, 0.01])
    # header.vlrs
    
    las = laspy.LasData(header)
    las.x = points[:,0]
    las.y = points[:,1]
    las.z = points[:,2]
    las.intensity = intensity
    las.classification = classes
    # las.gps_time = paris.gps_time
    # las.point_source_id = paris.frame_index
    
    las.write(path)
    print(f'{points.shape[0]} points were saved on {path}.')

In [4]:
from vision.PointTransformer import PointTransformerSeg
num_classes = 5
seg_model = PointTransformerSeg(4, num_classes)
seg_model.load_state_dict(torch.load('data/point_transformer_seg.pt'))
seg_model.cuda()
seg_model.eval()
print()




In [5]:
argoverse = Argoverse('data/0dr6jn0kF6YjT9Qr1mtpYrE0ihkGpKsd')

In [8]:
create_las('./test.las', argoverse, seg_model)

torch.Size([54601]) torch.Size([2799])
torch.Size([54535]) torch.Size([2821])
torch.Size([54514]) torch.Size([2787])
torch.Size([54550]) torch.Size([2530])
torch.Size([54592]) torch.Size([1954])
torch.Size([54505]) torch.Size([2692])
torch.Size([54500]) torch.Size([2387])
torch.Size([54584]) torch.Size([2829])
torch.Size([54492]) torch.Size([2202])
torch.Size([54415]) torch.Size([1924])
torch.Size([54456]) torch.Size([2591])
torch.Size([54482]) torch.Size([2573])
torch.Size([54422]) torch.Size([2024])
torch.Size([54441]) torch.Size([2789])
torch.Size([54475]) torch.Size([2347])
torch.Size([54495]) torch.Size([2307])
torch.Size([54531]) torch.Size([1844])
torch.Size([54441]) torch.Size([2353])
torch.Size([54413]) torch.Size([2628])
torch.Size([54359]) torch.Size([2183])
torch.Size([54403]) torch.Size([2700])
torch.Size([54412]) torch.Size([2531])
torch.Size([54384]) torch.Size([2361])
torch.Size([54549]) torch.Size([2371])
torch.Size([54388]) torch.Size([2766])
torch.Size([54470]) torch