# Synthetic Optical Flow from Fused Lidar


In [1]:
import sys
sys.path.append('/opt/psegs')

import numpy as np

from psegs.exp.fused_lidar_flow import FusedLidarCloudTableBase
from psegs.exp.fused_lidar_flow import TaskLidarCuboidCameraDFFactory

import IPython.display
import PIL.Image


## General Notebook Utilities
    
def imshow(x):
    IPython.display.display(PIL.Image.fromarray(x))

def show_html(x):
    from IPython.core.display import display, HTML
    display(HTML(x))

## SemanticKITTI

In [2]:
from psegs.exp.semantic_kitti import SemanticKITTISDTable

class SemanticKITTILCCDFFactory(TaskLidarCuboidCameraDFFactory):
    
    SRC_SD_TABLE = SemanticKITTISDTable
    
    @classmethod
    def build_df_for_segment(cls, spark, segment_uri):
        seg_rdd = cls.SRC_SD_TABLE.get_segment_datum_rdd(spark, segment_uri)
        
        def to_task_row(scan_id_iter_sds):
            scan_id, iter_sds = scan_id_iter_sds
            camera_images = []
            point_clouds = []
            for sd in iter_sds:
                if sd.camera_image is not None:
                    camera_images.append(sd.camera_image)
                elif sd.point_cloud is not None:
                    point_clouds.append(sd.point_cloud)
            
            from pyspark import Row
            r = Row(
                    task_id=int(scan_id),
                    point_clouds=point_clouds,
                    cuboids=[], # SemanticKITTI has no cuboids
                    camera_images=camera_images) 
            from oarphpy.spark import RowAdapter
            return RowAdapter.to_row(r)
            
        grouped = seg_rdd.groupBy(lambda sd: sd.uri.extra['semantic_kitti.scan_id'])
        row_rdd = grouped.map(to_task_row)

        df = spark.createDataFrame(row_rdd, schema=cls.table_schema())
        df = df.persist()
        return df

class SemanticKITTIFusedWorldCloudTable(FusedLidarCloudTableBase):
    TASK_DF_FACTORY = SemanticKITTILCCDFFactory

    # SemanticKITTI has no cuboids, so we skip this step.
    HAS_OBJ_CLOUDS = False


## KITTI-360

In [3]:
from psegs.datasets.kitti_360 import KITTI360SDTable
class KITTI360OurFusedClouds(KITTI360SDTable):
    INCLUDE_FISHEYES = False
    INCLUDE_FUSED_CLOUDS = False  # Use our own fused clouds

class KITTI360OurFusedWorldCloudTable(FusedLidarCloudTableBase):
    SRC_SD_TABLE = KITTI360OurFusedClouds
    
    @classmethod
    def _get_task_lidar_cuboid_rdd(cls, spark, segment_uri):
        datum_df = cls.SRC_SD_TABLE.get_segment_datum_df(spark, segment_uri)
        datum_df.registerTempTable('datums')
        spark.catalog.dropTempView('culi_tasks_df')
        print('Building tasks table for %s ...' % segment_uri.segment_id)
        spark.sql("""
          CACHE TABLE culi_tasks_df OPTIONS ( 'storageLevel' 'DISK_ONLY' ) AS
          SELECT 
              CONCAT(uri.segment_id, '.', uri.extra.`kitti-360.frame_id`) AS task_id,
              FLATTEN(COLLECT_LIST(cuboids)) AS cuboids, 
              COLLECT_LIST(point_cloud) AS point_clouds
          FROM datums
          WHERE 
              uri.topic LIKE '%cuboid%' OR uri.topic LIKE '%lidar%'
          GROUP BY task_id
        """)
        
        
        # TODO! for lidar and camera image!
        #         both_have_ego_pose = (
        #             ci1.extra.get('kitti-360.has-valid-ego-pose') and
        #             ci2.extra.get('kitti-360.has-valid-ego-pose'))
        
        tasks_df = spark.sql('SELECT * FROM culi_tasks_df')
        print('... done.')
        return tasks_df.rdd


## NuScenes

In [4]:
# !pip3 install nuscenes-devkit==1.1.2
from psegs.datasets.nuscenes import NuscStampedDatumTableBase, NuscStampedDatumTableLabelsAllFrames
class NuscFusedWorldCloudTableBase(FusedLidarCloudTableBase):
    #SRC_SD_TABLE = NuscStampedDatumTableBase -- Subclass must pick a NuScenes configuration!
    
    SPLITS = ['train_detect', 'train_track']
    
    @classmethod
    def _filter_ego_vehicle(cls, cloud_ego):
        # Note: NuScenes authors have already corrected clouds for ego motion:
        # https://github.com/nutonomy/nuscenes-devkit/issues/481#issuecomment-716250423
        # But have not filtered out ego self-returns
        cloud_ego = cloud_ego[np.where(  ~(
                        (cloud_ego[:, 0] <= 1.5) & (cloud_ego[:, 0] >= -1.5) &  # Nusc lidar +x is +right
                        (cloud_ego[:, 1] <= 2.5) & (cloud_ego[:, 0] >= -2.5) &  # Nusc lidar +y is +forward
                        (cloud_ego[:, 1] <= 1.5) & (cloud_ego[:, 0] >= -1.5)    # Nusc lidar +z is +up
        ))]
        return cloud_ego
    
    @classmethod
    def _get_task_lidar_cuboid_rdd(cls, spark, segment_uri):
        datum_df = cls.SRC_SD_TABLE.get_segment_datum_df(spark, segment_uri)
        datum_df.registerTempTable('datums')
        spark.catalog.dropTempView('culi_tasks_df')
        print('Building tasks table for %s ...' % segment_uri.segment_id)
        if cls.SRC_SD_TABLE.LABELS_KEYFRAMES_ONLY:
            # For Nusc: group by nuscenes-sample-token WITH KEYFRAMES
            spark.sql("""
              CACHE TABLE culi_tasks_df OPTIONS ( 'storageLevel' 'DISK_ONLY' ) AS
              SELECT 
                  uri.extra.`nuscenes-sample-token` AS task_id,
                  FLATTEN(COLLECT_LIST(cuboids)) AS cuboids, 
                  COLLECT_LIST(point_cloud) AS point_clouds
              FROM datums
              WHERE 
                uri.extra.`nuscenes-is-keyframe` = 'True' AND (
                  uri.extra['nuscenes-label-channel'] is NULL OR 
                  uri.extra['nuscenes-label-channel'] LIKE '%LIDAR%'
                ) AND (
                  uri.topic LIKE '%cuboid%' OR
                  uri.topic LIKE '%lidar%'
                )
              GROUP BY task_id
            """)
        else:
            # For Nusc: group by nuscenes-sample-token WITH ALL FRAMES
            spark.sql("""
              CACHE TABLE culi_tasks_df OPTIONS ( 'storageLevel' 'DISK_ONLY' ) AS
              SELECT 
                  CONCAT(uri.segment_id, '.', uri.timestamp) AS task_id,
                  FLATTEN(COLLECT_LIST(cuboids)) AS cuboids, 
                  COLLECT_LIST(point_cloud) AS point_clouds
              FROM datums
              WHERE 
                (
                  uri.extra['nuscenes-label-channel'] is NULL OR 
                  uri.extra['nuscenes-label-channel'] LIKE '%LIDAR%'
                ) AND (
                  uri.topic LIKE '%cuboid%' OR
                  uri.topic LIKE '%lidar%'
                )
              GROUP BY task_id
              HAVING SIZE(cuboids) > 0 AND SIZE(point_clouds) > 0
            """)
        
        tasks_df = spark.sql('SELECT * FROM culi_tasks_df')
        print('... done.')
        return tasks_df.rdd

class NuscFusedWorldCloudKeyframesOnlyTable(NuscFusedWorldCloudTableBase):
    SRC_SD_TABLE = NuscStampedDatumTableBase

class NuscFusedWorldCloudAllFramesTable(NuscFusedWorldCloudTableBase):
    SRC_SD_TABLE = NuscStampedDatumTableLabelsAllFrames
    

## Start Spark

In [5]:
from psegs.spark import NBSpark
spark = NBSpark.getOrCreate()

2021-02-18 09:24:15,079	oarph 1137111 : Using source root /opt/psegs/psegs 
2021-02-18 09:24:15,080	oarph 1137111 : Using source root /opt/psegs 
2021-02-18 09:24:15,164	oarph 1137111 : Generating egg to /tmp/tmpm8685e53_oarphpy_eggbuild ...
2021-02-18 09:24:15,246	oarph 1137111 : ... done.  Egg at /tmp/tmpm8685e53_oarphpy_eggbuild/psegs-0.0.0-py3.8.egg


## Build Fused Lidar Assets

```
docker --context default run -it --name=potree_viewer --rm --net=host -v `pwd`:/shared  jonazpiazu/potree
```

In [6]:
# T = KITTI360OurFusedWorldCloudTable
# rdds = T._create_datum_rdds(spark)
# print([r.count() for r in rdds])

# seg_uris = T.get_all_segment_uris()
# samp = T.get_sample(seg_uris[0], spark=spark)


In [7]:
# print([lc.sensor_name for lc in samp.lidar_clouds][:10])
# c = samp.lidar_clouds[0]#[lc for lc in samp.lidar_clouds if lc.sensor_name == '11002'][0]
# print(c.get_cloud().shape)
# imshow(c.get_bev_debug_image(x_bounds_meters=None, y_bounds_meters=None))
# imshow(c.get_front_rv_debug_image(y_bounds_meters=None, z_bounds_meters=None))

## Compute Candidate Optical Flow Pairs

## Render Optical Flow

In [8]:
from psegs.exp.fused_lidar_flow import OpticalFlowRenderBase

class SemanticKITTIOFlowRenderer(OpticalFlowRenderBase):
    FUSED_LIDAR_SD_TABLE = SemanticKITTIFusedWorldCloudTable

R = SemanticKITTIOFlowRenderer
seg_uris = R.FUSED_LIDAR_SD_TABLE.get_all_segment_uris()
R.build(spark=spark, only_segments=[seg_uris[0]])

2021-02-18 09:24:18,520	ps   1137111 : Found Sequence 00 with 4541 scans
2021-02-18 09:24:18,522	ps   1137111 : Found Sequence 01 with 1101 scans
2021-02-18 09:24:18,525	ps   1137111 : Found Sequence 02 with 4661 scans
2021-02-18 09:24:18,526	ps   1137111 : Found Sequence 03 with 801 scans
2021-02-18 09:24:18,528	ps   1137111 : Found Sequence 05 with 2761 scans
2021-02-18 09:24:18,529	ps   1137111 : Found Sequence 06 with 1101 scans
2021-02-18 09:24:18,530	ps   1137111 : Found Sequence 07 with 1101 scans
2021-02-18 09:24:18,532	ps   1137111 : Found Sequence 09 with 1591 scans
2021-02-18 09:24:18,533	ps   1137111 : Found Sequence 10 with 1201 scans
2021-02-18 09:24:18,533	ps   1137111 : Found 18859 total scans
2021-02-18 09:24:18,535	ps   1137111 : Filtering to only 1 segments
2021-02-18 09:24:18,535	ps   1137111 : Finding scans for sequence 00 with no moving points ...
2021-02-18 09:24:21,964	ps   1137111 : ... sequence 00 has 3097 scans with no movers.
2021-02-18 09:24:23,501	ps   113

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 9.0 failed 1 times, most recent failure: Lost task 2.0 in stage 9.0 (TID 168, 192.168.0.213, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 605, in main
    process()
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 595, in process
    out_iter = func(split_index, iterator)
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 425, in func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 1141, in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 1141, in <genexpr>
  File "/tmp/spark-7208c596-d864-4eb7-a4e2-d4103c63798b/userFiles-e943c48c-d96c-4703-bc73-783dc2f2adb0/psegs-0.0.0-py3.8.egg/psegs/exp/fused_lidar_flow.py", line 1062, in _render_oflow_tasks
    if 'lidar|objects_fused' in p.uri.topic:
AttributeError: 'PointCloud' object has no attribute 'uri'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:503)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:638)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1004)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2139)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:127)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:446)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:449)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2008)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2007)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2007)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:973)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2239)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2188)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2177)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:775)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2120)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2139)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2164)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1004)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:388)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1003)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:168)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 605, in main
    process()
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 595, in process
    out_iter = func(split_index, iterator)
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 2596, in pipeline_func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 425, in func
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 1141, in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pyspark-3.0.1-py3.8.egg/pyspark/rdd.py", line 1141, in <genexpr>
  File "/tmp/spark-7208c596-d864-4eb7-a4e2-d4103c63798b/userFiles-e943c48c-d96c-4703-bc73-783dc2f2adb0/psegs-0.0.0-py3.8.egg/psegs/exp/fused_lidar_flow.py", line 1062, in _render_oflow_tasks
    if 'lidar|objects_fused' in p.uri.topic:
AttributeError: 'PointCloud' object has no attribute 'uri'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:503)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:638)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1004)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2139)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:127)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:446)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:449)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
