# SemanticKITTI to Stamped Datum Table

CAN DELETE THIS NOTEBOOK


In [1]:
# parameters

# Please follow the instructions posted on the SemanticKITTI website to obtain the data:
# http://www.semantic-kitti.org/dataset.html#download
# Additionally, if you wish to study optical flow, you'll want to expand the KITTI zip
# file `data_odometry_color.zip`.
# Extract the data as described to a directory and paste that directory path here:
SEMANTICKITTI_ROOT = '/outer_root/host_mnt/Volumes/970-evo-raid0/semantickitti_odom_tmp/'

OUTPUT_ROOT = '/tmp/semantickitti_fused_root/'


In [2]:
# Setup

import time
import numpy as np
import os

import open3d as o3d
from oarphpy import util as oputil

# Deduced from:
# https://github.com/PRBonn/semantic-kitti-api/blob/c2d7712964a9541ed31900c925bf5971be2107c2/auxiliary/SSCDataset.py#L20
SK_SPLIT_SEQUENCES = {
    "train": ["00", "01", "02", "03", "04", "05", "06", "07", "09", "10"],
    "valid": ["08"],
    "test": ["11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21"]
}

SK_MOVING_LABELS = [
    252, # "moving-car"
    253, # "moving-bicyclist"
    254, # "moving-person"
    255, # "moving-motorcyclist"
    256, # "moving-on-rails"
    257, # "moving-bus"
    258, # "moving-truck"
    259, # "moving-other-vehicle"
]

def get_scene_basepath(seq):
    return os.path.join(SEMANTICKITTI_ROOT, 'dataset/sequences', seq)

SK_SEQ_TO_NSCANS = {}
for seq in SK_SPLIT_SEQUENCES['train']:
    scene_base = get_scene_basepath(seq)
    last_vel = max(os.listdir(os.path.join(scene_base + '/velodyne/')))
    n_scans = int(last_vel.replace('.bin', '')) + 1
    print('Found Sequence %s with %s scans' % (seq, n_scans))
    SK_SEQ_TO_NSCANS[seq] = n_scans
print("Found %s total scans" % sum(SK_SEQ_TO_NSCANS.values()))



Found Sequence 00 with 4541 scans
Found Sequence 01 with 1101 scans
Found Sequence 02 with 4661 scans
Found Sequence 03 with 801 scans
Found Sequence 04 with 271 scans
Found Sequence 05 with 2761 scans
Found Sequence 06 with 1101 scans
Found Sequence 07 with 1101 scans
Found Sequence 09 with 1591 scans
Found Sequence 10 with 1201 scans
Found 19130 total scans


In [3]:
# import time
# import six
# from contextlib import contextmanager
# class ThruputObserver(object):
#   """A utility for measuring the runtime and throughput of a subroutine.
#   Similar in spirit to `tqdm`, except `ThruputObserver`:
#    * Tracks not just time but a size metric (e.g. memory) in bytes
#    * Reports percentiles
#    * Simply logs strings and is not terminal-interactive
  
#   While `tqdm` is useful for notebooks, `ThruputObserver` seeks to be more
#   useful for longer-running batch jobs.
#   """
  
#   def __init__(
#       self,
#       name='',
#       log_on_del=False,
#       only_stats=None,
#       log_freq=100,
#       n_total=None,
#       n_total_chunks=None):
#     self.n = 0
#     self.num_bytes = 0
#     self.ts = []
#     self.name = name
#     self.log_on_del = log_on_del
#     self.only_stats = only_stats or []
#     self.n_total = max(n_total, 1) if n_total is not None else None
#     self.n_total_chunks = (
#       max(n_total_chunks, 1) if n_total_chunks is not None else None)
#     self._start = None
#     self.__log_freq = log_freq
#     self.__last_log = 0
  
#   @contextmanager
#   def observe(self, n=0, num_bytes=0):
#     """
#     NB: contextmanagers appear to be expensive due to object creation.
#     Use ThurputObserver#{start,stop}_block() for <10ms ops. 
#     FMI https://stackoverflow.com/questions/34872535/why-contextmanager-is-slow
#     """

#     self.start_block()
#     yield
#     self.stop_block(n=n, num_bytes=num_bytes)
  
#   def start_block(self):
#     self._start = time.time()
  
#   def update_tallies(self, n=0, num_bytes=0, new_block=False):
#     self.n += n
#     self.num_bytes += num_bytes
#     if new_block:
#       self.stop_block()
#       self.start_block()
  
#   def stop_block(self, n=0, num_bytes=0):
#     end = time.time()
#     self.n += n
#     self.num_bytes += num_bytes
#     if self._start is not None:
#       self.ts.append(end - self._start)
#     self._start = None
  
#   def maybe_log_progress(self, every_n=-1):
#     if every_n >= 0:
#       self.__log_freq = every_n
#     if self.n >= self.__last_log + self.__log_freq:
#       from oarphpy.util import log
#       print("Progress for \n" + str(self)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#       self.__last_log = self.n
#         # Track last log because `n` may increase inconsistently
#       if every_n == -1 and (self.n >= (1.7 * self.__log_freq)):
#         self.__log_freq = int(1.7 * self.__log_freq)
#           # Exponentially decay logging frequency. Don't decay quite as
#           # fast as Vowpal Wabbit did, though.

#   @staticmethod
#   def union(thruputs):
#     u = ThruputObserver()
#     for t in thruputs:
#       u += t
#     return u

#   @property
#   def total_time(self):
#     return sum(self.ts)

#   def get_stats(self):
#     import numpy as np
#     from humanfriendly import format_size
#     from humanfriendly import format_timespan

#     total_time = self.total_time

#     stats = [
#       ('Thruput', ''),
#       ('N thru', (self.n
#                     if self.n_total is None
#                     else '%s (of %s)' % (self.n, self.n_total))),
#       ('N chunks', (len(self.ts)
#                     if self.n_total_chunks is None
#                     else '%s (of %s)' % (len(self.ts), self.n_total_chunks))),
#       ('Total time', format_timespan(total_time) if total_time else '-'),
#       ('Total thru', format_size(self.num_bytes)),
#       ('Rate', 
#         format_size(self.num_bytes / total_time) + ' / sec'
#         if total_time else '-'),
#       ('Hz', float(self.n) / total_time if total_time else '-'),
#     ]
#     percent_complete = None
#     if self.n_total is not None:
#       percent_complete = 100. * float(self.n) / self.n_total
#     elif self.n_total_chunks is not None:
#       percent_complete = 100. * float(len(self.ts)) / self.n_total_chunks
#     if percent_complete is not None:
#       eta_sec = (
#         (100. - percent_complete) * 
#         (total_time / (percent_complete + 1e-10)))
#       stats.extend([
#         ('Progress', ''),
#         ('Percent Complete', percent_complete),
#         ('Est. Time To Completion', format_timespan(eta_sec)),
#       ])
#     if len(self.ts) >= 2:
#       format_t = lambda t: format_timespan(t, detailed=True)
#       stats.extend([
#         ('Latency (per chunk)', ''),
#         ('Avg', format_t(np.mean(self.ts))),
#         ('p50', format_t(np.percentile(self.ts, 50))),
#         ('p95', format_t(np.percentile(self.ts, 95))),
#         ('p99', format_t(np.percentile(self.ts, 99))),
#       ])
#     if self.only_stats:
#       stats = tuple(
#         (name, value)
#         for name, value in stats
#         if name in self.only_stats
#       )
#     return stats

#   def __iadd__(self, other):
#     self.n += other.n
#     self.num_bytes += other.num_bytes
#     self.ts.extend(other.ts)
#     return self

#   def __str__(self):
#     import tabulate
#     stats = self.get_stats()
#     summary = tabulate.tabulate(stats)
#     if self.name:
#       prefix = '%s [Pid:%s Id:%s]' % (self.name, os.getpid(), id(self))
#       summary = prefix + '\n' + summary
#     return summary
  
#   def __del__(self):
#     if self.log_on_del:
#       self.stop_block()

#       from oarphpy.util import create_log
#       log = create_log()
#       print('\n' + str(self) + '\n') #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  
#   @staticmethod
#   def monitoring_tensor(name, tensor, **observer_init_kwargs):
#     """Monitor the size of the given tensorflow `Tensor` and record a
#     text TF Summary with the contents of this ThruputObserver."""

#     class Observer(object):
#       def __init__(self, dtype_size_bytes):
#         self.observer = ThruputObserver(name=name, **observer_init_kwargs)
#         self.dtype_size_bytes = dtype_size_bytes
#       def __call__(self, t_shape):
#         import numpy as np
#         n = t_shape[0]
#         num_bytes = np.prod(t_shape) * self.dtype_size_bytes
#         self.observer.stop_block(n=n, num_bytes=num_bytes)
#         self.observer.maybe_log_progress()
        
#         # Tensorboard is very picky about wanting Markdown :P
#         import tabulatehelper as th
#         stats = self.observer.get_stats()
#         out = th.md_table(stats, headers=[name])

#         self.observer.start_block()
#         return out
    
#     import tensorflow as tf
#     obs_str_tensor = tf.compat.v1.py_func(
#               Observer(tensor.dtype.size), [tf.shape(tensor)], tf.string)
#     tf.summary.text(name + '/ThruputObserver', obs_str_tensor)
#     return obs_str_tensor
  
#   @staticmethod
#   def wrap_func(func, **observer_init_kwargs):
#     """Decorate `func` and observe a block on each call"""
#     class MonitoredFunc(object):
#       def __init__(self, func, observer_init_kwargs):
#         self.func = func
#         self.observer = ThruputObserver(**observer_init_kwargs)
#       def __call__(self, *args, **kwargs):
#         from oarphpy.util.misc import get_size_of_deep
#         self.observer.start_block()
#         ret = self.func(*args, **kwargs)
#         self.observer.stop_block(n=1, num_bytes=get_size_of_deep(ret))
#         self.observer.maybe_log_progress()
#         return ret
#     return MonitoredFunc(func, observer_init_kwargs)

#   @staticmethod
#   def monitor_generator(gen, **observer_init_kwargs): #~~~~~~~~~~~~~~~~~~~~
#     observer_init_kwargs['log_on_del'] = True
#     t = ThruputObserver(**observer_init_kwargs)
#     while True:
#         t.start_block()
#         x = six.next(gen)
#         t.stop_block(n=1, num_bytes=oputil.get_size_of_deep(x)) # ~~~~~~~~~~~~~~~~~~~
        
#         yield x
        
#         t.maybe_log_progress()
        

In [4]:
## Support Code

def get_calibration(seq):
    scene_base = get_scene_basepath(seq)
    return parse_calibration(os.path.join(scene_base, 'calib.txt'))

def get_poses(seq):
    scene_base = get_scene_basepath(seq)
    return parse_poses(os.path.join(scene_base, "poses.txt"))
    
def parse_calibration(path):
    """Parse a calibration file and return a map to 4x4 Numpy matrices.
    Important keys returned:
    * Tr - the lidar to camera static transform
    * P2 - the left camera projective matrix P
    Based upon https://github.com/PRBonn/semantic-kitti-api/blob/9b5feda3b19ea560a298493b9a5ebebe0cbe2cc2/generate_sequential.py#L14
    """
    calib = {}

    with open(path) as f:
        for line in f:
            key, mat_str = line.strip().split(":")
            values = [float(v) for v in mat_str.strip().split()]
            mat = np.zeros((4, 4))
            mat[0, 0:4] = values[0:4]
            mat[1, 0:4] = values[4:8]
            mat[2, 0:4] = values[8:12]
            mat[3, 3] = 1.0
            calib[key] = mat
    return calib

def parse_poses(path):
    """Read a SemanticKITTI (per-scan) poses file and return a list of 4x4 homogenous
    RT matrices that express world-to-left-camera transforms.  The index of this list is
    implicitly the scan ID.
    
    Based upon: https://github.com/PRBonn/semantic-kitti-api/blob/9b5feda3b19ea560a298493b9a5ebebe0cbe2cc2/generate_sequential.py#L42
    """
    poses = []
    with open(path) as f:
        for line in f:
            values = [float(v) for v in line.strip().split()]
            mat = np.zeros((4, 4))
            mat[0, 0:4] = values[0:4]
            mat[1, 0:4] = values[4:8]
            mat[2, 0:4] = values[8:12]
            mat[3, 3] = 1.0
            poses.append(mat)
    return poses
    

    
# #     Tr = calib["Tr"]
# #     Tr_inv = np.linalg.inv(Tr)
    
    
    
#   """ read poses file with per-scan poses from given filename
#       Returns
#       -------
#       list
#           list of poses as 4x4 numpy arrays.
#   """
#   file = open(filename)

#   poses = []

#   Tr = calibration["Tr"]
# #   print('Tr', Tr)
# #   Tr = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
#   Tr_inv = np.linalg.inv(Tr)

#   for line in file:
#     values = [float(v) for v in line.strip().split()]

#     pose = np.zeros((4, 4))
#     pose[0, 0:4] = values[0:4]
#     pose[1, 0:4] = values[4:8]
#     pose[2, 0:4] = values[8:12]
#     pose[3, 3] = 1.0

#     poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
# #     poses.append(np.matmul(pose, Tr))
#   file.close()
#   return poses

### Set up Spark

In [5]:
from oarphpy.spark import NBSpark
NBSpark.SRC_ROOT = '/opt/psegs/psegs'
NBSpark.SRC_ROOT_MODULES = ['psegs']
NBSpark.CONF_KV.update({
    'spark.driver.maxResultSize': '10g',
    'spark.driver.memory': '16g',
  })
# NBSpark.CONF_KV.pop('spark.extraListeners')
spark = NBSpark.getOrCreate()

2021-02-11 09:51:40,514	oarph 17241 : Using source root /opt/psegs/psegs 
INFO - 2021-02-11 09:51:40,514 - spark - Using source root /opt/psegs/psegs 
2021-02-11 09:51:40,516	oarph 17241 : Using source root /opt/psegs 
INFO - 2021-02-11 09:51:40,516 - spark - Using source root /opt/psegs 
2021-02-11 09:51:40,555	oarph 17241 : Generating egg to /tmp/tmpngjmscf__oarphpy_eggbuild ...
INFO - 2021-02-11 09:51:40,555 - spark - Generating egg to /tmp/tmpngjmscf__oarphpy_eggbuild ...
INFO - 2021-02-11 09:51:40,571 - driver - Generating grammar tables from /usr/lib/python3.8/lib2to3/Grammar.txt
INFO - 2021-02-11 09:51:40,626 - driver - Generating grammar tables from /usr/lib/python3.8/lib2to3/PatternGrammar.txt
2021-02-11 09:51:40,706	oarph 17241 : ... done.  Egg at /tmp/tmpngjmscf__oarphpy_eggbuild/psegs-0.0.0-py3.8.egg
INFO - 2021-02-11 09:51:40,706 - spark - ... done.  Egg at /tmp/tmpngjmscf__oarphpy_eggbuild/psegs-0.0.0-py3.8.egg
INFO - 2021-02-11 09:51:43,209 - kernelextension - Client Con

### Fuse World Clouds and Dump Them

Nota Bene! Excellent large point cloud viewer: 
```
docker --context default run -it --name=potree_viewer --rm --net=host -v `pwd`:/shared  jonazpiazu/potree
```


In [6]:
# class SingleSequenceWorldCloudFuser(object):
    
#     def __init__(self, seq):
#         self.seq = seq
#         self.scene_base = get_scene_basepath(seq)
        
#         print("Loading calibration for sequence %s" % seq)
#         self.calib = get_calibration(seq)
              
#         print("Loading poses for sequence %s" % seq)
#         self.all_poses = get_poses(seq)

#     @classmethod
#     def get_moving_mask_for_scan(cls, scene_base, scan_id):
#         scan_name = str(scan_id).rjust(6, '0')
#         labels_path = os.path.join(scene_base, 'labels', scan_name + '.label')
#         labels = np.fromfile(labels_path, dtype=np.uint32)
#         labels = labels.reshape((-1))
#         sem_label = labels & 0xFFFF  # semantic label in lower half
#         inst_label = labels >> 16    # instance id in upper half
#          # NB: 22 / 252 is chase car in scene 08 !!!
        
#         moving_mask = np.logical_or.reduce(tuple((sem_label == c) for c in SK_MOVING_LABELS))
#         return moving_mask
        
#     def read_scan_get_clean_world_cloud(self, scan_id):
#         import numpy as np

#         scan_name = str(scan_id).rjust(6, '0')
#         scan_path = os.path.join(self.scene_base, 'velodyne', scan_name + '.bin')
#         lidar = np.frombuffer(open(scan_path, 'rb').read(), dtype=np.float32).reshape((-1, 4))
#         cloud = np.ones(lidar.shape)  # need homogenous for change below
#         cloud[:, 0:3] = lidar[:, 0:3]

#         # Move cloud into the world frame
#         Tr = self.calib["Tr"]
#         Tr_inv = np.linalg.inv(Tr)
#         cam2_pose = self.all_poses[scan_id]
#         pose = np.matmul(Tr_inv, np.matmul(cam2_pose, Tr))  
#         cloud = np.matmul(pose, cloud.T).T

#         # Clean out points for anything moving
# #         moving_mask = np.logical_or.reduce(tuple((sem_label == c) for c in SK_MOVING_LABELS))
# #         if not moving_mask.any():
# #             frames_no_movers.append(s)
#         moving_mask = self.get_moving_mask_for_scan(self.scene_base, scan_id)
#         static_cloud = cloud[~moving_mask][:, :3]
        
#         # TODO need to scrube the ego car !!  
#         # moving_cloud = cloud[moving_mask][:, :3]
#         return static_cloud
    

In [7]:
# for seq, n_scans in sorted(SK_SEQ_TO_NSCANS.items()):
#     print("Fusing sequence %s ..." % seq)
#     fuser = SingleSequenceWorldCloudFuser(seq)
    
#     slices = n_scans // 100
#     task_rdd = spark.sparkContext.parallelize(range(n_scans), numSlices=slices)
#     cloud_rdd = task_rdd.map(lambda s: fuser.read_scan_get_clean_world_cloud(s))
    
#     import pyspark
#     cloud_rdd = cloud_rdd.persist(pyspark.StorageLevel.MEMORY_AND_DISK)
    
    
#     iter_clouds = cloud_rdd.toLocalIterator()#prefetchPartitions=True)):  TODO FIXME USING SPARK 2.4 !!!
#     iter_clouds_t = ThruputObserver.monitor_generator(iter_clouds, n_total=n_scans, log_freq=100)
#     fused_world_cloud = np.vstack(iter_clouds_t)
    
#     print("Fused world cloud: {s} ({sz:.2f} GBytes)".format(
#         s=fused_world_cloud.shape, sz=fused_world_cloud.nbytes * 1e-9))
    
#     fused_world_root = os.path.join(OUTPUT_ROOT, 'fused_world_clouds')
#     oputil.mkdir(fused_world_root)

#     import pickle
#     path = os.path.join(fused_world_root, "%s.pkl" % seq)
#     pickle.dump(fused_world_cloud, open(path, 'wb'), protocol=4)
#     print('Saved fused world cloud pkl to %s' % path)
    
#     pcd = o3d.geometry.PointCloud()
#     pcd.points = o3d.utility.Vector3dVector(fused_world_cloud)
#     path = os.path.join(fused_world_root, "%s.ply" % seq)
#     o3d.io.write_point_cloud(path, pcd)
#     print('Saved fused world cloud to %s' % path)
# # #     n_moving_pts = sum(c.shape[0] for c in all_moving_clouds)
# # #     print('moving_cloud pts', n_moving_pts, float(n_moving_pts) / fused_world_cloud.shape[0])
# # #     print('frames_no_movers', frames_no_movers[:20])
    
    

### Search for frames with zero moving things

In [8]:



# # for seq, n_scans in sorted(SK_SEQ_TO_NSCANS.items()):
# #     print("Searching sequence %s ..." % seq)
    
# #     slices = n_scans // 100
# #     task_rdd = spark.sparkContext.parallelize(range(n_scans), numSlices=slices)
    
# #     scan_has_no_movers = lambda scan_id: (not seq_scan_has_movers(seq, scan_id))
# #     scans_no_movers = task_rdd.filter(scan_has_no_movers).collect()
    
# #     print("Sequence %s has %s frames with no moving points ..." % (seq, len(scans_no_movers)))


# import sys
# sys.path.append('/opt/psegs')

# import copy
# from psegs import datum
# from psegs import util
# from psegs.table.sd_table import StampedDatumTableBase
# class SemanticKITTIFusedSDTable(StampedDatumTableBase):
    
#     ONLY_FRAMES_WITH_NO_MOVERS = True
    
#     import sys
#     sys.path.append('/opt/psegs')
    
#     @classmethod
#     def _get_all_segment_uris(cls):
#         return [
#             datum.URI(
#                 dataset='semantikitti-psegs-fused',
#                 split='train',
#                 segment_id=str(seq))
#             for seq in SK_SEQ_TO_NSCANS.keys()
#         ]

#     @classmethod
#     def _create_datum_rdds(cls, spark, existing_uri_df=None, only_segments=None):
#         """Subclasses should create and return a list of `RDD[StampedDatum]`s

#         only_segments must be segment uris
#         TODO docs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""
        
        
#         assert existing_uri_df is None, "Resume feature not supported"
#         seg_uris = cls.get_all_segment_uris()
#         if only_segments:
#             util.log.info("Filtering to only %s segments" % len(only_segments))
#             seg_uris = [
#                 uri for uri in seg_uris
#                 if any(
#                   suri.soft_matches_segment(uri) for suri in only_segments)
#             ]
        
#         datum_rdds = []
#         for seg_uri in seg_uris:
#             seq = seg_uri.segment_id
#             if cls.ONLY_FRAMES_WITH_NO_MOVERS:
#                 util.log.info("Finding scans for sequence %s with no movering points ..." % seq)
#                 n_scans = SK_SEQ_TO_NSCANS[seq]
#                 slices = n_scans // 100
#                 task_rdd = spark.sparkContext.parallelize(range(n_scans), numSlices=slices)
#                 scan_has_no_movers = lambda scan_id: (not seq_scan_has_movers(seq, scan_id))
#                 scans_no_movers = task_rdd.filter(scan_has_no_movers).collect()
#                 util.log.info("... sequence %s has %s scans with no movers." % (seq, len(scans_no_movers)))
#                 scan_ids = scans_no_movers
#             else:
#                 scan_ids = list(range(SK_SEQ_TO_NSCANS[seq]))
            
            
#             tasks = [(seg_uri, scan_id) for scan_id in scan_ids]
            
#             # Emit camera_image RDD
#             ctask_rdd = spark.sparkContext.parallelize(tasks)
#             datum_rdd = ctask_rdd.map(lambda t: cls.create_camera_frame(*t))
#             datum_rdds.append(datum_rdd)
            
#             # Emit ego_pose RDD
#             ptask_rdd = spark.sparkContext.parallelize(tasks)
#             datum_rdd = ptask_rdd.map(lambda t: cls.create_ego_pose(*t))
#             datum_rdds.append(datum_rdd)
            
#             # Emit world cloud once
#             wc_rdd = spark.sparkContext.parallelize([seg_uri])
#             datum_rdd = wc_rdd.map(lambda t: cls.create_world_cloud(t))
#             datum_rdds.append(datum_rdd)
    
#         return datum_rdds
        
#         # Emit camera and pose RDDs
        
        
#         # for each segment emit camera and ego pose RDDs
#         # for each world cloud emit flyweight
#         # if we had cuboids, we'd emit them and object fused clouds
#         # for the fused stuff, perhaps lazy-create those? and/or require as a
#         # FIXTURES thing.
    
#     @classmethod
#     def _get_calib(cls, seq):
#         if not hasattr(cls, '_calib'):
#             cls._calib = {}
#         if seq not in cls._calib:
#             cls._calib[seq] = get_calibration(seq)
#         return cls._calib[seq]
    
#     @classmethod
#     def _get_poses(cls, seq):
#         if not hasattr(cls, '_poses'):
#             cls._poses = {}
#         if seq not in cls._poses:
#             cls._poses[seq] = get_poses(seq)
#         return cls._poses[seq]
    
#     @classmethod
#     def create_camera_frame(cls, base_uri, scan_id):
#         seq = base_uri.segment_id
#         calib = cls._get_calib(seq)
        
#         uri = copy.deepcopy(base_uri)
#         uri.topic = 'camera|left_rect'
#         uri.timestamp = int(scan_id) # HACK!

#         scene_base = get_scene_basepath(seq)
#         scan_name = str(scan_id).rjust(6, '0')
#         img_path = os.path.join(scene_base, 'image_2/', scan_name + '.png')
#         assert os.path.exists(img_path), (
#             "Did you remember to expand data_odometry_color.zip ? %s not found" % img_path)
#         with open(img_path, 'rb') as f:
#             width, height = util.get_png_wh(f.read(100)) # HACK!!!!
        
#         image_png = util.LazyThunktor(lambda: open(img_path, 'rb').read())
        
#         # HACK!!!  This is actually P !!!
#         K = calib['P2']
        
#         # hack! this is lidar to cam
#         ego_to_sensor = datum.Transform.from_transformation_matrix(
#                 calib['Tr'], src_frame='lidar', dest_frame=uri.topic)
        
#         sd_ego_pose = cls.create_ego_pose(base_uri, scan_id)
#         ego_pose = sd_ego_pose.transform
#         ci = datum.CameraImage(
#               sensor_name=uri.topic,
#               image_png=image_png,
#               width=width,
#               height=height,
#               timestamp=uri.timestamp,
#               ego_pose=ego_pose,
#               K=K,
#               ego_to_sensor=ego_to_sensor,
#               extra={'semantic_kitti.scan_id': str(scan_id)})
#         return datum.StampedDatum(uri=uri, camera_image=ci)
    
#     @classmethod
#     def create_ego_pose(cls, base_uri, scan_id):
#         seq = base_uri.segment_id
#         poses = cls._get_poses(seq)
        
#         uri = copy.deepcopy(base_uri)
#         uri.topic = 'ego_pose'
#         uri.timestamp = int(scan_id) # HACK!
        
#         # Hack! believe ego frame is lidar here?
#         ego_pose = datum.Transform.from_transformation_matrix(
#                 poses[scan_id], src_frame='world', dest_frame='ego')

#         return datum.StampedDatum(uri=uri, transform=ego_pose)      
    
#     @classmethod
#     def create_world_cloud(cls, base_uri):
#         seq = base_uri.segment_id

#         uri = copy.deepcopy(base_uri)
#         uri.topic = 'lidar|world_fused'
#         uri.timestamp = 0 # HACK!
        
#         cloud_path = os.path.join(OUTPUT_ROOT, 'fused_world_clouds', seq + '.ply')
#         def ply_to_np(path):
#             import open3d
#             pcd = open3d.io.read_point_cloud(str(path))
#             return np.asarray(pcd.points)
#         cloud = util.LazyThunktor(lambda: ply_to_np(cloud_path))
#         pc = datum.PointCloud(
#               sensor_name=uri.topic,
#               timestamp=uri.timestamp,
#               cloud=cloud,
#               ego_to_sensor=datum.Transform(),
#               ego_pose=datum.Transform(),
#               extra={'semantic_kitti.world_cloud_path': cloud_path})
#         return datum.StampedDatum(uri=uri, point_cloud=pc)

# seg_uris = SemanticKITTIFusedSDTable.get_all_segment_uris()
# sd_rdd = SemanticKITTIFusedSDTable._get_segment_datum_rdd_or_df(spark, seg_uris[0])
# print(sd_rdd.count())

In [9]:
import sys
sys.path.append('/opt/psegs')

import copy

from psegs import datum
from psegs import util
from psegs.table.sd_table import StampedDatumTableBase


def get_moving_mask_for_scan(scene_base, scan_id):
    scan_name = str(scan_id).rjust(6, '0')
    labels_path = os.path.join(scene_base, 'labels', scan_name + '.label')
    labels = np.fromfile(labels_path, dtype=np.uint32)
    labels = labels.reshape((-1))
    sem_label = labels & 0xFFFF  # semantic label in lower half
    inst_label = labels >> 16    # instance id in upper half
     # NB: 22 / 252 is chase car in scene 08 !!!

    moving_mask = np.logical_or.reduce(tuple((sem_label == c) for c in SK_MOVING_LABELS))
    return moving_mask

def seq_scan_has_movers(seq, scan_id):
    scene_base = get_scene_basepath(seq)
#     moving_mask = SingleSequenceWorldCloudFuser.get_moving_mask_for_scan(scene_base, scan_id)
    moving_mask = get_moving_mask_for_scan(scene_base, scan_id)
    return moving_mask.any()

def read_scan_get_cloud(seq, scan_id, remove_movers=True, filter_ego=True):
    scan_name = str(scan_id).rjust(6, '0')
    scene_base = get_scene_basepath(seq)
    scan_path = os.path.join(scene_base, 'velodyne', scan_name + '.bin')

    # Read the raw lidar
    lidar = np.frombuffer(open(scan_path, 'rb').read(), dtype=np.float32).reshape((-1, 4))
    cloud = np.ones(lidar.shape)  # need homogenous for change below
    cloud[:, 0:3] = lidar[:, 0:3]

    if remove_movers:
        # Clean out points for anything moving
        moving_mask = get_moving_mask_for_scan(scene_base, scan_id)
        cloud = cloud[~moving_mask]#[:, :3]
    
    if filter_ego:
        pass # TODO ~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    return cloud



class SemanticKITTISDTable(StampedDatumTableBase):
    
    ONLY_FRAMES_WITH_NO_MOVERS = True
    
    @classmethod
    def _get_all_segment_uris(cls):
        return [
            datum.URI(
                dataset='semantikitti-psegs-fused',
                split='train',
                segment_id=str(seq))
            for seq in SK_SEQ_TO_NSCANS.keys()
        ]

    @classmethod
    def _create_datum_rdds(cls, spark, existing_uri_df=None, only_segments=None):
        assert existing_uri_df is None, "Resume feature not supported"
        
        seg_uris = cls.get_all_segment_uris()
        if only_segments:
            util.log.info("Filtering to only %s segments" % len(only_segments))
            seg_uris = [
                uri for uri in seg_uris
                if any(
                  suri.soft_matches_segment(uri) for suri in only_segments)
            ]
        
        datum_rdds = []
        for seg_uri in seg_uris:
            seq = seg_uri.segment_id
            if cls.ONLY_FRAMES_WITH_NO_MOVERS:
                util.log.info("Finding scans for sequence %s with no movering points ..." % seq)
                n_scans = SK_SEQ_TO_NSCANS[seq]
                slices = n_scans // 100
                task_rdd = spark.sparkContext.parallelize(range(n_scans), numSlices=slices)
                scan_has_no_movers = lambda scan_id: (not seq_scan_has_movers(seq, scan_id))
                scans_no_movers = task_rdd.filter(scan_has_no_movers).collect()
                util.log.info("... sequence %s has %s scans with no movers." % (seq, len(scans_no_movers)))
                scan_ids = scans_no_movers
            else:
                scan_ids = list(range(SK_SEQ_TO_NSCANS[seq]))
            
            
            tasks = [(seg_uri, scan_id) for scan_id in scan_ids]
            
            # Emit camera_image RDD
            ctask_rdd = spark.sparkContext.parallelize(tasks)
            datum_rdd = ctask_rdd.map(lambda t: cls.create_camera_frame(*t))
            datum_rdds.append(datum_rdd)
            
            # Emit ego_pose RDD
            ptask_rdd = spark.sparkContext.parallelize(tasks)
            datum_rdd = ptask_rdd.map(lambda t: cls.create_ego_pose(*t))
            datum_rdds.append(datum_rdd)
            
            # Emit velodyne cloud RDD
            pctask_rdd = spark.sparkContext.parallelize(tasks[:100])
            datum_rdd = pctask_rdd.map(lambda t: cls.create_point_cloud_in_world(*t))
            datum_rdds.append(datum_rdd)
#             # Emit world cloud once
#             wc_rdd = spark.sparkContext.parallelize([seg_uri])
#             datum_rdd = wc_rdd.map(lambda t: cls.create_world_cloud(t))
#             datum_rdds.append(datum_rdd)
    
        return datum_rdds
    
    @classmethod
    def _get_calib(cls, seq):
        if not hasattr(cls, '_calib'):
            cls._calib = {}
        if seq not in cls._calib:
            cls._calib[seq] = get_calibration(seq)
        return cls._calib[seq]
    
    @classmethod
    def _get_poses(cls, seq):
        if not hasattr(cls, '_poses'):
            cls._poses = {}
        if seq not in cls._poses:
            cls._poses[seq] = get_poses(seq)
        return cls._poses[seq]
    
    @classmethod
    def create_camera_frame(cls, base_uri, scan_id):
        seq = base_uri.segment_id
        calib = cls._get_calib(seq)
        
        uri = copy.deepcopy(base_uri)
        uri.topic = 'camera|left_rect'
        uri.timestamp = int(scan_id) # HACK!

        scene_base = get_scene_basepath(seq)
        scan_name = str(scan_id).rjust(6, '0')
        img_path = os.path.join(scene_base, 'image_2/', scan_name + '.png')
        assert os.path.exists(img_path), (
            "Did you remember to expand data_odometry_color.zip ? %s not found" % img_path)
        with open(img_path, 'rb') as f:
            width, height = util.get_png_wh(f.read(100)) # Util only needs the first few bytes
        
        import imageio
        image_factory = lambda: imageio.imread(img_path)
        
        # HACK!!!  This is actually P !!!
        K = calib['P2']
        
        # hack! this is lidar to cam
        ego_to_sensor = datum.Transform.from_transformation_matrix(
                calib['Tr'], src_frame='lidar', dest_frame=uri.topic)
        
        sd_ego_pose = cls.create_ego_pose(base_uri, scan_id)
        ego_pose = sd_ego_pose.transform
        ci = datum.CameraImage(
              sensor_name=uri.topic,
              image_factory=image_factory,
              width=width,
              height=height,
              timestamp=uri.timestamp,
              ego_pose=ego_pose,
              K=K,
              ego_to_sensor=ego_to_sensor,
              extra={'semantic_kitti.scan_id': str(scan_id)})
        return datum.StampedDatum(uri=uri, camera_image=ci)
    
    @classmethod
    def create_ego_pose(cls, base_uri, scan_id):
        seq = base_uri.segment_id
        poses = cls._get_poses(seq)
        
        uri = copy.deepcopy(base_uri)
        uri.topic = 'ego_pose'
        uri.timestamp = int(scan_id) # HACK!
        
        # Hack! believe ego frame is lidar here?
        ego_pose = datum.Transform.from_transformation_matrix(
                poses[scan_id], src_frame='world', dest_frame='ego')

        return datum.StampedDatum(uri=uri, transform=ego_pose)
    
    @classmethod
    def create_point_cloud_in_world(cls, base_uri, scan_id):
        
        uri = copy.deepcopy(base_uri)
        uri.topic = 'lidar|world' + ('_cleaned' if cls.ONLY_FRAMES_WITH_NO_MOVERS else '')
        uri.timestamp = int(scan_id) # HACK!
        
        sd_ego_pose = cls.create_ego_pose(base_uri, scan_id)
        ego_pose = sd_ego_pose.transform
        
        def _get_cloud(seq, sid):
            cloud = read_scan_get_cloud(
                        seq,
                        sid,
                        remove_movers=cls.ONLY_FRAMES_WITH_NO_MOVERS)
            
            # Move cloud into the world frame
            calib = cls._get_calib(seq)
            all_poses = cls._get_poses(seq)
            Tr = calib["Tr"]
            Tr_inv = np.linalg.inv(Tr)
            cam2_pose = all_poses[sid]
            pose = np.matmul(Tr_inv, np.matmul(cam2_pose, Tr))
            cloud = np.matmul(pose, cloud.T).T
            
            return cloud

        pc = datum.PointCloud(
          sensor_name=uri.topic,
          timestamp=uri.timestamp,
          cloud_factory=lambda: _get_cloud(base_uri.segment_id, scan_id),
          ego_to_sensor=datum.Transform(), # Hack! cloud is in world frame
          ego_pose=ego_pose,
          extra={'semantic_kitti.scan_id': str(scan_id)})
        return datum.StampedDatum(uri=uri, point_cloud=pc)
        

seg_uris = SemanticKITTISDTable.get_all_segment_uris()
# sd_rdd = SemanticKITTISDTable._get_segment_datum_rdd_or_df(spark, seg_uris[0])
# print(sd_rdd.count())

In [10]:
from psegs.exp.fused_lidar import FusedWorldCloudTableBase

class SemanticKITTIFusedWorldCloudTable(FusedWorldCloudTableBase):
    SRC_SD_TABLE = SemanticKITTISDTable
    
    @classmethod
    def _get_task_lidar_cuboid_rdd(cls, spark, segment_uri):
        seg_rdd = cls.SRC_SD_TABLE.get_segment_datum_rdd(spark, segment_uri)
        
        # SemanticKITTI has no cuboids, so the Fuser algo simply concats the cloud points
        def iter_task_rows(iter_sds):
            from pyspark import Row
            from oarphpy.spark import RowAdapter
            for sd in iter_sds:
                if sd.point_cloud is not None:
                    pc = sd.point_cloud
                    task_id = "%s.%s" % (sd.uri.segment_id, pc.extra['semantic_kitti.scan_id'])
                    yield Row(
                        task_id=task_id,
                        point_clouds=[pc],
                        cuboids=[])
        
        task_rdd = seg_rdd.mapPartitions(iter_task_rows)
        return task_rdd
        
seg_uris = SemanticKITTIFusedWorldCloudTable.get_all_segment_uris()
sd_rdd = SemanticKITTIFusedWorldCloudTable._get_segment_datum_rdd_or_df(spark, seg_uris[0])
print(sd_rdd.count())

2021-02-11 09:51:44,067	ps   17241 : Filtering to only 1 segments
INFO - 2021-02-11 09:51:44,067 - fused_lidar - Filtering to only 1 segments
2021-02-11 09:51:44,068	ps   17241 : SemanticKITTIFusedWorldCloudTable building fused world clouds ...
INFO - 2021-02-11 09:51:44,068 - fused_lidar - SemanticKITTIFusedWorldCloudTable building fused world clouds ...
2021-02-11 09:51:44,069	ps   17241 : ... have 1 segments to fuse ...
INFO - 2021-02-11 09:51:44,069 - fused_lidar - ... have 1 segments to fuse ...
2021-02-11 09:51:44,070	ps   17241 : ... working on 00 ...
INFO - 2021-02-11 09:51:44,070 - fused_lidar - ... working on 00 ...
2021-02-11 09:51:44,071	ps   17241 : ... have fused cloud; skipping! /opt/psegs/dataroot/fused_world_clouds/naive_cuboid_scrubber/semantikitti-psegs-fused/train/00/fused_world.ply
INFO - 2021-02-11 09:51:44,071 - fused_lidar - ... have fused cloud; skipping! /opt/psegs/dataroot/fused_world_clouds/naive_cuboid_scrubber/semantikitti-psegs-fused/train/00/fused_world.

1
