# Cheap Optical Flow: Is it Good? Does it Boost?


## Quickstart

## Credits

Some portions of this notebook adapted from:
 * [Middlebury Flow code by Johannes Oswald](https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py)
 * [DeepDeform Demo Code](https://github.com/AljazBozic/DeepDeform)
 * [OpticalFlowToolkit by RUOTENG LI](https://github.com/liruoteng/OpticalFlowToolkit)
 * [OpenCV Samples](https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py)

In [1]:
# parameters
SHOW_DEMO_OUTPUT = False
DEMO_FPS = []

RUN_FULL_ANALYSIS = False
ALL_FP_FACTORY_CLSS = []

In [2]:
## Setup

!pip3 install pypng scikit-image
print('fixme installs')
print()
print()

import copy
import imageio
import IPython.display
import math
import os
import PIL.Image
import six
import sys
import tempfile


## General Notebook Utilities
    
def imshow(x):
    IPython.display.display(PIL.Image.fromarray(x))

def show_html(x):
    from IPython.core.display import display, HTML
    display(HTML(x))


## Create a random temporary directory for analysis library (for Spark-enabled full analysis mode)
old_cwd = os.getcwd()
tempdir = tempfile.TemporaryDirectory(suffix='_cheap_optical_flow_eval_analysis')
ALIB_SRC_DIR = tempdir.name
print("Putting analysis lib in %s" % ALIB_SRC_DIR)
os.chdir(ALIB_SRC_DIR)
!mkdir -p cheap_optical_flow_eval_analysis
!touch cheap_optical_flow_eval_analysis/__init__.py

%load_ext autoreload
%autoreload 2
sys.path.append(ALIB_SRC_DIR)


## Prepare a build of local psegs for inclusion
!cd /opt/psegs && python3 setup.py clean bdist_egg
PSEGS_EGG_PATH = '/opt/psegs/dist/psegs-0.0.1-py3.8.egg'
assert os.path.exists(PSEGS_EGG_PATH), "Build failed?"
sys.path.append('/opt/psegs')
import psegs


## Prepare Spark session with local PSegs and local Analysis Lib
from psegs.spark import NBSpark
NBSpark.SRC_ROOT = os.path.join(ALIB_SRC_DIR, 'cheap_optical_flow_eval_analysis')
NBSpark.CONF_KV.update({
    'spark.driver.maxResultSize': '2g',
    'spark.driver.memory': '16g',
    'spark.submit.pyFiles': PSEGS_EGG_PATH,
  })
spark = NBSpark.getOrCreate()

fixme installs


Putting analysis lib in /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis
running clean
running bdist_egg
running egg_info
writing psegs.egg-info/PKG-INFO
writing dependency_links to psegs.egg-info/dependency_links.txt
writing top-level names to psegs.egg-info/top_level.txt
reading manifest file 'psegs.egg-info/SOURCES.txt'
writing manifest file 'psegs.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
copying psegs/exp/fused_lidar_flow.py -> build/lib/psegs/exp

creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/psegs
copying build/lib/psegs/dummyrun.py -> build/bdist.linux-x86_64/egg/psegs
creating build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/kitti.py -> build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/idsutil.py -> build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/kitti_360.py -> build/bdist.l

2021-04-15 02:01:27,469	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:27,470	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:27,514	oarph 3501100 : Generating egg to /tmp/tmp7ftnd9ii_oarphpy_eggbuild ...
2021-04-15 02:01:27,525	oarph 3501100 : ... done.  Egg at /tmp/tmp7ftnd9ii_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [3]:
%%writefile cheap_optical_flow_eval_analysis/ofp.py

## Data Model & Utility Code

import attr
import cv2
import imageio
import math
import os
import PIL.Image
import six

import numpy as np

from psegs import datum

from oarphpy import plotting as op_plt
from oarphpy.spark import CloudpickeledCallable
img_to_data_uri = lambda x: op_plt.img_to_data_uri(x, format='png')

@attr.s(slots=True, eq=False, weakref_slot=False)
class OpticalFlowPair(object):
    """A flyweight for a pair of images with an optical flow field.
    Supports lazy-loading of large data attributes."""
    
    ## Core Attributes (Required for All Datasets)
    
    dataset = attr.ib(type=str, default='')
    """(Display name) To which dataset does this pair belong?"""
    
    id1 = attr.ib(type=str, default='')
    """Identifier or URI for the first image"""
    
    id2 = attr.ib(type=str, default='')
    """Identifier or URI for the second image"""
    
    img1 = attr.ib(default=None)
    """URI or numpy array or CloudPickleCallable for the first image (source image)"""

    img2 = attr.ib(default=None)
    """URI or numpy array or CloudpickeledCallable for the second image (target image)"""
    
    flow = attr.ib(default=None)
    """A numpy array or callable or CloudpickeledCallable representing optical flow from img1 -> img2"""
    
    uri = attr.ib(type=datum.URI, default=None, converter=datum.URI.from_str)
    """A URI addressing this pair; to make dynamic construction of the pair easier"""
    
    
    ## Optional Attributes (For Select Datasets)
    
    diff_time_sec = attr.ib(type=float, default=0.0)
    """Difference in time (in seconds) between the views / poses depicted in `img1` and `img2`."""
    
    translation_meters = attr.ib(type=float, default=0.0)
    """Difference in ego translation (in meters) between the views / poses depicted in `img1` and `img2`."""
    
    uvdviz_im1 = attr.ib(default=None)
    """An nx4 numpy array representing UVD-visible points for `img1`"""
    
    uvdviz_im2 = attr.ib(default=None)
    """An nx4 numpy array representing UVD-visible points for `img2`"""
    
    K = attr.ib(default=None)
    """A 3x3 numpy array representing the camera matrix K for both views"""
    
    # to add:
    # semantic image for frame 1, frame 2 [could be painted by cuboids]
    # instance images for frame 1, frame 2 [could be painted by cuboids]
    #   -- for colored images, at first just pivot all oflow metrics by colors
    # get uvdviz1 uvdviz2 (scene flow)
    #   * for deepeform, their load_flow will work
    #   * for kitti, we have to read their disparity images
    # get uvd1 uvd2 (lidar for nearest neighbor stuff)
    # depth image for frame 1, frame 2 [could be interpolated by cuboids]
    #   -- at first bucket the depth coarsely and pivot al oflow by colors
    
    def get_img1(self):
        if isinstance(self.img1, CloudpickeledCallable):
            self.img1 = self.img1()
        if isinstance(self.img1, six.string_types):
            self.img1 = imageio.imread(self.img1)
        return self.img1
    
    def get_img2(self):
        if isinstance(self.img2, CloudpickeledCallable):
            self.img2 = self.img2()
        if isinstance(self.img2, six.string_types):
            self.img2 = imageio.imread(self.img2)
        return self.img2
    
    def get_flow(self):
        if not isinstance(self.flow, (np.ndarray, np.generic)):
            self.flow = self.flow()
        return self.flow
    
    def has_scene_flow(self):
        return (self.uvdviz_im1 is not None and uvdviz_im2 is not None and self.K is not None)
    
    def get_sf_viz_html(self):
        uvd1 = self.uvdviz_im1[self.uvdviz_im1[:, -1] == 1, :2]
        uvd2 = self.uvdviz_im2[self.uvdviz_im2[:, -1] == 1, :2]
        xyzrgb1 = uvd_to_xyzrgb(uvd1, self.K, imgs=[self.get_img1()])
        xyzrgb2 = uvd_to_xyzrgb(uvd2, self.K, imgs=[self.get_img2()])
        html1 = create_xyzrgb_3d_plot_html(xyzrgb1)
        html2 = create_xyzrgb_3d_plot_html(xyzrgb2)
        
        html = "View 1:<br />%s<br /><br />View 2:<br />%s" % (html1, html2)
        return html
    
    def to_html(self):
        im1 = self.get_img1()
        im2 = self.get_img2()
        flow = self.get_flow()
        fviz = draw_flow(im1, flow)
        
        sf_html = ''
        if self.has_scene_flow():
            sf_html = """
                <tr><td style="text-align:left"><b>Scene Flow</b></td></tr>
                <tr><td>{viz_html></td></tr>
            """format(viz_html=self.get_sf_viz_html())
        
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Dataset:</b> {dataset}</td></tr>
            <tr><td style="text-align:left"><b>URI:</b> {uri}</td></tr>
            
            <tr><td style="text-align:left"><b>Source Image:</b> {id1}</td></tr>
            <tr><td><img src="{im1}" /></td></tr>

            <tr><td style="text-align:left"><b>Target Image:</b> {id2}</td></tr>
            <tr><td><img src="{im2}" /></td></tr>

            <tr><td style="text-align:left"><b>Flow</b></td></tr>
            <tr><td><img src="{fviz}" /></td></tr>
            
            {sf_html}
            </table>
        """.format(
                dataset=self.dataset,
                uri=str(self.uri),
                id1=self.id1, id2=self.id2,
                im1=img_to_data_uri(im1), im2=img_to_data_uri(im2),
                fviz=img_to_data_uri(fviz),
                sf_html=sf_html)
        return html

def draw_flow(img, flow, step=8):
    """Based upon OpenCV sample: https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py"""
    h, w = img.shape[:2]
    y, x = np.mgrid[step/2:h:step, step/2:w:step].reshape(2,-1).astype(int)
    fx, fy = flow[y,x].T
    lines = np.vstack([x, y, x+fx, y+fy]).T.reshape(-1, 2, 2)
    lines = np.int32(lines + 0.5)
    vis = img.copy()
    cv2.polylines(vis, lines, 0, (0, 255, 0))
    for (x1, y1), (_x2, _y2) in lines:
        cv2.circle(vis, (x1, y1), 1, (0, 255, 0), -1)
    return vis

def uvd_to_xyzrgb(uvd, K, imgs=None):
    import numpy as np
    from psegs import datum

    fx = K[0, 0]
    cx = K[0, 2]
    fy = K[1, 1]
    cy = K[1, 2]
    
    xyz = np.zeros((uvd.shape[0], 3))
    xyz[:, 0] = (uvd[:, 0] - cx) / fx
    xyz[:, 1] = (uvd[:, 1] - cy) / fy
    xyz[:, 1] = 1.
    xyz = uvd[:, 2] * xyz / np.linalg.norm(xyz, axis=-1)[:, np.newaxis]
    
    from psegs import datum
    pc = (cloud=xyz)
    cis = [datum.CameraImage() for img in (imgs or [])]
    xyzrgb = datum.PointCloud.paint_ego_cloud(xyz, camera_images=cis)
    return xyzrgb

def create_xyzrgb_3d_plot_html(xyzrgb, max_points=100000):
    import plotly
    import plotly.graph_objects as go
    import pandas as pd

    pcloud_df = pd.DataFrame(xyzrgb, columns=['x', 'y', 'z', 'r', 'g', 'b'])
    pcloud_df = pcloud_df.sample(n=min(xyzrgb.shape[0], max_points))
    scatter = go.Scatter3d(
                x=pcloud_df['x'], y=pcloud_df['y'], z=pcloud_df['z'],
                mode='markers',
                marker=dict(size=2, color=pcloud_df[['r', 'g', 'b']], opacity=0.9))
    fig = go.Figure(data=[scatter])
    fig.update_layout(
          width=1000, height=700,
          scene_aspectmode='data')
    html = plotly.offline.plot(fig, output_type='div')
    return html

class FlowPairFactoryBase(object):
    DATASET = ''

    @classmethod
    def list_fp_uris(cls, spark):
        return []
    
    @classmethod
    def get_fp_rdd_for_uris(cls, spark, uris):
        uris = [datum.URI.from_str(u) for u in uris]
        uris = [u for u in uris if u.dataset == cls.DATASET]
        if not uris:
            return None
        return cls._get_fp_rdd_for_uris(spark, uris)

    @classmethod
    def _get_fp_rdd_for_uris(cls, spark, uris):
        return None

class FlowPairUnionFactory(FlowPairFactoryBase):
    FACTORIES = []
    
    @classmethod
    def list_fp_uris(cls, spark):
        import itertools
        return list(itertools.chain.from_iterable(F.list_fp_uris(spark) for F in cls.FACTORIES))
    
    @classmethod
    def get_fp_rdd_for_uris(cls, spark, uris):
        rdds = []
        for F in cls.FACTORIES:
            rdd = F.get_fp_rdd_for_uris(spark, uris)
            if rdd is not None:
                rdds.append(rdd)
        assert rdds, "No RDDs for %s" % uris
        return spark.sparkContext.union(rdds)


Writing cheap_optical_flow_eval_analysis/ofp.py


In [4]:
from cheap_optical_flow_eval_analysis.ofp import *

2021-04-15 02:01:31,714	oarph 3501100 : Source has changed! Rebuilding Egg ...
2021-04-15 02:01:31,714	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,715	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,717	oarph 3501100 : Generating egg to /tmp/tmpcw0qhfcy_oarphpy_eggbuild ...
2021-04-15 02:01:31,724	oarph 3501100 : ... done.  Egg at /tmp/tmpcw0qhfcy_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


## Middlebury Optical Flow



In [5]:
# TODO talk configs



In [6]:
%%writefile cheap_optical_flow_eval_analysis/midd.py

from psegs import datum

from cheap_optical_flow_eval_analysis.ofp import *

# Please unzip `other-color-allframes.zip` and `other-gt-flow.zip` to a directory and provide the target below:
MIDD_DATA_ROOT = '/opt/psegs/ext_data/middlebury-flow/'

# For the Middlebury Flow dataset, we only consider the real scenes
MIDD_SCENES = [
    {
        'input': 'other-data/Dimetrodon/frame10.png',
        'expected_out': 'other-data/Dimetrodon/frame11.png',
        'flow_gt': 'other-gt-flow/Dimetrodon/flow10.flo',
    },
        {
        'input': 'other-data/Hydrangea/frame10.png',
        'expected_out': 'other-data/Hydrangea/frame11.png',
        'flow_gt': 'other-gt-flow/Hydrangea/flow10.flo',
    },
        {
        'input': 'other-data/RubberWhale/frame10.png',
        'expected_out': 'other-data/RubberWhale/frame11.png',
        'flow_gt': 'other-gt-flow/RubberWhale/flow10.flo',
    },
]


def midd_read_flow(path):
    import os
    import numpy as np
    # Based upon: https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py
    # compute colored image to visualize optical flow file .flo
    # Author: Johannes Oswald, Technical University Munich
    # Contact: johannes.oswald@tum.de
    # Date: 26/04/2017
    # For more information, check http://vision.middlebury.edu/flow/ 
    assert os.path.exists(path) and path.endswith('.flo'), path
    f = open(path, 'rb')
    flo_number = np.fromfile(f, np.float32, count=1)[0]
    TAG_FLOAT = 202021.25
    assert flo_number == TAG_FLOAT, 'Flow number %r incorrect.' % flo_number
    w = np.fromfile(f, np.int32, count=1)
    h = np.fromfile(f, np.int32, count=1)

    #if error try: data = np.fromfile(f, np.float32, count=2*w[0]*h[0])
    data = np.fromfile(f, np.float32, count=int(2*w*h))

    # Reshape data into 3D array (columns, rows, bands)
    flow = np.resize(data, (int(h), int(w), 2))	
    f.close()

    # We found that there are some invalid (?) (i.e. very large) flows, so we're going
    # to ignore those for this experiment.
    invalid = (flow >= 1666)
    flow[invalid] = 0

    return flow

def midd_create_fp(uri):
    scene_idx = int(uri.extra['midd.scene_idx'])
    scene = MIDD_SCENES[scene_idx]
    data_root = uri.extra['midd.dataroot']
    return OpticalFlowPair(
                uri=uri,
                dataset="Middlebury Optical Flow",
                id1=scene['input'],
                img1='file://' + os.path.join(data_root, scene['input']),
                id2=scene['expected_out'],
                img2='file://' + os.path.join(data_root, scene['expected_out']),
                flow=CloudpickeledCallable(lambda: midd_read_flow(os.path.join(data_root, scene['flow_gt']))))
    

class MiddFactory(FlowPairFactoryBase):
    DATASET = 'midd_oflow'
    
    @classmethod
    def list_fp_uris(cls, spark):
        return [
            datum.URI(dataset=cls.DATASET, extra={'midd.scene_idx': i, 'midd.dataroot': MIDD_DATA_ROOT})
            for i, scene in enumerate(MIDD_SCENES)
        ]
    
    @classmethod
    def _get_fp_rdd_for_uris(cls, spark, uris):
        uri_rdd = spark.sparkContext.parallelize(uris)
        fp_rdd = uri_rdd.map(midd_create_fp)
        return fp_rdd


Writing cheap_optical_flow_eval_analysis/midd.py


In [7]:
from cheap_optical_flow_eval_analysis.midd import MiddFactory

2021-04-15 02:01:31,835	oarph 3501100 : Source has changed! Rebuilding Egg ...
2021-04-15 02:01:31,836	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,837	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,838	oarph 3501100 : Generating egg to /tmp/tmpnmfxezoj_oarphpy_eggbuild ...
2021-04-15 02:01:31,845	oarph 3501100 : ... done.  Egg at /tmp/tmpnmfxezoj_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [8]:
ALL_FP_FACTORY_CLSS.append(MiddFactory)

print("Found %s Midd scenes" % len(MiddFactory.list_fp_uris(spark)))

if SHOW_DEMO_OUTPUT:
    demo_uris = MiddFactory.list_fp_uris(spark)
    fp_rdd = MiddFactory.get_fp_rdd_for_uris(spark, demo_uris)
    fps = fp_rdd.collect()
    
    for fp in fps:
        show_html(fp.to_html() + "<br/><br/><br/>")
        DEMO_FPS.append(fp)

# for i, scene in enumerate(MIDD_SCENES):
#     p = OpticalFlowPair(
#             dataset="Middlebury Optical Flow",
#             id1=scene['input'],
#             img1='file://' + os.path.join(MIDD_DATA_ROOT, scene['input']),
#             id2=scene['expected_out'],
#             img2='file://' + os.path.join(MIDD_DATA_ROOT, scene['expected_out']),
#             flow=CloudpickeledCallable(lambda: midd_read_flow(os.path.join(MIDD_DATA_ROOT, scene['flow_gt']))))
    
#     if RUN_FULL_ANALYSIS:
#         ALL_FPS.append(copy.deepcopy(p))
    
#     if SHOW_DEMO_OUTPUT:
#         show_html(p.to_html() + "<br/><br/><br/>")
#         DEMO_FPS.append(p)

Found 3 Midd scenes


## DeepDeform

In [9]:
%%writefile cheap_optical_flow_eval_analysis/deepdeform.py

from psegs import datum

from cheap_optical_flow_eval_analysis.ofp import *

# Please extract deepdeform_v1.7z to a directory and provide the target below:
DD_DATA_ROOT = '/opt/psegs/ext_data/deepdeform_v1/'

def dd_load_oflow(path):
    # Based upon https://github.com/AljazBozic/DeepDeform/blob/master/utils.py#L1
    import shutil
    import struct
    import os
    import numpy as np

    # Flow is stored row-wise in order [channels, height, width].
    assert os.path.isfile(path)

    flow_gt = None
    with open(path, 'rb') as fin:
        width = struct.unpack('I', fin.read(4))[0]
        height = struct.unpack('I', fin.read(4))[0]
        channels = struct.unpack('I', fin.read(4))[0]
        n_elems = height * width * channels

        flow = struct.unpack('f' * n_elems, fin.read(n_elems * 4))
        flow_gt = np.asarray(flow, dtype=np.float32).reshape([channels, height, width])

    # Match format used in this analysis
    flow_gt = np.moveaxis(flow_gt, 0, -1) # (h, w, 2)
    invalid_flow = flow_gt == -np.Inf
    flow_gt[invalid_flow] = 0.0
    return flow_gt

def dd_create_fp(uri):
    return OpticalFlowPair(
                uri=uri,
                dataset="DeepDeform Semi-Synthetic Optical Flow",
                id1=uri.extra['dd.input'],
                img1='file://' + os.path.join(DD_DATA_ROOT, uri.extra['dd.input']),
                id2=uri.extra['dd.expected_out'],
                img2='file://' + os.path.join(DD_DATA_ROOT, uri.extra['dd.expected_out']),
                flow=dd_load_oflow(os.path.join(DD_DATA_ROOT, uri.extra['dd.flow_gt'])))


class DDFactory(FlowPairFactoryBase):
    DATASET = 'deep_deform'
    
    @classmethod
    def _get_all_scenes(cls):
        import json
        DD_ALIGNMENTS = json.load(open(os.path.join(DD_DATA_ROOT, 'train_alignments.json')))
        ALL_DD_SCENES = [
            {
                "dd.input": ascene['source_color'],
                "dd.expected_out": ascene['target_color'],
                "dd.flow_gt": ascene['optical_flow'],
            }
            for ascene in DD_ALIGNMENTS
        ]
        return ALL_DD_SCENES
    
    @classmethod
    def list_fp_uris(cls, spark):
        scenes = cls._get_all_scenes()
        return [
            datum.URI(dataset=cls.DATASET, extra=scene)
            for scene in scenes
        ]
    
    @classmethod
    def _get_fp_rdd_for_uris(cls, spark, uris):
        uri_rdd = spark.sparkContext.parallelize(uris)
        fp_rdd = uri_rdd.map(dd_create_fp)
        return fp_rdd


Writing cheap_optical_flow_eval_analysis/deepdeform.py


In [10]:
from cheap_optical_flow_eval_analysis.deepdeform import *

2021-04-15 02:01:31,903	oarph 3501100 : Source has changed! Rebuilding Egg ...
2021-04-15 02:01:31,903	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,904	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:31,905	oarph 3501100 : Generating egg to /tmp/tmphp1vdjl9_oarphpy_eggbuild ...
2021-04-15 02:01:31,912	oarph 3501100 : ... done.  Egg at /tmp/tmphp1vdjl9_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [11]:
from psegs import datum

DD_DEMO_URIS = [
    datum.URI(dataset=DDFactory.DATASET, extra={
        "dd.input": "train/seq000/color/000000.jpg",
        "dd.expected_out": "train/seq000/color/000200.jpg",
        "dd.flow_gt": "train/seq000/optical_flow/blackdog_000000_000200.oflow",
    }),
    datum.URI(dataset=DDFactory.DATASET, extra={
        "dd.input": "train/seq000/color/000000.jpg",
        "dd.expected_out": "train/seq000/color/001200.jpg",
        "dd.flow_gt": "train/seq000/optical_flow/blackdog_000000_001200.oflow",
    }),
    datum.URI(dataset=DDFactory.DATASET, extra={
        "dd.input": "train/seq001/color/003400.jpg",
        "dd.expected_out": "train/seq001/color/003600.jpg",
        "dd.flow_gt": "train/seq001/optical_flow/lady_003400_003600.oflow",
    }),
    datum.URI(dataset=DDFactory.DATASET, extra={
        "dd.input": "train/seq337/color/000050.jpg",
        "dd.expected_out": "train/seq337/color/000350.jpg",
        "dd.flow_gt": "train/seq337/optical_flow/adult_000050_000350.oflow",
    }),
]

ALL_FP_FACTORY_CLSS.append(DDFactory)

print("Found %s DeepDeform scenes" % len(DDFactory.list_fp_uris(spark)))

if SHOW_DEMO_OUTPUT:
    fp_rdd = DDFactory.get_fp_rdd_for_uris(spark, DD_DEMO_URIS)
    fps = fp_rdd.collect()
    
    for fp in fps:
        show_html(fp.to_html() + "<br/><br/><br/>")
        DEMO_FPS.append(fp)



# import json
# DD_ALIGNMENTS = json.load(open(os.path.join(DD_DATA_ROOT, 'train_alignments.json')))
# ALL_DD_SCENES = [
#     {
#         "input": ascene['source_color'],
#         "expected_out": ascene['target_color'],
#         "flow_gt": ascene['optical_flow'],
#     }
#     for ascene in DD_ALIGNMENTS
# ]

# print("Found %s DeepDeform scenes" % len(ALL_DD_SCENES))
# if SHOW_DEMO_OUTPUT:
#     for scene in DD_DEMO_SCENES:
#         p = dd_create_fp(scene)
#         show_html(p.to_html())
#         DEMO_FPS.append(p)

# if RUN_FULL_ANALYSIS:
#     for scene in ALL_DD_SCENES:
#         p = dd_create_fp(scene)
#         ALL_FPS.append(p)
        

Found 4540 DeepDeform scenes


## Kitti Scene Flow Benchmark (2015)


In [12]:
# # Please unzip `data_scene_flow.zip` and `data_scene_flow_calib.zip` to a directory and provide that target below:
# KITTI_SF15_DATA_ROOT = '/opt/psegs/ext_data/kitti_scene_flow_2015/'



# from oarphpy import util as oputil
# KITTI_SF15_ALL_FLOW_OCC = [
#     os.path.basename(p)
#     for p in oputil.all_files_recursive(
#         os.path.join(KITTI_SF15_DATA_ROOT, 'training/flow_occ'), pattern='*.png')
# ]
    
# KITTI_SF15_ALL_SCENES = [
#     {
#         "input": 'training/image_2/%s' % fname,
#         "expected_out": 'training/image_2/%s' % fname.replace('_10', '_11'),
#         "flow_gt": 'training/flow_occ/%s' % fname,
#     }
#     for fname in KITTI_SF15_ALL_FLOW_OCC
# ]
# print("Found %s KITTI SceneFlow 2015 scenes" % len(KITTI_SF15_ALL_SCENES))



In [13]:
%%writefile cheap_optical_flow_eval_analysis/kittisf15.py

from psegs import datum

from cheap_optical_flow_eval_analysis.ofp import *

# Please unzip `data_scene_flow.zip` and `data_scene_flow_calib.zip` to a directory and provide that target below:
KITTI_SF15_DATA_ROOT = '/opt/psegs/ext_data/kitti_scene_flow_2015/'


def kittisf15_load_flow(path):
    # Based upon https://github.com/liruoteng/OpticalFlowToolkit/blob/master/lib/flowlib.py#L559
    import png
    import numpy as np
    flow_object = png.Reader(filename=path)
    flow_direct = flow_object.asDirect()
    flow_data = list(flow_direct[2])
    w, h = flow_direct[3]['size']
    flow = np.zeros((h, w, 3), dtype=np.float64)
    for i in range(len(flow_data)):
        flow[i, :, 0] = flow_data[i][0::3]
        flow[i, :, 1] = flow_data[i][1::3]
        flow[i, :, 2] = flow_data[i][2::3]

    invalid_idx = (flow[:, :, 2] == 0)
    flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
    flow[invalid_idx, 0] = 0
    flow[invalid_idx, 1] = 0
    return flow[:, :, :2]

def kittisf15_create_fp(uri):
    return OpticalFlowPair(
                uri=uri,
                dataset="KITTI Scene Flow 2015",
                id1=uri.extra['ksf15.input'],
                img1='file://' + os.path.join(KITTI_SF15_DATA_ROOT, uri.extra['ksf15.input']),
                id2=uri.extra['ksf15.expected_out'],
                img2='file://' + os.path.join(KITTI_SF15_DATA_ROOT, uri.extra['ksf15.expected_out']),
                flow=kittisf15_load_flow(os.path.join(KITTI_SF15_DATA_ROOT, uri.extra['ksf15.flow_gt'])))


class KITTISF15Factory(FlowPairFactoryBase):
    DATASET = 'kitti_sf15'
    
    @classmethod
    def _get_all_scenes(cls):
        from oarphpy import util as oputil
        KITTI_SF15_ALL_FLOW_OCC = [
            os.path.basename(p)
            for p in oputil.all_files_recursive(
                os.path.join(KITTI_SF15_DATA_ROOT, 'training/flow_occ'), pattern='*.png')
        ]

        KITTI_SF15_ALL_SCENES = [
            {
                "ksf15.input": 'training/image_2/%s' % fname,
                "ksf15.expected_out": 'training/image_2/%s' % fname.replace('_10', '_11'),
                "ksf15.flow_gt": 'training/flow_occ/%s' % fname,
            }
            for fname in KITTI_SF15_ALL_FLOW_OCC
        ]
        return KITTI_SF15_ALL_SCENES

    
    @classmethod
    def list_fp_uris(cls, spark):
        scenes = cls._get_all_scenes()
        return [
            datum.URI(dataset=cls.DATASET, extra=scene)
            for scene in scenes
        ]
    
    @classmethod
    def _get_fp_rdd_for_uris(cls, spark, uris):
        uri_rdd = spark.sparkContext.parallelize(uris)
        fp_rdd = uri_rdd.map(kittisf15_create_fp)
        return fp_rdd


Writing cheap_optical_flow_eval_analysis/kittisf15.py


In [14]:
from cheap_optical_flow_eval_analysis.kittisf15 import *

2021-04-15 02:01:32,060	oarph 3501100 : Source has changed! Rebuilding Egg ...
2021-04-15 02:01:32,060	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:32,061	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:32,062	oarph 3501100 : Generating egg to /tmp/tmprdjz7bio_oarphpy_eggbuild ...
2021-04-15 02:01:32,069	oarph 3501100 : ... done.  Egg at /tmp/tmprdjz7bio_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [15]:
from psegs import datum

# You have to ls flow_occ to get the paths
KITTI_SF15_DEMO_URIS = [
    datum.URI(dataset=KITTISF15Factory.DATASET, extra={
        'ksf15.input': 'training/image_2/000000_10.png',
        'ksf15.expected_out': 'training/image_2/000000_11.png',
        'ksf15.flow_gt': 'training/flow_occ/000000_10.png',
    }),
    datum.URI(dataset=KITTISF15Factory.DATASET, extra={
        'ksf15.input': 'training/image_2/000007_10.png',
        'ksf15.expected_out': 'training/image_2/000007_11.png',
        'ksf15.flow_gt': 'training/flow_occ/000007_10.png',
    }),
    datum.URI(dataset=KITTISF15Factory.DATASET, extra={
        'ksf15.input': 'training/image_2/000023_10.png',
        'ksf15.expected_out': 'training/image_2/000023_11.png',
        'ksf15.flow_gt': 'training/flow_occ/000023_10.png',
    }),
    datum.URI(dataset=KITTISF15Factory.DATASET, extra={
        'ksf15.input': 'training/image_2/000051_10.png',
        'ksf15.expected_out': 'training/image_2/000051_11.png',
        'ksf15.flow_gt': 'training/flow_occ/000051_10.png',
    }),
    datum.URI(dataset=KITTISF15Factory.DATASET, extra={
        'ksf15.input': 'training/image_2/000003_10.png',
        'ksf15.expected_out': 'training/image_2/000003_11.png',
        'ksf15.flow_gt': 'training/flow_occ/000003_10.png',
    }),
]

ALL_FP_FACTORY_CLSS.append(KITTISF15Factory)

print("Found %s Kitti Scene Flow 2015 scenes" % len(KITTISF15Factory.list_fp_uris(spark)))

if SHOW_DEMO_OUTPUT:
    fp_rdd = KITTISF15Factory.get_fp_rdd_for_uris(spark, KITTI_SF15_DEMO_URIS)
    fps = fp_rdd.collect()
    
    for fp in fps:
        show_html(fp.to_html() + "<br/><br/><br/>")
        DEMO_FPS.append(fp)






# def kitti_sf15_create_fp(info):
#      return OpticalFlowPair(
#                 dataset="KITTI Scene Flow 2015",
#                 id1=scene['input'],
#                 img1='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['input']),
#                 id2=scene['expected_out'],
#                 img2='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['expected_out']),
#                 flow=KITTISF15LoadFlowFromPng(os.path.join(KITTI_SF15_DATA_ROOT, scene['flow_gt'])))

# if SHOW_DEMO_OUTPUT:
#     for scene in KITTI_SF15_DEMO_SCENES:
#         p = kitti_sf15_create_fp(scene)
#         show_html(p.to_html())
#         DEMO_FPS.append(p)

# if RUN_FULL_ANALYSIS:
#     for scene in KITTI_SF15_ALL_SCENES:
#         p = kitti_sf15_create_fp(scene)
#         ALL_FPS.append(p)

Found 200 Kitti Scene Flow 2015 scenes


## PSegs Synthetic Flow from Fused Lidar

In [16]:
# PSEGS_SYNTHFLOW_PARQUET_ROOT = '/outer_root/media/rocket4q/psegs_flow_records_short'

# from psegs.exp.fused_lidar_flow import FlowRecTable

# T = FlowRecTable(spark, PSEGS_SYNTHFLOW_PARQUET_ROOT)
# synthflow_record_uris = T.get_record_uris()
# print("Found %s PSegs SynthFlow records" % len(synthflow_record_uris))





# fr_samp_rdd = T.get_records_with_samples_rdd(
#                     record_uris=[PSEGS_SYNTHFLOW_DEMO_RECORD_URIS[0]],
#                     include_cameras=False,
#                     include_cuboids=False,
#                     include_point_clouds=False)
# flow_rec = fr_samp_rdd.take(1)[0][0]

# print("Sample record:")
# show_html(flow_rec.to_html())


# PSEGS_SYNTHFLOW_DEMO_FPS_DO_CACHE = True
# PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH = '/tmp/psegs_synthflow_demo.pkl'



In [17]:
%%writefile cheap_optical_flow_eval_analysis/psegs_synthflow.py

from psegs import datum
from psegs.exp.fused_lidar_flow import FlowRecTable

from cheap_optical_flow_eval_analysis.ofp import *

from oarphpy.spark import CloudpickeledCallable


# Please provide the PSegs synthetic flow Parquet directory root below:
PSEGS_SYNTHFLOW_PARQUET_ROOT = '/outer_root/media/rocket4q/psegs_flow_records_short_fixed'
# PSEGS_SYNTHFLOW_PARQUET_ROOT = '/outer_root/media/rocket4q/psegs_synthflow.parquet'

def psegs_synthflow_flow_rec_to_fp(flow_rec, sample):
  fr = flow_rec

  uri_str_to_datum = sample.get_uri_str_to_datum()

  # Find the camera_images associated with `flow_rec`
  ci1_url_str = str(flow_rec.clouds[0].ci_uris[0])
  ci1_sd = uri_str_to_datum[ci1_url_str]
  ci1 = ci1_sd.camera_image

  ci2_url_str = str(flow_rec.clouds[1].ci_uris[0])
  ci2_sd = uri_str_to_datum[ci2_url_str]
  ci2 = ci2_sd.camera_image

  import numpy as np
  world_T1 = ci1.ego_pose.translation
  world_T2 = ci2.ego_pose.translation
  translation_meters = np.linalg.norm(world_T2 - world_T1)

  id1 = ci1_url_str + '&extra.psegs_flow_sids=' + str(fr.clouds[0].sample_id)
  id2 = ci2_url_str + '&extra.psegs_flow_sids=' + str(fr.clouds[1].sample_id)

  import urllib.parse
  eval_uri = datum.URI(dataset=PSegsSynthFlowFactory.DATASET, extra={'pssf.ruri': urllib.parse.quote(str(fr.uri))})

  uvdviz_im1 = flow_rec.clouds[0].uvdvis
  uvdviz_im2 = flow_rec.clouds[1].uvdvis
  K = ci1.K

  fp = OpticalFlowPair(
          uri=eval_uri,
          dataset="PSegs SynthFlow for %s (%s)" % (fr.uri.dataset, fr.uri.split),
          id1=id1,
          id2=id2,
          img1=CloudpickeledCallable(lambda: ci1.image),
          img2=CloudpickeledCallable(lambda: ci2.image),
          flow=CloudpickeledCallable(lambda: fr.to_optical_flow()),

          diff_time_sec=abs(ci2_sd.uri.timestamp - ci1_sd.uri.timestamp),
          translation_meters=translation_meters,
      
          uvdviz_im1=uvdviz_im1,
          uvdviz_im2=uvdviz_im2,
          K=K)
  return fp

def psegs_synthflow_create_fps(
        spark,
        flow_record_pq_table_path,
        record_uris,
        include_cuboids=False,
        include_point_clouds=False):

  T = FlowRecTable(spark, flow_record_pq_table_path)
  rec_sample_rdd = T.get_records_with_samples_rdd(
                          record_uris=record_uris,
                          include_cameras=True,
                          include_cuboids=include_cuboids,
                          include_point_clouds=include_point_clouds)

  fps = [
    flow_rec_to_fp(flow_rec, sample)
    for flow_rec, sample in rec_sample_rdd.collect()
  ]

  return fps


class PSegsSynthFlowFactory(FlowPairFactoryBase):
    DATASET = 'psegs_synthflow'
    
    @classmethod
    def _get_frec_table(cls, spark):
        if not hasattr(cls, '_frec_table'):
            cls._frec_table = FlowRecTable(spark, PSEGS_SYNTHFLOW_PARQUET_ROOT)
        return cls._frec_table
    
    @classmethod
    def list_fp_uris(cls, spark):
        import urllib.parse
        T = cls._get_frec_table(spark)
        ruris = T.get_record_uris()
        return [
            datum.URI(dataset=cls.DATASET, extra={'pssf.ruri': urllib.parse.quote(str(ruri))})
            for ruri in ruris
        ]
    
    @classmethod
    def _get_fp_rdd_for_uris(cls, spark, uris):
        import urllib.parse
        T = cls._get_frec_table(spark)
        ruris = [urllib.parse.unquote(uri.extra['pssf.ruri']) for uri in uris]
        rec_sample_rdd = T.get_records_with_samples_rdd(
                          record_uris=ruris,
                          include_cameras=True,
                          include_cuboids=False,
                          include_point_clouds=False)
        fp_rdd = rec_sample_rdd.map(lambda fs: psegs_synthflow_flow_rec_to_fp(*fs))
        return fp_rdd
        


# def psegs_synthflow_iter_fp_rdds(
#         spark,
#         flow_record_pq_table_path,
#         fps_per_rdd=100,
#         include_cuboids=False,
#         include_point_clouds=False):
  
#   T = FlowRecTable(spark, flow_record_pq_table_path)
#   ruris = T.get_record_uris()

#   # Ensure a sort so that pairs from similar segments will load in the same
#   # RDD -- that makes joins smaller and faster
#   ruris = sorted(ruris)

#   from oarphpy import util as oputil
#   for ruri_chunk in oputil.ichunked(ruris, fps_per_rdd):
#     frec_sample_rdd = T.get_records_with_samples_rdd(
#                           record_uris=rids,
#                           include_cuboids=include_cuboids,
#                           include_point_clouds=include_point_clouds)
#     fp_rdd = frec_sample_rdd.map(flow_rec_to_fp)
#     yield fp_rdd







Writing cheap_optical_flow_eval_analysis/psegs_synthflow.py


In [18]:
from cheap_optical_flow_eval_analysis.psegs_synthflow import *

2021-04-15 02:01:32,337	oarph 3501100 : Source has changed! Rebuilding Egg ...
2021-04-15 02:01:32,338	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-15 02:01:32,339	oarph 3501100 : Using source root /tmp/tmp4376nzva_cheap_optical_flow_eval_analysis 
2021-04-15 02:01:32,340	oarph 3501100 : Generating egg to /tmp/tmphccawm3g_oarphpy_eggbuild ...
2021-04-15 02:01:32,346	oarph 3501100 : ... done.  Egg at /tmp/tmphccawm3g_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [19]:
from psegs import datum

import urllib.parse

PSEGS_SYNTHFLOW_DEMO_RECORD_RURIS = (
  'psegs://dataset=kitti-360&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=4340,4339',
  'psegs://dataset=kitti-360&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=11219,11269',

  'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=40009,40010',
  'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=50013,50014',

#   'psegs://dataset=kitti-360-fused&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=11103,11104',
#   'psegs://dataset=kitti-360-fused&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=1181,1182',

#   'psegs://dataset=nuscenes&split=train_detect&segment_id=scene-0002&extra.psegs_flow_sids=10016,10017',
#   'psegs://dataset=nuscenes&split=train_detect&segment_id=scene-0582&extra.psegs_flow_sids=60035,60036',

#   'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0393&extra.psegs_flow_sids=50017,50018',
#   'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=40019,40020',
)

PSEGS_SYNTHFLOW_DEMO_URIS = [
    datum.URI(dataset=PSegsSynthFlowFactory.DATASET, extra={
        'pssf.ruri': urllib.parse.quote(ruri_str)
    })
    for ruri_str in PSEGS_SYNTHFLOW_DEMO_RECORD_RURIS
]

ALL_FP_FACTORY_CLSS.append(PSegsSynthFlowFactory)

print("Found %s PSegs SynthFlow scenes" % len(PSegsSynthFlowFactory.list_fp_uris(spark)))

if SHOW_DEMO_OUTPUT:
    if os.path.exists(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH):
        print("Loading demo FlowPairs from %s" % PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH)
        import pickle
        fps = pickle.load(open(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH, 'rb'))
    else:
        print("Building Demo FlowPairs, this might take a while ....")
        fp_rdd = PSegsSynthFlowFactory.get_fp_rdd_for_uris(spark, PSEGS_SYNTHFLOW_DEMO_URIS)
        fps = fp_rdd.collect()
        if PSEGS_SYNTHFLOW_DEMO_FPS_DO_CACHE:
            print("Saving demo FlowPairs to %s ..." % PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH)
            import pickle
            with open(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH, 'wb') as f:
                pickle.dump(fps, f, protocol=4)
    
    for fp in fps:
        show_html(fp.to_html())
        DEMO_FPS.append(fp)
    
    
    
#     import urllib.parse
    
    
    
    
#     for fp in fps:
#         show_html(fp.to_html() + "<br/><br/><br/>")
#         DEMO_FPS.append(fp)






# if SHOW_DEMO_OUTPUT:
#     if os.path.exists(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH):
#         print("Loading demo FlowPairs from %s" % PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH)
#         import pickle
#         fps = pickle.load(open(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH, 'rb'))
#     else:
#         print("Building Demo FlowPairs, this might take a while ....")
#         fps = psegs_synthflow_create_fps(spark, PSEGS_SYNTHFLOW_PARQUET_ROOT, PSEGS_SYNTHFLOW_DEMO_RECORD_URIS)
#         if PSEGS_SYNTHFLOW_DEMO_FPS_DO_CACHE:
#             print("Saving demo FlowPairs to %s ..." % PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH)
#             import pickle
#             with open(PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH, 'wb') as f:
#                 pickle.dump(fps, f, protocol=4)
    
#     for fp in fps:
#         show_html(fp.to_html())
#         DEMO_FPS.append(fp)
        


Found 191 PSegs SynthFlow scenes


## Reconstruction via Optical Flow

In [20]:
## Reconstruction via Optical Flow

def zero_flow(flow):
    return (flow[:, :, :2] == np.array([0, 0])).all(axis=-1)

def warp_flow_backwards(img, flow):
    """Given an image, apply the inverse of `flow`"""
    h, w = flow.shape[:2]
    flow = -flow
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    res = cv2.remap(img, flow.astype(np.float32), None, cv2.INTER_LINEAR)
    return res
    
def warp_flow_forwards(img, flow):
    """Given an image, apply the given optical flow `flow`.  Returns not only the warped
    image, but a `mask` indicating warped pixels (i.e. there was non-zero flow *into* these pixels ).
    With some help from https://stackoverflow.com/questions/41703210/inverting-a-real-valued-index-grid/46009462#46009462
    """
    h, w = img.shape[:2]
    pts = flow.copy()
    pts[:, :, 0] += np.arange(w)
    pts[:, :, 1] += np.arange(h)[:, np.newaxis]
    exclude = zero_flow(flow)
    if exclude.all():
        # No flow anywhere!
        return img.copy(), np.zeros((h, w)).astype(np.bool)
    else:
        inpts = pts[~exclude]
    
    from scipy.interpolate import griddata
    inpts = np.reshape(inpts, [-1, 2])
    grid_y, grid_x = np.mgrid[:h, :w]
    chan_out = []
    for ch in range(img.shape[-1]):
        spts = img[:, :, ch][~exclude].reshape([-1, 1])
        mapped = griddata(inpts, spts, (grid_x, grid_y), method='linear')
        chan_out.append(mapped.astype(img.dtype))
    out = np.stack(chan_out, axis=-1)
    out = out.reshape([h, w, len(chan_out)])

    mask = np.reshape(inpts, [-1, 2])
    mask = np.rint(mask).astype(np.int)
    mask = mask[np.where((mask[:, 0] >= 0) & (mask[:, 0] < w) & (mask[:, 1] >= 0) & (mask[:, 1] < h))]
    valid_mask = np.zeros((h, w))
    valid_mask[mask[:, 1], mask[:, 0]] = 1
    
    return out, valid_mask.astype(np.bool)

# @attr.s(slots=True, eq=False, weakref_slot=False)
class FlowReconstructedImagePair(object):
    """A pair of reconstructed images using an input pair of images and optical
    flow field (i.e. an `OpticalFlowPair` instance)."""

    slots = (
        'opair',
        'img2_recon_fwd',
        'img2_recon_fwd_valid',
        'img1_recon_bkd',
        'img1_recon_bkd_valid'
    )
    
    def __init__(self, **kwargs):
        for k in self.slots:
            setattr(self, k, kwargs.get(k))
    
#     opair = attr.ib(default=OpticalFlowPair())
#     """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
#     img2_recon_fwd = attr.ib(default=np.array([]))
#     """A Numpy image containing the result of FORWARDS-WARPING OpticalFlowPair::img1
#     via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img2"""

#     img2_recon_fwd_valid = attr.ib(default=np.array([]))
#     """A Numpy boolean mask indicating which pixels of `img2_recon_fwd` were modified via non-zero flow"""
    
#     img1_recon_bkd = attr.ib(default=np.array([]))
#     """A Numpy image containing the result of BACKWARDS-WARPING OpticalFlowPair::img2
#     via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img1"""

#     img1_recon_bkd_valid = attr.ib(default=np.array([]))
#     """A Numpy boolean mask indicating which pixels of `img1_recon_bkd` were modified via non-zero flow"""
        
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        
        # Forward Warp
        fwarped, fvalid = warp_flow_forwards(oflow_pair.get_img1(), flow)

        # Backwards Warp
        exclude = zero_flow(flow)
        bwarped = warp_flow_backwards(oflow_pair.get_img2(), -flow[:, :, :2])
        bvalid = ~exclude
        
        return FlowReconstructedImagePair(
                opair=oflow_pair,
                img2_recon_fwd=fwarped,
                img2_recon_fwd_valid=fvalid,
                img1_recon_bkd=bwarped,
                img1_recon_bkd_valid=bvalid)
    
    def to_html(self):
        # We use pixels from the destination image in order to make the reconstruction 
        # easier to interpret; we'll fade them in intensity so that they are more
        # conspicuous.        
        FADE_UNTOUCHED_PIXELS = 0.3
        
        viz_fwd = self.img2_recon_fwd.copy().astype(np.float32)
        im2 = self.opair.get_img2()
        if (~self.img2_recon_fwd_valid).any():
            viz_fwd[~self.img2_recon_fwd_valid] = im2[~self.img2_recon_fwd_valid]
            viz_fwd[~self.img2_recon_fwd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            # viz_fwd = im2.copy() * FADE_UNTOUCHED_PIXELS
            print('no invalids forward!')
        
        viz_bkd = self.img1_recon_bkd.copy().astype(np.float32)
        im1 = self.opair.get_img1()
        if (~self.img1_recon_bkd_valid).any():
            viz_bkd[~self.img1_recon_bkd_valid] = im1[~self.img1_recon_bkd_valid]
            viz_bkd[~self.img1_recon_bkd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            # viz_bkd = im1.copy() * FADE_UNTOUCHED_PIXELS
            print('no invalids backwards!')
        
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Forwards Warped <i>(dark pixels unwarped)</i></b></td></tr>
            <tr><td><img src="{viz_fwd}" width="100%" /></td></tr>

            <tr><td style="text-align:left"><b>Backwards Warped <i>(dark pixels unwarped)</i></b></td></tr>
            <tr><td><img src="{viz_bkd}" width="100%" /></td></tr>

            </table>
        """.format(
                viz_fwd=img_to_data_uri(viz_fwd.astype(np.uint8)),
                viz_bkd=img_to_data_uri(viz_bkd.astype(np.uint8)))
        return html

        
if SHOW_DEMO_OUTPUT:
    DEMO_RECONS = []
    for p in DEMO_FPS:
        recon = FlowReconstructedImagePair.create_from(p)
        show_html(recon.to_html() + "</br></br></br>")
        DEMO_RECONS.append(recon)


## Analysis: Demo

In [21]:
# Analysis Utils

def mse(i1, i2, valid):
    return np.mean((i1[valid] - i2[valid]) ** 2)

def rmse(i1, i2, valid):
    return math.sqrt(mse(i1, i2, valid))

def psnr(i1, i2, valid):
    return 20 * math.log10(255) - 10 * math.log10(max((mse(i1, i2, valid), 1e-12)))

def ssim(i1, i2, valid):
    # Some variance out there ...
    # https://github.com/scikit-image/scikit-image/blob/master/skimage/metrics/_structural_similarity.py#L12-L232
    # https://github.com/nianticlabs/monodepth2/blob/13200ab2f29f2f10dec3aa5db29c32a23e29d376/layers.py#L218
    # https://cvnote.ddlee.cn/2019/09/12/psnr-ssim-python
    # We will just use SKImage for now ...
    from skimage.metrics import structural_similarity as ssim
    mssim, S = ssim(i1, i2, win_size=11, multichannel=True, full=True)
    return np.mean(S[valid])

def to_edge_im(img):
    return np.stack([
        cv2.Laplacian(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, ksize=1),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 1, 0, ksize=3),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 0, 1, ksize=3),
    ], axis=-1)

def edges_mse(i1, i2, valid):
    return mse(to_edge_im(i1), to_edge_im(i2), valid)


def oflow_coverage(valid):
    return valid.sum() / (valid.shape[0] * valid.shape[1])

def oflow_magnitude_hist(flow, valid, bins=50):
    flow_l2s = np.sqrt( flow[valid][:, 0] ** 2 + flow[valid][:, 1] ** 2 )
    bin_counts, bin_edges = np.histogram(flow_l2s, bins=bins)
    return bin_edges, bin_counts


# Analysis Data Model

class OFlowReconErrors(object):
    """Various measures of reconstruction error for a `FlowReconstructedImagePair` instance.
    Encapsulated as two dictionaries of stats for easy interop with Spark SQL."""

    RECONSTRUCTION_ERR_METRICS = {
        'SSIM': ssim,
        'MSE': mse,
        'RMSE': rmse,
        'PSNR': psnr,
        'Edges_MSE': edges_mse,
    }
    
    def __init__(self, recon_pair: FlowReconstructedImagePair):
        im2 = recon_pair.opair.get_img2()
        img2_recon_fwd = recon_pair.img2_recon_fwd
        img2_recon_fwd_valid = recon_pair.img2_recon_fwd_valid
        self.forward_stats = dict(
            (name, func(im2, img2_recon_fwd, img2_recon_fwd_valid))
            for name, func in self.RECONSTRUCTION_ERR_METRICS.items())
        
        im1 = recon_pair.opair.get_img1()
        img1_recon_fwd = recon_pair.img1_recon_bkd
        img1_recon_fwd_valid = recon_pair.img1_recon_bkd_valid
        self.backward_stats = dict(
            (name, func(im1, img1_recon_fwd, img1_recon_fwd_valid))
            for name, func in self.RECONSTRUCTION_ERR_METRICS.items())

    def to_html(self):
        stat_names = self.RECONSTRUCTION_ERR_METRICS.keys()

        rows = [
            """
            <tr>
              <td style="text-align:left"><b>{name}</b></td>
              <td style="text-align:left">{fwd:.2f}</td>
              <td style="text-align:left">{bkd:.2f}</td>
            </tr>
            """.format(name=name, fwd=self.forward_stats[name], bkd=self.backward_stats[name])
            for name in stat_names
        ]
        
        
        html = """
            <table>
              <tr>
                  <th></th> <th><b>Forwards Warp</b></th> <th><b>Backwards Warp</b></th>
              </tr>

              {table_rows}

            </table>
        """.format(table_rows="".join(rows))
        
        return html
            
# @attr.s(slots=True, eq=False, weakref_slot=False)
class OFlowStats(object):
    """Stats on the optical flow of a `OpticalFlowPair` instance"""

    slots = (
        'opair',
        'coverage',
        'magnitude_hist',
    )
    
    def __init__(self, **kwargs):
        for k in self.slots:
            setattr(self, k, kwargs.get(k))
    
#     opair = attr.ib(default=OpticalFlowPair())
#     """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
#     coverage = attr.ib(default=0)
#     """Fraction of the image with valid flow"""
    
#     magnitude_hist = attr.ib(default=[np.array([]), np.array([])])
#     """Histogram [bin edges, bin counts] of flow magnitudes"""
    
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        valid = ~zero_flow(flow)
        return OFlowStats(
                 opair=oflow_pair,
                 coverage=oflow_coverage(valid),
                 magnitude_hist=oflow_magnitude_hist(flow, valid))
                 
    def to_html(self):
        import matplotlib.pyplot as plt
        fig = plt.figure()
        bin_edges, bin_counts = self.magnitude_hist
        plt.bar(bin_edges[:-1], bin_counts)
        plt.title("Histogram of Flow Magnitudes")
        plt.xlabel('Flow Magnitude (pixels)')
        plt.ylabel('Count')

        hist_img = matplotlib_fig_to_img(fig)
        
        html = """
            <table>           
            <tr><td style="text-align:left"><b>Flow Coverage:</b> {coverage:.2f}% </td></tr>
            <tr><td><img src="{flow_hist}" width="100%" /></td></tr>
            </table>
        """.format(
                coverage=100. * self.coverage,
                flow_hist=img_to_data_uri(matplotlib_fig_to_img(hist_img)))
        return html


# Misc

def matplotlib_fig_to_img(fig):
    import io
    import matplotlib.pyplot as plt
    from PIL import Image
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    im = Image.open(buf)
    im.show()
    buf.seek(0)

    import imageio
    hist_img = imageio.imread(buf)
    buf.close()
    return hist_img


if SHOW_DEMO_OUTPUT:
    %matplotlib agg
    for recon in DEMO_RECONS:
        p = recon.opair
        errors = OFlowReconErrors(recon)
        err_html = errors.to_html()  
            
        fstats = OFlowStats.create_from(p)
        stats_html = fstats.to_html()
            
        title = "<b>{dataset} {id1} -> {id2}</b>".format(dataset=p.dataset, id1=p.id1, id2=p.id2)
        
        show_html(title + stats_html + err_html + "</br></br></br>")
            

## Scene Flow Analysis (where depth and intrinsics are available)

In [22]:
"""
 * for psegs, we have uvd and K
 * for kitti tracking, we'll be able to have uvd and K
 * for deepdeform, the intrinsics are in each seq.  also a mask for maybe the images of interest?
 * for kitti sf, we can get K (P) from kitti-like file.  !!! kitti has obj_map colors image!!  
     need to figure out depth meters from disparity ...  derrrp to get the raw velodynes we have to cross-ref
     with odometry dataset. let's talk to yiyi about that...
 * !!! do a test where you use nearest neighbor correspondence on raw clouds for OFlow. then can see how bad
     the pairing is sometimes
 
 * metrics: end-pt-error for NN forward; same for backward; then also do a chamfer distance metric
 * (do all this again but first do an ICP on the raw depths-- the rigid background should probably align, right?
     use the ICP's RT to pose raw and 
 * a common class for all these is background / foreground.  want to break down chamfer dist etc bucket by at least
      background / foreground
 * debug image: surface pairs of points with end pt error larger than E and plot on the image
 
 * another good test: (1) train self-sup SF on raw clouds.  then test on large displacement pair 
     (walk a prediction forward many time steps). then can see how well that holds up vs our "GT"
 
# From code above, which we won't run every time since
# it's complicated and just gets static information.
f, cx, cy, w, h = 1144.27150333,  960. ,  540., 1920, 1080
K = np.array([
      [f, 0, cx],
      [0, f, cy],
      [0, 0, 1],
])

px_y = np.tile(np.arange(h)[:, np.newaxis], [1, w])
px_x = np.tile(np.arange(w)[np.newaxis, :], [h, 1])
PYX = np.concatenate([px_y[:,:,np.newaxis], px_x[:, :,np.newaxis]], axis=-1)
RAYS_FOR_CAM = np.zeros((h, w, 3))
RAYS_FOR_CAM[:, :, 0] = (PYX[:, :, 0] - cy) / f
RAYS_FOR_CAM[:, :, 1] = (PYX[:, :, 1] - cx) / f
RAYS_FOR_CAM[:, :, 2] = 1

yxz = RAYS_FOR_CAM * (demo[:, :, 3][:, :,np.newaxis])
yxz = yxz.reshape([-1, 3])
yxzrgb = np.concatenate([yxz, demo[:, :, :3].reshape([-1, 3])], axis=-1)
"""



def nn_distance(xyz_src, xyz_target):
    import numpy as np
    import open3d as o3d
    pcds = o3d.geometry.PointCloud()
    pcds.points = o3d.utility.Vector3dVector(xyz_src)
    pcdt = o3d.geometry.PointCloud()
    pcdt.points = o3d.utility.Vector3dVector(xyz_target)
    dists = pcds.compute_point_cloud_distance(pcdt)
    dists = np.asarray(dists)
    return dists


class SFlowStats(object):
    """Stats on the scene flow of a `OpticalFlowPair` instance (that has scene flow data)"""

    slots = (
        'fwd_nn_end_point_error',
        'bkd_nn_end_point_error',
        'chamfer_distance',
        'fwd_epe_50th',
        'fwd_epe_75th',
        'fwd_epe_95th',
#         'icp_fwd_nn_end_point_error',
#         'icp_bkd_nn_end_point_error',
#         'icp_chamfer_distance',
        
        'opair',
        'coverage',
        'magnitude_hist',
    )
    
    def __init__(self, **kwargs):
        for k in self.slots:
            setattr(self, k, kwargs.get(k))
    
#     opair = attr.ib(default=OpticalFlowPair())
#     """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
#     coverage = attr.ib(default=0)
#     """Fraction of the image with valid flow"""
    
#     magnitude_hist = attr.ib(default=[np.array([]), np.array([])])
#     """Histogram [bin edges, bin counts] of flow magnitudes"""
    
#     @classmethod
#     def create_from(cls, oflow_pair: OpticalFlowPair):
#         flow = oflow_pair.get_flow()
#         valid = ~zero_flow(flow)
#         return OFlowStats(
#                  opair=oflow_pair,
#                  coverage=oflow_coverage(valid),
#                  magnitude_hist=oflow_magnitude_hist(flow, valid))
                 
#     def to_html(self):
#         import matplotlib.pyplot as plt
#         fig = plt.figure()
#         bin_edges, bin_counts = self.magnitude_hist
#         plt.bar(bin_edges[:-1], bin_counts)
#         plt.title("Histogram of Flow Magnitudes")
#         plt.xlabel('Flow Magnitude (pixels)')
#         plt.ylabel('Count')

#         hist_img = matplotlib_fig_to_img(fig)
        
#         html = """
#             <table>           
#             <tr><td style="text-align:left"><b>Flow Coverage:</b> {coverage:.2f}% </td></tr>
#             <tr><td><img src="{flow_hist}" width="100%" /></td></tr>
#             </table>
#         """.format(
#                 coverage=100. * self.coverage,
#                 flow_hist=img_to_data_uri(matplotlib_fig_to_img(hist_img)))
#         return html
    


## Analysis on Full Datasets

In [23]:
# import sys
# sys.path.append('/opt/psegs')

# from oarphpy.spark import NBSpark
# NBSpark.SRC_ROOT = os.path.join(ALIB_SRC_DIR, 'cheap_optical_flow_eval_analysis')
# NBSpark.CONF_KV.update({
#     'spark.driver.maxResultSize': '2g',
#     'spark.driver.memory': '16g',
#   })
# spark = NBSpark.getOrCreate()


from oarphpy.spark import RowAdapter

from pyspark import Row


def flow_pair_to_full_row(fp):
    from threadpoolctl import threadpool_limits
    with threadpool_limits(limits=1, user_api='blas'):
        recon = FlowReconstructedImagePair.create_from(fp)
        fstats = OFlowStats.create_from(fp)
        errors = OFlowReconErrors(recon)

        rowdata = dict(
                fp_datset=fp.dataset,
                fp_uri=str(fp.uri),
                flow_coverage=fstats.coverage,
                diff_time_sec=fp.diff_time_sec,
                translation_meters=fp.translation_meters,
        )
        rowdata.update(
            ('Forwards_' + k, float(v))
            for k, v in errors.forward_stats.items())
        rowdata.update(
            ('Backwards_' + k, float(v))
            for k, v in errors.backward_stats.items())
        return RowAdapter.to_row(rowdata)


analysis_uris_demo = MiddFactory.list_fp_uris(spark) + PSEGS_SYNTHFLOW_DEMO_URIS + KITTI_SF15_DEMO_URIS + DD_DEMO_URIS


class UnionFactory(FlowPairUnionFactory):
    FACTORIES = ALL_FP_FACTORY_CLSS

analysis_uris_full = UnionFactory.list_fp_uris(spark)
# analysis_uris_full = analysis_uris_full[4900:]
print('analysis_uris_full', len(analysis_uris_full))

from oarphpy import util as oputil
thru = oputil.ThruputObserver(name='run_analysis', n_total=len(analysis_uris_full))
for uri_chunk in oputil.ichunked(analysis_uris_full, 100):
    thru.start_block()
    fp_rdd = UnionFactory.get_fp_rdd_for_uris(spark, uri_chunk)
    result_rdd = fp_rdd.map(flow_pair_to_full_row)
    df = spark.createDataFrame(result_rdd)
    df.write.save(
            mode='append',
            path='/outer_root/media/rocket4q/oflow_pq_eval_test.parquet',
            format='parquet',
            compression='lz4')
    thru.stop_block(n=len(uri_chunk))
    thru.maybe_log_progress(every_n=1)


# if True:#RUN_FULL_ANALYSIS:
# #     spark = Spark.getOrCreate()
    
# #     for p in ALL_FPS:
# #         import cloudpickle
# #         try:
# #             cloudpickle.dumps(p)
# #         except Exception:
# #             assert False, p
# #     print('all good')
    
#     import pickle
#     fp_rdd = spark.sparkContext.parallelize(ALL_FPS, numSlices=200)
# #     print(fp_rdd.count())
#     df = spark.createDataFrame(fp_rdd.map(flow_pair_to_full_row)).persist()

#     print(df.count())
#     df.show(10)
#     df.printSchema()

In [24]:
results_df = spark.read.parquet('/outer_root/media/rocket4q/oflow_pq_eval_test.parquet')


def add_dataset(row):
    from psegs import datum
    row = row.asDict()
    uri = datum.URI.from_str(row['fp_uri'])
    row['fp_dataset'] = uri.dataset
    return row

results_df = spark.createDataFrame(results_df.rdd.map(add_dataset))
results_df = results_df.persist()

results_df.show()
results_df.count()



+-------------------+------------------+------------------+-----------------+-------------------+------------------+------------------+------------------+------------------+-------------------+--------------------+-----------+--------------------+
|Backwards_Edges_MSE|     Backwards_MSE|    Backwards_PSNR|   Backwards_RMSE|     Backwards_SSIM|Forwards_Edges_MSE|      Forwards_MSE|     Forwards_PSNR|     Forwards_RMSE|      Forwards_SSIM|       flow_coverage| fp_dataset|              fp_uri|
+-------------------+------------------+------------------+-----------------+-------------------+------------------+------------------+------------------+------------------+-------------------+--------------------+-----------+--------------------+
|  744.0916748046875| 57.34668296563634| 30.54572058271115|7.572759270281629| 0.5458482847253787|  998.046142578125| 55.85433777856281| 30.66023453780854|  7.47357596994657| 0.5581415463961072| 0.22368489583333334|deep_deform|psegs://dataset=d...|
|    842

19214

In [28]:
import os

def fp_uri_to_fname(fp_uri):
    fp_uri = str(fp_uri)
    import urllib.parse
    fname = urllib.parse.quote(fp_uri)
    from slugify import slugify
    fname = slugify(fname)
    return fname

def extract_fp_uris_from_html(html):
    import re
    matches = list(set(re.findall(r'alt=\\"(.*?)\\"', html)))
    import html
    return set(html.unescape(s) for s in matches)

FLOW_EVAL_REPORT_BASEDIR = '/tmp/flow_eval/'
from oarphpy import util as oputil
oputil.mkdir(FLOW_EVAL_REPORT_BASEDIR)

from oarphpy import plotting as pl
class Plotter(pl.HistogramWithExamplesPlotter):
    NUM_BINS = 50
    ROWS_TO_DISPLAY_PER_BUCKET = 10
    SUB_PIVOT_COL = 'fp_dataset'

    def display_bucket(self, sub_pivot, bucket_id, irows):
        from oarphpy.spark import RowAdapter
        from psegs import datum
        
        # Sample from irows using reservior sampling
        import random
        rows = []
        for i, row in enumerate(irows):
            r = random.randint(0, i)
            if r < cls.ROWS_TO_DISPLAY_PER_BUCKET:
                if i < cls.ROWS_TO_DISPLAY_PER_BUCKET:
                    rows.insert(r, row)
                else:
                    rows[r] = row
        
        # Now render each row to HTML
        row_htmls = []
        for row in rows:
            rowdata = RowAdapter.from_row(row)
            
            fp_uri_str = rowdata['fp_uri']
            fp_uri = datum.URI.from_str(fp_uri_str)
            fp_page_uri = fp_uri_to_fname(fp_uri_str) + '.html'
            dataset = fp_uri.dataset
            id1 = "TODO"
            id2 = "TODO"
            
            row_html = f"""
                <a href="{fp_page_uri}" alt="{fp_uri_str}">
                    {fp_uri.dataset} {fp_uri.split} {fp_uri.segment_id} {id1} -> {id2}
                </a><br />"""
            row_htmls.append(row_html)
        
        HTML = """
        <b>Pivot: {spv} Bucket: {bucket_id} </b> <br/>
        
        {row_bodies}
        """.format(
              spv=sub_pivot,
              bucket_id=bucket_id,
              row_bodies="<br/><br/><br/>".join(row_htmls))
        
        return bucket_id, HTML

plotter = Plotter()

chosen_fp_uris = set()
histogram_htmls = []
cols = [col for col in results_df.columns if col not in ('fp_uri', 'fp_dataset')]
print("Rendering %s histograms" % len(cols))
for col in cols:
    print("Working on %s" % col)
#     fig = plotter.run(results_df, col)
    dest = os.path.join(FLOW_EVAL_REPORT_BASEDIR, '%s.html' % col)
#     pl.save_bokeh_fig(fig, dest)
    
    with open(dest, 'r') as f:
        cur_chosen_fp_uris = extract_fp_uris_from_html(f.read())
        print(len(cur_chosen_fp_uris))
    chosen_fp_uris |= cur_chosen_fp_uris
# assert False, len(chosen_fp_uris)
    
print("Rendering %s histogram bucket pages" % len(chosen_fp_uris))
class UnionFactory(FlowPairUnionFactory):
    FACTORIES = ALL_FP_FACTORY_CLSS

analysis_uris_full = UnionFactory.list_fp_uris(spark)

fp_rdd = UnionFactory.get_fp_rdd_for_uris(spark, list(chosen_fp_uris))
def render_and_save(fp):
    from threadpoolctl import threadpool_limits
    with threadpool_limits(limits=1, user_api='blas'):
        import os
        recon = FlowReconstructedImagePair.create_from(fp)
        fstats = OFlowStats.create_from(fp)
        errors = OFlowReconErrors(recon)
        page_html = "<br/>".join((fp.to_html(), recon.to_html(), fstats.to_html(), errors.to_html()))

        dest = os.path.join(FLOW_EVAL_REPORT_BASEDIR, fp_uri_to_fname(fp.uri) + '.html')
        with open(dest, 'w') as f:
            f.write(page_html)
fp_rdd.foreach(render_and_save)
    
    
# from bokeh.io import output_notebook
# output_notebook()
# from bokeh.plotting import show
# show(fig)

[autoreload of psegs.datasets.kitti_360 failed: Traceback (most recent call last):
  File "/usr/lib/python3/dist-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/lib/python3/dist-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/usr/lib/python3/dist-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/usr/lib/python3/dist-packages/IPython/extensions/autoreload.py", line 292, in update_class
    if (old_obj == new_obj) is True:
  File "<attrs generated eq attr._make.Attribute>", line 4, in __eq__
    return  (
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
]


Rendering 11 histograms
Working on Backwards_Edges_MSE
547
Working on Backwards_MSE
938
Working on Backwards_PSNR
892
Working on Backwards_RMSE
932
Working on Backwards_SSIM
945
Working on Forwards_Edges_MSE
548
Working on Forwards_MSE
891
Working on Forwards_PSNR
892
Working on Forwards_RMSE
888
Working on Forwards_SSIM
867
Working on flow_coverage
792
Rendering 2031 histogram bucket pages


2021-04-15 02:08:31,794	ps   3501100 : Building DF for psegs://dataset=kitti-360&split=train&segment_id=2013_05_28_drive_0004_sync
2021-04-15 02:08:31,795	ps   3501100 : Loading /opt/psegs/dataroot/stamped_datum/stamped_datums ...
2021-04-15 02:08:31,901	ps   3501100 : Creating datums for KITTI-360 ...
2021-04-15 02:08:31,901	ps   3501100 : Filtering to only 1 segments
2021-04-15 02:08:35,435	ps   3501100 : ... seq 2013_05_28_drive_0004_sync has 99660 URIs spanning 1211 sec, creating 389 slices ...
2021-04-15 02:08:35,966	ps   3501100 : ... partitioned datums into 1 RDDs.
2021-04-15 02:08:36,004	ps   3501100 : Going to write in 1 chunks ...


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 9 in stage 34.0 failed 1 times, most recent failure: Lost task 9.0 in stage 34.0 (TID 13400, 192.168.0.213, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 605, in main
    process()
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 597, in process
    serializer.dump_stream(out_iter, outfile)
  File "/opt/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 271, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/spark/python/lib/pyspark.zip/pyspark/util.py", line 107, in wrapper
    return f(*args, **kwargs)
  File "/opt/psegs/psegs/datasets/kitti_360.py", line 449, in create_stamped_datum
    if uri.topic.startswith('camera'):
AttributeError: 'tuple' object has no attribute 'topic'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:503)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:638)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:489)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$1.hasNext(InMemoryRelation.scala:132)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1388)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1298)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1362)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1186)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:360)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:311)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:99)
	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:127)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:446)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:449)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2008)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2007)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2007)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:973)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2239)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2188)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2177)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:775)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2120)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2139)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2164)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1004)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:388)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1003)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:168)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 605, in main
    process()
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 597, in process
    serializer.dump_stream(out_iter, outfile)
  File "/opt/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 271, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/spark/python/lib/pyspark.zip/pyspark/util.py", line 107, in wrapper
    return f(*args, **kwargs)
  File "/opt/psegs/psegs/datasets/kitti_360.py", line 449, in create_stamped_datum
    if uri.topic.startswith('camera'):
AttributeError: 'tuple' object has no attribute 'topic'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:503)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:638)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:489)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$1.hasNext(InMemoryRelation.scala:132)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1388)
	at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1298)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1362)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1186)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:360)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:311)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:99)
	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:127)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:446)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:449)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
