In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
import matplotlib.pyplot as plt

from lyft_dataset_sdk.lyftdataset import LyftDataset
from lyft_dataset_sdk.utils.data_classes import LidarPointCloud, Quaternion
from lyft_dataset_sdk.utils.geometry_utils import transform_matrix

In [None]:
input_dir = '/run/media/hoosiki/WareHouse1/mtb/datasets/lyft-3d-od'

lyft_dataset = LyftDataset(data_path=os.path.join(input_dir, 'train'),
                           json_path=os.path.join(input_dir, 'train', 'data'),
                           verbose=True)

In [None]:
sample = lyft_dataset.get('sample', 'b71497fc753ec107ca1ca6427f2513c550835aa244504550a5b0e2edd341f57d')
lidar = lyft_dataset.get('sample_data', sample['data']['LIDAR_TOP'])
lidar_data_path = lyft_dataset.get_sample_data_path(sample['data']['LIDAR_TOP'])

ego_pose = lyft_dataset.get('ego_pose', lidar['ego_pose_token'])
calibrated_sensor = lyft_dataset.get('calibrated_sensor', lidar['calibrated_sensor_token'])

global_from_car = transform_matrix(ego_pose['translation'],
                                   Quaternion(ego_pose['rotation']),
                                   inverse=False)

car_from_sensor = transform_matrix(calibrated_sensor['translation'],
                                   Quaternion(calibrated_sensor['rotation']),
                                   inverse=False)

# pointcloud w.r.t sensor frame: [xyzi, n_points]
pointcloud = LidarPointCloud.from_file(lidar_data_path)
# pointcloud w.r.t car frame.
pointcloud.transform(car_from_sensor)
# pointcloud: [xyzi, n_points] -> [n_points, xyzi]
pointcloud = pointcloud.points.transpose(1, 0)

In [None]:
pointcloud = pointcloud.transpose(1, 0)
pointcloud.shape

In [None]:
# A sanity check, the points should be centered around 0 in car space.
plt.hist(pointcloud[0], alpha=0.5, bins=30, label="X")
plt.hist(pointcloud[1], alpha=0.5, bins=30, label="Y")
plt.hist(pointcloud[2], alpha=0.5, bins=30, label="Z")
plt.legend()
plt.xlabel("Distance from car along axis")
plt.ylabel("Amount of points")
plt.show()

In [None]:
for i in range(3):
    mean = pointcloud[i].mean()
    std = pointcloud[i].std()
    print(mean-3*std, mean+3*std)