# Cheap Optical Flow: Is it Good? Does it Boost?


## Quickstart

## Credits

Some portions of this notebook adapted from:
 * [Middlebury Flow code by Johannes Oswald](https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py)
 * [DeepDeform Demo Code](https://github.com/AljazBozic/DeepDeform)
 * [OpticalFlowToolkit by RUOTENG LI](https://github.com/liruoteng/OpticalFlowToolkit)
 * [OpenCV Samples](https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py)

In [1]:
# parameters
SHOW_DEMO_OUTPUT = True
DEMO_FPS = []

RUN_FULL_ANALYSIS = False
ALL_FPS = []

In [2]:
## Setup

!pip3 install pypng scikit-image
print('fixme installs')
print()
print()

import copy
import imageio
import IPython.display
import math
import os
import PIL.Image
import six
import sys
import tempfile


## General Notebook Utilities
    
def imshow(x):
    IPython.display.display(PIL.Image.fromarray(x))

def show_html(x):
    from IPython.core.display import display, HTML
    display(HTML(x))


## Create a random temporary directory for analysis library (for Spark-enabled full analysis mode)
old_cwd = os.getcwd()
tempdir = tempfile.TemporaryDirectory(suffix='_cheap_optical_flow_eval_analysis')
ALIB_SRC_DIR = tempdir.name
print("Putting analysis lib in %s" % ALIB_SRC_DIR)
os.chdir(ALIB_SRC_DIR)
!mkdir -p cheap_optical_flow_eval_analysis
!touch cheap_optical_flow_eval_analysis/__init__.py

%load_ext autoreload
%autoreload 2
sys.path.append(ALIB_SRC_DIR)


## Prepare a build of local psegs for inclusion
!cd /opt/psegs && python3 setup.py clean bdist_egg
PSEGS_EGG_PATH = '/opt/psegs/dist/psegs-0.0.1-py3.8.egg'
assert os.path.exists(PSEGS_EGG_PATH), "Build failed?"
sys.path.append('/opt/psegs')
import psegs


## Prepare Spark session with local PSegs and local Analysis Lib
from psegs.spark import NBSpark
NBSpark.SRC_ROOT = os.path.join(ALIB_SRC_DIR, 'cheap_optical_flow_eval_analysis')
NBSpark.CONF_KV.update({
    'spark.driver.maxResultSize': '2g',
    'spark.driver.memory': '16g',
    'spark.submit.pyFiles': PSEGS_EGG_PATH,
  })
spark = NBSpark.getOrCreate()

fixme installs


Putting analysis lib in /tmp/tmpdx2owrw9_cheap_optical_flow_eval_analysis
running clean
running bdist_egg
running egg_info
writing psegs.egg-info/PKG-INFO
writing dependency_links to psegs.egg-info/dependency_links.txt
writing top-level names to psegs.egg-info/top_level.txt
reading manifest file 'psegs.egg-info/SOURCES.txt'
writing manifest file 'psegs.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py

creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/psegs
copying build/lib/psegs/dummyrun.py -> build/bdist.linux-x86_64/egg/psegs
creating build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/kitti.py -> build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/idsutil.py -> build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datasets/kitti_360.py -> build/bdist.linux-x86_64/egg/psegs/datasets
copying build/lib/psegs/datase

2021-04-09 18:56:41,908	oarph 1608375 : Using source root /tmp/tmpdx2owrw9_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-09 18:56:41,908	oarph 1608375 : Using source root /tmp/tmpdx2owrw9_cheap_optical_flow_eval_analysis 
2021-04-09 18:56:41,952	oarph 1608375 : Generating egg to /tmp/tmppbvm6j6s_oarphpy_eggbuild ...
2021-04-09 18:56:41,964	oarph 1608375 : ... done.  Egg at /tmp/tmppbvm6j6s_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


In [3]:
%%writefile cheap_optical_flow_eval_analysis/ofp.py

## Data Model & Utility Code

import attr
import cv2
import imageio
import math
import os
import PIL.Image
import six

import numpy as np

from oarphpy import plotting as op_plt
from oarphpy.spark import CloudpickeledCallable
img_to_data_uri = lambda x: op_plt.img_to_data_uri(x, format='png')

@attr.s(slots=True, eq=False, weakref_slot=False)
class OpticalFlowPair(object):
    """A flyweight for a pair of images with an optical flow field.
    Supports lazy-loading of large data attributes."""
    
    ## Core Attributes (Required for All Datasets)
    
    dataset = attr.ib(type=str, default='')
    """To which dataset does this pair belong?"""
    
    id1 = attr.ib(type=str, default='')
    """Identifier or URI for the first image"""
    
    id2 = attr.ib(type=str, default='')
    """Identifier or URI for the second image"""
    
    img1 = attr.ib(default=None)
    """URI or numpy array or CloudPickleCallable for the first image (source image)"""

    img2 = attr.ib(default=None)
    """URI or numpy array or CloudpickeledCallable for the second image (target image)"""
    
    flow = attr.ib(default=None)
    """A numpy array or callable or CloudpickeledCallable representing optical flow from img1 -> img2"""
    
    ## Optional Attributes (For Select Datasets)
    
    diff_time_sec = attr.ib(type=float, default=0.0)
    """Difference in time (in seconds) between the views / poses depicted in `img1` and `img2`."""
    
    translation_meters = attr.ib(type=float, default=0.0)
    """Difference in ego translation (in meters) between the views / poses depicted in `img1` and `img2`."""
    
    
    
    # to add:
    # diff time seconds
    # semantic image for frame 1, frame 2 [could be painted by cuboids]
    # instance images for frame 1, frame 2 [could be painted by cuboids]
    #   -- for colored images, at first just pivot all oflow metrics by colors
    # get uvdviz1 uvdviz2 (scene flow)
    #   * for deepeform, their load_flow will work
    #   * for kitti, we have to read their disparity images
    # get uvd1 uvd2 (lidar for nearest neighbor stuff)
    # depth image for frame 1, frame 2 [could be interpolated by cuboids]
    #   -- at first bucket the depth coarsely and pivot al oflow by colors
    
    def get_img1(self):
        if isinstance(self.img1, CloudpickeledCallable):
            self.img1 = self.img1()
        if isinstance(self.img1, six.string_types):
            self.img1 = imageio.imread(self.img1)
        return self.img1
    
    def get_img2(self):
        if isinstance(self.img2, CloudpickeledCallable):
            self.img2 = self.img2()
        if isinstance(self.img2, six.string_types):
            self.img2 = imageio.imread(self.img2)
        return self.img2
    
    def get_flow(self):
        if not isinstance(self.flow, (np.ndarray, np.generic)):
            self.flow = self.flow()
        return self.flow
    
    def to_html(self):
        im1 = self.get_img1()
        im2 = self.get_img2()
        flow = self.get_flow()
        fviz = draw_flow(im1, flow)
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Dataset:</b> {dataset}</td></tr>
            
            <tr><td style="text-align:left"><b>Source Image:</b> {id1}</td></tr>
            <tr><td><img src="{im1}" /></td></tr>

            <tr><td style="text-align:left"><b>Target Image:</b> {id2}</td></tr>
            <tr><td><img src="{im2}" /></td></tr>

            <tr><td style="text-align:left"><b>Flow</b></td></tr>
            <tr><td><img src="{fviz}" /></td></tr>
            </table>
        """.format(
                dataset=self.dataset,
                id1=self.id1, id2=self.id2,
                im1=img_to_data_uri(im1), im2=img_to_data_uri(im2),
                fviz=img_to_data_uri(fviz))
        return html

def draw_flow(img, flow, step=8):
    """Based upon OpenCV sample: https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py"""
    h, w = img.shape[:2]
    y, x = np.mgrid[step/2:h:step, step/2:w:step].reshape(2,-1).astype(int)
    fx, fy = flow[y,x].T
    lines = np.vstack([x, y, x+fx, y+fy]).T.reshape(-1, 2, 2)
    lines = np.int32(lines + 0.5)
    vis = img.copy()
    cv2.polylines(vis, lines, 0, (0, 255, 0))
    for (x1, y1), (_x2, _y2) in lines:
        cv2.circle(vis, (x1, y1), 1, (0, 255, 0), -1)
    return vis


Writing cheap_optical_flow_eval_analysis/ofp.py


In [4]:
from cheap_optical_flow_eval_analysis.ofp import *

2021-04-09 18:56:45,170	oarph 1608375 : Source has changed! Rebuilding Egg ...
2021-04-09 18:56:45,171	oarph 1608375 : Using source root /tmp/tmpdx2owrw9_cheap_optical_flow_eval_analysis/cheap_optical_flow_eval_analysis 
2021-04-09 18:56:45,171	oarph 1608375 : Using source root /tmp/tmpdx2owrw9_cheap_optical_flow_eval_analysis 
2021-04-09 18:56:45,174	oarph 1608375 : Generating egg to /tmp/tmp0xm0f0t7_oarphpy_eggbuild ...
2021-04-09 18:56:45,181	oarph 1608375 : ... done.  Egg at /tmp/tmp0xm0f0t7_oarphpy_eggbuild/cheap_optical_flow_eval_analysis-0.0.0-py3.8.egg


## Middlebury Optical Flow



In [5]:
# Please unzip `other-color-allframes.zip` and `other-gt-flow.zip` to a directory and provide the target below:
MIDD_DATA_ROOT = '/opt/psegs/ext_data/middlebury-flow/'

# For the Middlebury Flow dataset, we only consider the real scenes
MIDD_SCENES = [
    {
        'input': 'other-data/Dimetrodon/frame10.png',
        'expected_out': 'other-data/Dimetrodon/frame11.png',
        'flow_gt': 'other-gt-flow/Dimetrodon/flow10.flo',
    },
        {
        'input': 'other-data/Hydrangea/frame10.png',
        'expected_out': 'other-data/Hydrangea/frame11.png',
        'flow_gt': 'other-gt-flow/Hydrangea/flow10.flo',
    },
        {
        'input': 'other-data/RubberWhale/frame10.png',
        'expected_out': 'other-data/RubberWhale/frame11.png',
        'flow_gt': 'other-gt-flow/RubberWhale/flow10.flo',
    },
]


In [6]:
%%writefile cheap_optical_flow_eval_analysis/midd.py

def midd_read_flow(path):
    import os
    import numpy as np
    # Based upon: https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py
    # compute colored image to visualize optical flow file .flo
    # Author: Johannes Oswald, Technical University Munich
    # Contact: johannes.oswald@tum.de
    # Date: 26/04/2017
    # For more information, check http://vision.middlebury.edu/flow/ 
    assert os.path.exists(path) and path.endswith('.flo'), path
    f = open(path, 'rb')
    flo_number = np.fromfile(f, np.float32, count=1)[0]
    TAG_FLOAT = 202021.25
    assert flo_number == TAG_FLOAT, 'Flow number %r incorrect.' % flo_number
    w = np.fromfile(f, np.int32, count=1)
    h = np.fromfile(f, np.int32, count=1)

    #if error try: data = np.fromfile(f, np.float32, count=2*w[0]*h[0])
    data = np.fromfile(f, np.float32, count=int(2*w*h))

    # Reshape data into 3D array (columns, rows, bands)
    flow = np.resize(data, (int(h), int(w), 2))	
    f.close()

    # We found that there are some invalid (?) (i.e. very large) flows, so we're going
    # to ignore those for this experiment.
    invalid = (flow >= 1666)
    flow[invalid] = 0

    return flow


Writing cheap_optical_flow_eval_analysis/midd.py


In [7]:
from cheap_optical_flow_eval_analysis.midd import *

In [8]:
for i, scene in enumerate(MIDD_SCENES):
    p = OpticalFlowPair(
            dataset="Middlebury Optical Flow",
            id1=scene['input'],
            img1='file://' + os.path.join(MIDD_DATA_ROOT, scene['input']),
            id2=scene['expected_out'],
            img2='file://' + os.path.join(MIDD_DATA_ROOT, scene['expected_out']),
            flow=CloudpickeledCallable(lambda: midd_read_flow(os.path.join(MIDD_DATA_ROOT, scene['flow_gt']))))
    
    if RUN_FULL_ANALYSIS:
        ALL_FPS.append(copy.deepcopy(p))
    
    if SHOW_DEMO_OUTPUT:
        show_html(p.to_html() + "<br/><br/><br/>")
        DEMO_FPS.append(p)

## DeepDeform

In [9]:
# Please extract deepdeform_v1.7z to a directory and provide the target below:
DD_DATA_ROOT = '/opt/psegs/ext_data/deepdeform_v1/'

DD_DEMO_SCENES = [
    {
        "input": "train/seq000/color/000000.jpg",
        "expected_out": "train/seq000/color/000200.jpg",
        "flow_gt": "train/seq000/optical_flow/blackdog_000000_000200.oflow",
    },
    
    {
        "input": "train/seq000/color/000000.jpg",
        "expected_out": "train/seq000/color/001200.jpg",
        "flow_gt": "train/seq000/optical_flow/blackdog_000000_001200.oflow",
    },
    
    {
        "input": "train/seq001/color/003400.jpg",
        "expected_out": "train/seq001/color/003600.jpg",
        "flow_gt": "train/seq001/optical_flow/lady_003400_003600.oflow",
    },
    
    {
        "input": "train/seq337/color/000050.jpg",
        "expected_out": "train/seq337/color/000350.jpg",
        "flow_gt": "train/seq337/optical_flow/adult_000050_000350.oflow",
    },
]


In [10]:
%%writefile cheap_optical_flow_eval_analysis/deepdeform.py

import attr

# Written as a functor to make it easier to pickle
@attr.s(slots=True, eq=False, weakref_slot=False)
class DDLoadFLow(object):
    path = attr.ib(default='')
    def __call__(self):
        path = self.path
        # Based upon https://github.com/AljazBozic/DeepDeform/blob/master/utils.py#L1
        import shutil
        import struct
        import os
        import numpy as np

        # Flow is stored row-wise in order [channels, height, width].
        assert os.path.isfile(path)

        flow_gt = None
        with open(path, 'rb') as fin:
            width = struct.unpack('I', fin.read(4))[0]
            height = struct.unpack('I', fin.read(4))[0]
            channels = struct.unpack('I', fin.read(4))[0]
            n_elems = height * width * channels

            flow = struct.unpack('f' * n_elems, fin.read(n_elems * 4))
            flow_gt = np.asarray(flow, dtype=np.float32).reshape([channels, height, width])

        # Match format used in this analysis
        flow_gt = np.moveaxis(flow_gt, 0, -1) # (h, w, 2)
        invalid_flow = flow_gt == -np.Inf
        flow_gt[invalid_flow] = 0.0
        return flow_gt


Writing cheap_optical_flow_eval_analysis/deepdeform.py


In [11]:
from cheap_optical_flow_eval_analysis.deepdeform import *

In [12]:
def dd_create_fp(info):
     return OpticalFlowPair(
                dataset="DeepDeform Semi-Synthetic Optical Flow",
                id1=scene['input'],
                img1='file://' + os.path.join(DD_DATA_ROOT, scene['input']),
                id2=scene['expected_out'],
                img2='file://' + os.path.join(DD_DATA_ROOT, scene['expected_out']),
                flow=DDLoadFLow(os.path.join(DD_DATA_ROOT, scene['flow_gt'])))

import json
DD_ALIGNMENTS = json.load(open(os.path.join(DD_DATA_ROOT, 'train_alignments.json')))
ALL_DD_SCENES = [
    {
        "input": ascene['source_color'],
        "expected_out": ascene['target_color'],
        "flow_gt": ascene['optical_flow'],
    }
    for ascene in DD_ALIGNMENTS
]

print("Found %s DeepDeform scenes" % len(ALL_DD_SCENES))
if SHOW_DEMO_OUTPUT:
    for scene in DD_DEMO_SCENES:
        p = dd_create_fp(scene)
        show_html(p.to_html())
        DEMO_FPS.append(p)

if RUN_FULL_ANALYSIS:
    for scene in ALL_DD_SCENES:
        p = dd_create_fp(scene)
        ALL_FPS.append(p)
        

## Kitti Scene Flow Benchmark (2015)


In [13]:
# Please unzip `data_scene_flow.zip` and `data_scene_flow_calib.zip` to a directory and provide that target below:
KITTI_SF15_DATA_ROOT = '/opt/psegs/ext_data/kitti_scene_flow_2015/'

# You have to ls flow_occ 
KITTI_SF15_DEMO_SCENES = [
    {
        'input': 'training/image_2/000000_10.png',
        'expected_out': 'training/image_2/000000_11.png',
        'flow_gt': 'training/flow_occ/000000_10.png',
    },
    {
        'input': 'training/image_2/000007_10.png',
        'expected_out': 'training/image_2/000007_11.png',
        'flow_gt': 'training/flow_occ/000007_10.png',
    },
    {
        'input': 'training/image_2/000023_10.png',
        'expected_out': 'training/image_2/000023_11.png',
        'flow_gt': 'training/flow_occ/000023_10.png',
    },
    {
        'input': 'training/image_2/000051_10.png',
        'expected_out': 'training/image_2/000051_11.png',
        'flow_gt': 'training/flow_occ/000051_10.png',
    },
    {
        'input': 'training/image_2/000003_10.png',
        'expected_out': 'training/image_2/000003_11.png',
        'flow_gt': 'training/flow_occ/000003_10.png',
    },
]

from oarphpy import util as oputil
KITTI_SF15_ALL_FLOW_OCC = [
    os.path.basename(p)
    for p in oputil.all_files_recursive(
        os.path.join(KITTI_SF15_DATA_ROOT, 'training/flow_occ'), pattern='*.png')
]
    
KITTI_SF15_ALL_SCENES = [
    {
        "input": 'training/image_2/%s' % fname,
        "expected_out": 'training/image_2/%s' % fname.replace('_10', '_11'),
        "flow_gt": 'training/flow_occ/%s' % fname,
    }
    for fname in KITTI_SF15_ALL_FLOW_OCC
]
print("Found %s KITTI SceneFlow 2015 scenes" % len(KITTI_SF15_ALL_SCENES))



Found 0 KITTI SceneFlow 2015 scenes


In [14]:
%%writefile cheap_optical_flow_eval_analysis/kittisf15.py

import attr

# Written as a functor to make it easier to pickle
@attr.s(slots=True, eq=False, weakref_slot=False)
class KITTISF15LoadFlowFromPng(object):
    path = attr.ib(default='')
    def __call__(self):
        path = self.path
        # Based upon https://github.com/liruoteng/OpticalFlowToolkit/blob/master/lib/flowlib.py#L559
        import png
        import numpy as np
        flow_object = png.Reader(filename=path)
        flow_direct = flow_object.asDirect()
        flow_data = list(flow_direct[2])
        w, h = flow_direct[3]['size']
        flow = np.zeros((h, w, 3), dtype=np.float64)
        for i in range(len(flow_data)):
            flow[i, :, 0] = flow_data[i][0::3]
            flow[i, :, 1] = flow_data[i][1::3]
            flow[i, :, 2] = flow_data[i][2::3]

        invalid_idx = (flow[:, :, 2] == 0)
        flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
        flow[invalid_idx, 0] = 0
        flow[invalid_idx, 1] = 0
        return flow[:, :, :2]


Writing cheap_optical_flow_eval_analysis/kittisf15.py


In [15]:
from cheap_optical_flow_eval_analysis.kittisf15 import *

In [16]:
def kitti_sf15_create_fp(info):
     return OpticalFlowPair(
                dataset="KITTI Scene Flow 2015",
                id1=scene['input'],
                img1='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['input']),
                id2=scene['expected_out'],
                img2='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['expected_out']),
                flow=KITTISF15LoadFlowFromPng(os.path.join(KITTI_SF15_DATA_ROOT, scene['flow_gt'])))

if SHOW_DEMO_OUTPUT:
    for scene in KITTI_SF15_DEMO_SCENES:
        p = kitti_sf15_create_fp(scene)
        show_html(p.to_html())
        DEMO_FPS.append(p)

if RUN_FULL_ANALYSIS:
    for scene in KITTI_SF15_ALL_SCENES:
        p = kitti_sf15_create_fp(scene)
        ALL_FPS.append(p)

## PSegs Synthetic Flow from Fused Lidar

In [None]:
PSEGS_SYNTHFLOW_PARQUET_ROOT = '/outer_root/media/rocket4q/psegs_flow_records_short'

from psegs.exp.fused_lidar_flow import FlowRecTable

T = FlowRecTable(spark, PSEGS_SYNTHFLOW_PARQUET_ROOT)
synthflow_record_uris = T.get_record_uris()
print("Found %s PSegs SynthFlow records" % len(synthflow_record_uris))


PSEGS_SYNTHFLOW_DEMO_RECORD_URIS = (
  'psegs://dataset=kitti-360&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=4340,4339',
  'psegs://dataset=kitti-360&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=11219,11269',

  'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=40009,40010',
  'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=50013,50014',

  # 'psegs://dataset=kitti-360-fused&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=11103,11104',
  # 'psegs://dataset=kitti-360-fused&split=train&segment_id=2013_05_28_drive_0000_sync&extra.psegs_flow_sids=1181,1182',

  # 'psegs://dataset=nuscenes&split=train_detect&segment_id=scene-0002&extra.psegs_flow_sids=10016,10017',
  # 'psegs://dataset=nuscenes&split=train_detect&segment_id=scene-0582&extra.psegs_flow_sids=60035,60036',

  # 'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0393&extra.psegs_flow_sids=50017,50018',
  # 'psegs://dataset=nuscenes&split=train_track&segment_id=scene-0501&extra.psegs_flow_sids=40019,40020',
)


fr_samp_rdd = T.get_records_with_samples_rdd(
                    record_uris=[PSEGS_SYNTHFLOW_DEMO_RECORD_URIS[0]],
                    include_cameras=False,
                    include_cuboids=False,
                    include_point_clouds=False)
flow_rec = fr_samp_rdd.take(1)[0][0]
show_html(flow_rec.to_html())


PSEGS_SYNTHFLOW_DEMO_FPS_DO_CACHE = True
PSEGS_SYNTHFLOW_DEMO_FPS_CACHE_PATH = '/tmp/psegs_synthflow_demo.pkl'



Found 191 PSegs SynthFlow records


In [7]:
%%writefile cheap_optical_flow_eval_analysis/psegs_synthflow.py

from oarphpy.spark import CloudpickeledCallable

from psegs.exp.fused_lidar_flow import FlowRecTable

def flow_rec_to_fp(flow_rec, sample):
  fr = flow_rec

  uri_str_to_datum = sample.get_uri_str_to_datum()

  # Find the camera_images associated with `flow_rec`
  ci1_url_str = str(flow_rec.clouds[0].ci_uris[0])
  ci1_sd = uri_str_to_datum[ci1_url_str]
  ci1 = ci1_sd.camera_image

  ci2_url_str = str(flow_rec.clouds[1].ci_uris[0])
  ci2_sd = uri_str_to_datum[ci2_url_str]
  ci2 = ci2_sd.camera_image

  import numpy as np
  world_T1 = ci1.ego_pose.translation
  world_T2 = ci2.ego_pose.translation
  translation_meters = np.linalg.norm(world_T2 - world_T1)

  id1 = ci1_url_str + '&extra.psegs_flow_sids=' + str(fr.clouds[0].sample_id)
  id2 = ci2_url_str + '&extra.psegs_flow_sids=' + str(fr.clouds[1].sample_id)

  fp = OpticalFlowPair(
          dataset=fr.uri.dataset + '/' + fr.uri.split,
          id1=id1,
          id2=id2,
          img1=CloudpickeledCallable(lambda: ci1.image),
          img2=CloudpickeledCallable(lambda: ci2.image),
          flow=CloudpickeledCallable(lambda: fr.to_optical_flow()),

          diff_time_sec=abs(ci2_sd.uri.timestamp - ci1_sd.uri.timestamp),
          translation_meters=translation_meters)
  return fp

def psegs_synthflow_create_fps(
        spark,
        flow_record_pq_table_path,
        record_uris,
        include_cuboids=False,
        include_point_clouds=False):

  T = FlowRecTable(spark, flow_record_pq_table_path)
  rec_sample_rdd = T.get_records_with_samples_rdd(
                          record_uris=record_uris,
                          include_cameras=True,
                          include_cuboids=include_cuboids,
                          include_point_clouds=include_point_clouds)

  fps = [
    flow_rec_to_fp(flow_rec, sample)
    for flow_rec, sample in rec_sample_rdd.collect()
  ]

  return fps

def psegs_synthflow_iter_fp_rdds(
        spark,
        flow_record_pq_table_path,
        fps_per_rdd=100,
        include_cuboids=False,
        include_point_clouds=False):
  
  T = FlowRecTable(spark, flow_record_pq_table_path)
  ruris = T.get_record_uris()

  # Ensure a sort so that pairs from similar segments will load in the same
  # RDD -- that makes joins smaller and faster
  ruris = sorted(ruris)

  from oarphpy import util as oputil
  for ruri_chunk in oputil.ichunked(ruris, fps_per_rdd):
    frec_sample_rdd = T.get_records_with_samples_rdd(
                          record_uris=rids,
                          include_cuboids=include_cuboids,
                          include_point_clouds=include_point_clouds)
    fp_rdd = frec_sample_rdd.map(flow_rec_to_fp)
    yield fp_rdd

Writing cheap_optical_flow_eval_analysis/psegs_synthflow.py


In [None]:
from cheap_optical_flow_eval_analysis.psegs_synthflow import *

In [5]:
fp = OpticalFlowPair(
                dataset="KITTI Scene Flow 2015",
                id1='yay',
                img1='yay',
                id2='yay',
                img2='yay',
                flow=np.array([1, 2, 3.]))

import pickle
data = pickle.dumps(fp, protocol=pickle.HIGHEST_PROTOCOL)
len(data)


# import sys
# sys.path.append('/opt/psegs')
# PSEGS_OFLOW_PKL_ROOT = '/opt/psegs/ext_data/psegs_oflow.parquet'

# # from oarphpy import util as oputil
# # PSEGS_OFLOW_PKL_PATHS = [
# #     os.path.abspath(p)
# #     for p in oputil.all_files_recursive(PSEGS_OFLOW_PKL_ROOT, pattern='*.pkl')
# # ]
# # print('len PSEGS_OFLOW_PKL_PATHS', len(PSEGS_OFLOW_PKL_PATHS))

# def psegs_load_flow(uvdij1_visible_uvdij2_visible):
#     uvdij_visible1 = uvdij1_visible_uvdij2_visible[:, :5]
#     uvdij_visible2 = uvdij1_visible_uvdij2_visible[:, 5:]
#     visible_both = ((uvdij_visible1[:, -1] == 1) & (uvdij_visible2[:, -1] == 1))
  
#     visboth_uv1 = uvdij_visible1[visible_both, :2]
#     visboth_uv2 = uvdij_visible2[visible_both, :2]
#     ij_flow = np.hstack([
#         uvdij_visible1[visible_both, 3:5], visboth_uv2 - visboth_uv1
#     ])
#     v2v_flow = np.zeros((h, w, 2))
#     xx = ij_flow[:, 0].astype(np.int)
#     yy = ij_flow[:, 1].astype(np.int)
#     v2v_flow[yy, xx] = ij_flow[:, 2:4]
    
#     return v2v_flow

# # def psegs_pkl_to_ofp_rowdata(pkl_path):
# #     import pickle
# #     with open(pkl_path, 'rb') as f:
# #         row = pickle.load(f)
    
# #     rowdata = {
# #         'ci1_uri': row['ci1_uri'],
# #         'ci2_uri': row['ci2_uri'],
# #         'flow': row['v2v_flow'],
        
# #     }
# #     from oarphpy.spark import RowAdapter
# #     return RowAdapter.to_row(rowdata)

# # from psegs.datasets.kitti_360 import KITTI360SDTable
# # df = KITTI360SDTable.as_df(spark)
# # print(df.rdd.map(lambda x: x).count())
# # assert False

# # def psegs_create_ofps_slow(pkl_paths):
# #     from psegs.exp.semantic_kitti import SemanticKITTISDTable
# #     from psegs.datasets.kitti_360 import KITTI360SDTable
# #     class KITTI360OurFusedClouds(KITTI360SDTable):
# #         INCLUDE_FISHEYES = False
# #         INCLUDE_FUSED_CLOUDS = False  # Use our own fused clouds
    
# #     for path in pkl_paths:
# #         import pickle
# #         with open(path, 'rb') as f:
# #             row = pickle.load(f)
        
# #         ci1_uri = row['ci1_uri']
# #         ci2_uri = row['ci2_uri']
# #         flow = row['v2v_flow']
    
# #         if ci1_uri.dataset == 'semantikitti':
# #             T = SemanticKITTISDTable
# #         elif ci1_uri.dataset == 'kitti-360':
# #             T = KITTI360SDTable
# #         else:
# #             raise ValueError(ci1_uri)
        
# #         s1 = T.get_sample(ci1_uri)
# #         ci1 = s1.camera_images[0]
# #         img1 = ci1.image
        
# #         s2 = T.get_sample(ci2_uri)
# #         ci2 = s2.camera_images[0]
# #         img2 = ci2.image
        
# #         yield OpticalFlowPair(
# #                 dataset=ci1_uri.dataset,
# #                 id1=str(ci1_uri),
# #                 img1=img1,
# #                 id2=str(ci2_uri),
# #                 img2=img2,
# #                 flow=flow)


# # if True: #SHOW_DEMO_OUTPUT:
# #     for p in psegs_create_ofps_slow(PSEGS_OFLOW_PKL_PATHS):
# #         show_html(p.to_html())
# #         ALL_FPS.append(p)
        
    
# # dataset = attr.ib(type=str, default='')
# #     """To which dataset does this pair belong?"""
    
# #     id1 = attr.ib(type=str, default='')
# #     """Identifier or URI for the first image"""
    
# #     id2 = attr.ib(type=str, default='')
# #     """Identifier or URI for the second image"""
    
# #     img1 = attr.ib(default=None)
# #     """URI or numpy array or CloudPickleCallable for the first image (source image)"""

# #     img2 = attr.ib(default=None)
# #     """URI or numpy array or CloudpickeledCallablefor the second image (target image)"""
    
# #     flow = attr.ib(default=None)
# #     """A callable or numpy array representing optical flow from img1 -> img2"""


270

In [1]:
import sys
sys.path.append('/opt/psegs')
from psegs.spark import NBSpark
NBSpark.SRC_ROOT = os.path.join(ALIB_SRC_DIR, 'cheap_optical_flow_eval_analysis')
NBSpark.CONF_KV.update({
    'spark.driver.maxResultSize': '2g',
    'spark.driver.memory': '16g',
  })
spark = NBSpark.getOrCreate()
spark.addPyFile()


PSEGS_OFLOW_PQ_ROOT = '/opt/psegs/ext_data/psegs_oflow.parquet'

oflow_df = spark.read.parquet(PSEGS_OFLOW_PQ_ROOT)
print(oflow_df.count())

print('segment_id', oflow_df.select('segment_id').distinct().count())


from psegs.table.sd_db import StampedDatumDB
# from psegs.exp.semantic_kitti import SemanticKITTISDTable
from psegs.dummyrun import KITTI360_OurFused
from psegs.dummyrun import KITTI360_KITTIFused
from psegs.dummyrun import NuscFlowSDTable
Ts = [
    KITTI360_OurFused,
    KITTI360_KITTIFused,
    NuscFlowSDTable,
]
db = StampedDatumDB(Ts, spark=spark)


from pyspark.sql import functions as F
COLS = ('dataset', 'split', 'segment_id', 'topic', 'timestamp')
def _build_uri_df(uri_colname):
    cols = [
        oflow_df[uri_colname + '.' + c]
        for c in COLS
    ]
    return oflow_df.select(*cols)
    
oflow_uri1_df = _build_uri_df('ci1_uri')
oflow_uri2_df = _build_uri_df('ci2_uri')
oflow_uri_df = oflow_uri1_df.union(oflow_uri2_df)
datum_df = db.get_datum_df(oflow_uri_df)
datum_df.show()
print(datum_df.count())



2021-04-03 01:19:00,797	oarph 1648501 : Using source root /opt/psegs/psegs 
2021-04-03 01:19:00,797	oarph 1648501 : Using source root /opt/psegs 
2021-04-03 01:19:00,984	oarph 1648501 : Generating egg to /tmp/tmpyqtmjnw9_oarphpy_eggbuild ...
2021-04-03 01:19:01,070	oarph 1648501 : ... done.  Egg at /tmp/tmpyqtmjnw9_oarphpy_eggbuild/psegs-0.0.0-py3.8.egg


100
segment_id 23
seg_uris [URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0000_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0002_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0003_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0004_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0005_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0006_sync', timestamp=0, topic='', extra={}, track_id='', sel_datums=[]), URI(dataset='kitti-360', split='train', segment_id='2013_05_28_drive_0007_sync', timestamp=0, topic='', extra={}, 

ValueError: No known table for psegs://dataset=nuscenes&split=train_detect&segment_id=scene-0283

In [3]:
oflow_uri_df.show()
from psegs.datum.uri import URI
URI.__slots__
# from psegs.table.sd_db import StampedDatumDB
# from psegs.exp.semantic_kitti import SemanticKITTISDTable
# from psegs.datasets.kitti_360 import KITTI360SDTable
# class KITTI360OurFusedClouds(KITTI360SDTable):
#     INCLUDE_FISHEYES = False
#     INCLUDE_FUSED_CLOUDS = False  # Use our own fused clouds

# db = StampedDatumDB([SemanticKITTISDTable, KITTI360SDTable], spark=spark)

+---------+------------+--------------------+--------------------+-------------------+
|  dataset|       split|          segment_id|               topic|          timestamp|
+---------+------------+--------------------+--------------------+-------------------+
| nuscenes|train_detect|          scene-0582|camera|CAM_FRONT_...|1537291129004799000|
| nuscenes|train_detect|          scene-0582|camera|CAM_BACK_LEFT|1537291132447405000|
| nuscenes|train_detect|          scene-0582|camera|CAM_BACK_LEFT|1537291140897405000|
| nuscenes| train_track|          scene-0129|     camera|CAM_BACK|1533112859187525000|
| nuscenes| train_track|          scene-0129|camera|CAM_BACK_R...|1533112852677893000|
| nuscenes| train_track|          scene-0129|camera|CAM_FRONT_...|1533112861170339000|
| nuscenes| train_track|          scene-0129|camera|CAM_FRONT_...|1533112864104844000|
| nuscenes|train_detect|          scene-0002|camera|CAM_BACK_LEFT|1531883719448129000|
| nuscenes|train_detect|          scene-000

('dataset',
 'split',
 'segment_id',
 'timestamp',
 'topic',
 'extra',
 'track_id',
 'sel_datums')

In [None]:
# rowdata_df = psegs_create_ofp_df(spark, PSEGS_OFLOW_PKL_PATHS[:2])


def psegs_pkl_to_ofp_rowdata(pkl_path):
    import pickle
    with open(pkl_path, 'rb') as f:
        row = pickle.load(f)
    
    rowdata = {
        'ci1_uri': row['ci1_uri'],
        'ci2_uri': row['ci2_uri'],
        'flow': row['v2v_flow'],
        
    }
    from oarphpy.spark import RowAdapter
    return RowAdapter.to_row(rowdata)


def psegs_create_ofps_slow(pkl_paths):
    
    
    from psegs.exp.semantic_kitti import SemanticKITTISDTable
    from psegs.datasets.kitti_360 import KITTI360SDTable
    class KITTI360OurFusedClouds(KITTI360SDTable):
        INCLUDE_FISHEYES = False
        INCLUDE_FUSED_CLOUDS = False  # Use our own fused clouds
    
    for path in pkl_paths:
        import pickle
        with open(path, 'rb') as f:
            row = pickle.load(f)
        
        ci1_uri = row['ci1_uri']
        ci2_uri = row['ci2_uri']
        flow = row['v2v_flow']
    
        if ci1_uri.dataset == 'semantikitti':
            T = SemanticKITTISDTable
        elif ci1_uri.dataset == 'kitti-360':
            T = KITTI360SDTable
        else:
            raise ValueError(ci1_uri)
        
        s1 = T.get_sample(ci1_uri)
        ci1 = s1.camera_images[0]
        img1 = ci1.image
        
        s2 = T.get_sample(ci2_uri)
        ci2 = s2.camera_images[0]
        img2 = ci2.image
        
        yield OpticalFlowPair(
                dataset=ci1_uri.dataset,
                id1=str(ci1_uri),
                img1=img1,
                id2=str(ci2_uri),
                img2=img2,
                flow=flow)


if True: #SHOW_DEMO_OUTPUT:
    for path in PSEGS_OFLOW_PKL_PATHS:
        import pickle
        with open(path, 'rb') as f:
            rowdata = pickle.load(f)
        
        from oarphpy.spark import RowAdapter
        row = RowAdapter.from_row(rowdata)
        
        ci1_uri = row['ci1_uri']
        ci2_uri = row['ci2_uri']
        flow = row['v2v_flow']
        
        s1 = db.get_sample(ci1_uri)
        ci1 = s1.camera_images[0]
        img1 = ci1.image
        
        s2 = db.get_sample(ci2_uri)
        ci2 = s2.camera_images[0]
        img2 = ci2.image
        
        p = OpticalFlowPair(
                dataset=ci1_uri.dataset,
                id1=str(ci1_uri),
                img1=img1,
                id2=str(ci2_uri),
                img2=img2,
                flow=flow)
        show_html(p.to_html())
        ALL_FPS.append(p)
    
#     for p in psegs_create_ofps_slow(PSEGS_OFLOW_PKL_PATHS):
#         show_html(p.to_html())
#         ALL_FPS.append(p)

# assert False

# rowdata_df.show()
# from oarphpy.spark import RowAdapter
# c = rowdata_df.select('ci1_uri').collect()
# print([str(RowAdapter.from_row(cc.ci1_uri)) for cc in c])

## Reconstruction via Optical Flow

In [21]:
## Reconstruction via Optical Flow

def zero_flow(flow):
    return (flow[:, :, :2] == np.array([0, 0])).all(axis=-1)

def warp_flow_backwards(img, flow):
    """Given an image, apply the inverse of `flow`"""
    h, w = flow.shape[:2]
    flow = -flow
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    res = cv2.remap(img, flow.astype(np.float32), None, cv2.INTER_LINEAR)
    return res
    
def warp_flow_forwards(img, flow):
    """Given an image, apply the given optical flow `flow`.  Returns not only the warped
    image, but a `mask` indicating warped pixels (i.e. there was non-zero flow *into* these pixels ).
    With some help from https://stackoverflow.com/questions/41703210/inverting-a-real-valued-index-grid/46009462#46009462
    """
    h, w = img.shape[:2]
    pts = flow.copy()
    pts[:, :, 0] += np.arange(w)
    pts[:, :, 1] += np.arange(h)[:, np.newaxis]
    exclude = zero_flow(flow)
    if exclude.all():
        # No flow anywhere!
        return img.copy(), np.zeros((h, w)).astype(np.bool)
    else:
        inpts = pts[~exclude]
    
    from scipy.interpolate import griddata
    inpts = np.reshape(inpts, [-1, 2])
    grid_y, grid_x = np.mgrid[:h, :w]
    chan_out = []
    for ch in range(img.shape[-1]):
        spts = img[:, :, ch][~exclude].reshape([-1, 1])
        mapped = griddata(inpts, spts, (grid_x, grid_y), method='linear')
        chan_out.append(mapped.astype(img.dtype))
    out = np.stack(chan_out, axis=-1)
    out = out.reshape([h, w, len(chan_out)])

    mask = np.reshape(inpts, [-1, 2])
    mask = np.rint(mask).astype(np.int)
    mask = mask[np.where((mask[:, 0] >= 0) & (mask[:, 0] < w) & (mask[:, 1] >= 0) & (mask[:, 1] < h))]
    valid_mask = np.zeros((h, w))
    valid_mask[mask[:, 1], mask[:, 0]] = 1
    
    return out, valid_mask.astype(np.bool)

# @attr.s(slots=True, eq=False, weakref_slot=False)
class FlowReconstructedImagePair(object):
    """A pair of reconstructed images using an input pair of images and optical
    flow field (i.e. an `OpticalFlowPair` instance)."""

    slots = (
        'opair',
        'img2_recon_fwd',
        'img2_recon_fwd_valid',
        'img1_recon_bkd',
        'img1_recon_bkd_valid'
    )
    
    def __init__(self, **kwargs):
        for k in self.slots:
            setattr(self, k, kwargs.get(k))
    
#     opair = attr.ib(default=OpticalFlowPair())
#     """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
#     img2_recon_fwd = attr.ib(default=np.array([]))
#     """A Numpy image containing the result of FORWARDS-WARPING OpticalFlowPair::img1
#     via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img2"""

#     img2_recon_fwd_valid = attr.ib(default=np.array([]))
#     """A Numpy boolean mask indicating which pixels of `img2_recon_fwd` were modified via non-zero flow"""
    
#     img1_recon_bkd = attr.ib(default=np.array([]))
#     """A Numpy image containing the result of BACKWARDS-WARPING OpticalFlowPair::img2
#     via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img1"""

#     img1_recon_bkd_valid = attr.ib(default=np.array([]))
#     """A Numpy boolean mask indicating which pixels of `img1_recon_bkd` were modified via non-zero flow"""
        
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        
        # Forward Warp
        fwarped, fvalid = warp_flow_forwards(oflow_pair.get_img1(), flow)

        # Backwards Warp
        exclude = zero_flow(flow)
        bwarped = warp_flow_backwards(oflow_pair.get_img2(), -flow[:, :, :2])
        bvalid = ~exclude
        
        return FlowReconstructedImagePair(
                opair=oflow_pair,
                img2_recon_fwd=fwarped,
                img2_recon_fwd_valid=fvalid,
                img1_recon_bkd=bwarped,
                img1_recon_bkd_valid=bvalid)
    
    def to_html(self):
        # We use pixels from the destination image in order to make the reconstruction 
        # easier to interpret; we'll fade them in intensity so that they are more
        # conspicuous.        
        FADE_UNTOUCHED_PIXELS = 0.3
        
        viz_fwd = self.img2_recon_fwd.copy().astype(np.float32)
        im2 = self.opair.get_img2()
        if (~self.img2_recon_fwd_valid).any():
            viz_fwd[~self.img2_recon_fwd_valid] = im2[~self.img2_recon_fwd_valid]
            viz_fwd[~self.img2_recon_fwd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            # viz_fwd = im2.copy() * FADE_UNTOUCHED_PIXELS
            print('no invalids forward!')
        
        viz_bkd = self.img1_recon_bkd.copy().astype(np.float32)
        im1 = self.opair.get_img1()
        if (~self.img1_recon_bkd_valid).any():
            viz_bkd[~self.img1_recon_bkd_valid] = im1[~self.img1_recon_bkd_valid]
            viz_bkd[~self.img1_recon_bkd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            # viz_bkd = im1.copy() * FADE_UNTOUCHED_PIXELS
            print('no invalids backwards!')
        
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Forwards Warped <i>(dark pixels unwarped)</i></b></td></tr>
            <tr><td><img src="{viz_fwd}" width="100%" /></td></tr>

            <tr><td style="text-align:left"><b>Backwards Warped <i>(dark pixels unwarped)</i></b></td></tr>
            <tr><td><img src="{viz_bkd}" width="100%" /></td></tr>

            </table>
        """.format(
                viz_fwd=img_to_data_uri(viz_fwd.astype(np.uint8)),
                viz_bkd=img_to_data_uri(viz_bkd.astype(np.uint8)))
        return html

        
if SHOW_DEMO_OUTPUT:
    DEMO_RECONS = []
    for p in DEMO_FPS:
        recon = FlowReconstructedImagePair.create_from(p)
        show_html(recon.to_html() + "</br></br></br>")
        DEMO_RECONS.append(recon)


## Analysis: Demo

In [22]:
# Analysis Utils

def mse(i1, i2, valid):
    return np.mean((i1[valid] - i2[valid]) ** 2)

def rmse(i1, i2, valid):
    return math.sqrt(mse(i1, i2, valid))

def psnr(i1, i2, valid):
    return 20 * math.log10(255) - 10 * math.log10(max((mse(i1, i2, valid), 1e-12)))

def ssim(i1, i2, valid):
    # Some variance out there ...
    # https://github.com/scikit-image/scikit-image/blob/master/skimage/metrics/_structural_similarity.py#L12-L232
    # https://github.com/nianticlabs/monodepth2/blob/13200ab2f29f2f10dec3aa5db29c32a23e29d376/layers.py#L218
    # https://cvnote.ddlee.cn/2019/09/12/psnr-ssim-python
    # We will just use SKImage for now ...
    from skimage.metrics import structural_similarity as ssim
    mssim, S = ssim(i1, i2, win_size=11, multichannel=True, full=True)
    return np.mean(S[valid])

def to_edge_im(img):
    return np.stack([
        cv2.Laplacian(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, ksize=1),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 1, 0, ksize=3),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 0, 1, ksize=3),
    ], axis=-1)

def edges_mse(i1, i2, valid):
    return mse(to_edge_im(i1), to_edge_im(i2), valid)


def oflow_coverage(valid):
    return valid.sum() / (valid.shape[0] * valid.shape[1])

def oflow_magnitude_hist(flow, valid, bins=50):
    flow_l2s = np.sqrt( flow[valid][:, 0] ** 2 + flow[valid][:, 1] ** 2 )
    bin_counts, bin_edges = np.histogram(flow_l2s, bins=bins)
    return bin_edges, bin_counts


# Analysis Data Model

class OFlowReconErrors(object):
    """Various measures of reconstruction error for a `FlowReconstructedImagePair` instance.
    Encapsulated as two dictionaries of stats for easy interop with Spark SQL."""

    RECONSTRUCTION_ERR_METRICS = {
        'SSIM': ssim,
        'MSE': mse,
        'RMSE': rmse,
        'PSNR': psnr,
        'Edges_MSE': edges_mse,
    }
    
    def __init__(self, recon_pair: FlowReconstructedImagePair):
        im2 = recon_pair.opair.get_img2()
        img2_recon_fwd = recon_pair.img2_recon_fwd
        img2_recon_fwd_valid = recon_pair.img2_recon_fwd_valid
        self.forward_stats = dict(
            (name, func(im2, img2_recon_fwd, img2_recon_fwd_valid))
            for name, func in self.RECONSTRUCTION_ERR_METRICS.items())
        
        im1 = recon_pair.opair.get_img1()
        img1_recon_fwd = recon_pair.img1_recon_bkd
        img1_recon_fwd_valid = recon_pair.img1_recon_bkd_valid
        self.backward_stats = dict(
            (name, func(im1, img1_recon_fwd, img1_recon_fwd_valid))
            for name, func in self.RECONSTRUCTION_ERR_METRICS.items())

    def to_html(self):
        stat_names = self.RECONSTRUCTION_ERR_METRICS.keys()

        rows = [
            """
            <tr>
              <td style="text-align:left"><b>{name}</b></td>
              <td style="text-align:left">{fwd:.2f}</td>
              <td style="text-align:left">{bkd:.2f}</td>
            </tr>
            """.format(name=name, fwd=self.forward_stats[name], bkd=self.backward_stats[name])
            for name in stat_names
        ]
        
        
        html = """
            <table>
              <tr>
                  <th></th> <th><b>Forwards Warp</b></th> <th><b>Backwards Warp</b></th>
              </tr>

              {table_rows}

            </table>
        """.format(table_rows="".join(rows))
        
        return html
            
# @attr.s(slots=True, eq=False, weakref_slot=False)
class OFlowStats(object):
    """Stats on the optical flow of a `OpticalFlowPair` instance"""

    slots = (
        'opair',
        'coverage',
        'magnitude_hist',
    )
    
    def __init__(self, **kwargs):
        for k in self.slots:
            setattr(self, k, kwargs.get(k))
    
#     opair = attr.ib(default=OpticalFlowPair())
#     """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
#     coverage = attr.ib(default=0)
#     """Fraction of the image with valid flow"""
    
#     magnitude_hist = attr.ib(default=[np.array([]), np.array([])])
#     """Histogram [bin edges, bin counts] of flow magnitudes"""
    
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        valid = ~zero_flow(flow)
        return OFlowStats(
                 opair=oflow_pair,
                 coverage=oflow_coverage(valid),
                 magnitude_hist=oflow_magnitude_hist(flow, valid))
                 
    def to_html(self):
        import matplotlib.pyplot as plt
        fig = plt.figure()
        bin_edges, bin_counts = self.magnitude_hist
        plt.bar(bin_edges[:-1], bin_counts)
        plt.title("Histogram of Flow Magnitudes")
        plt.xlabel('Flow Magnitude (pixels)')
        plt.ylabel('Count')

        hist_img = matplotlib_fig_to_img(fig)
        
        html = """
            <table>           
            <tr><td style="text-align:left"><b>Flow Coverage:</b> {coverage:.2f}% </td></tr>
            <tr><td><img src="{flow_hist}" width="100%" /></td></tr>
            </table>
        """.format(
                coverage=100. * self.coverage,
                flow_hist=img_to_data_uri(matplotlib_fig_to_img(hist_img)))
        return html


# Misc

def matplotlib_fig_to_img(fig):
    import io
    import matplotlib.pyplot as plt
    from PIL import Image
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    im = Image.open(buf)
    im.show()
    buf.seek(0)

    import imageio
    hist_img = imageio.imread(buf)
    buf.close()
    return hist_img


if SHOW_DEMO_OUTPUT:
    %matplotlib agg
    for recon in DEMO_RECONS:
        p = recon.opair
        errors = OFlowReconErrors(recon)
        err_html = errors.to_html()  
            
        fstats = OFlowStats.create_from(p)
        stats_html = fstats.to_html()
            
        title = "<b>{dataset} {id1} -> {id2}</b>".format(dataset=p.dataset, id1=p.id1, id2=p.id2)
        
        show_html(title + stats_html + err_html + "</br></br></br>")
            

## Analysis on Full Datasets

In [23]:
# import sys
# sys.path.append('/opt/psegs')

# from oarphpy.spark import NBSpark
# NBSpark.SRC_ROOT = os.path.join(ALIB_SRC_DIR, 'cheap_optical_flow_eval_analysis')
# NBSpark.CONF_KV.update({
#     'spark.driver.maxResultSize': '2g',
#     'spark.driver.memory': '16g',
#   })
# spark = NBSpark.getOrCreate()


from oarphpy.spark import RowAdapter

from pyspark import Row


def flow_pair_to_full_row(fp):
    fp_lite = copy.deepcopy(fp)
    recon = FlowReconstructedImagePair.create_from(fp)
    fstats = OFlowStats.create_from(fp)
    errors = OFlowReconErrors(recon)
    
    rowdata = dict(
            fp=fp_lite,
            flow_coverage=fstats.coverage,
    )
    rowdata.update(
        ('Forwards_' + k, float(v))
        for k, v in errors.forward_stats.items())
    rowdata.update(
        ('Backwards_' + k, float(v))
        for k, v in errors.backward_stats.items())
    return RowAdapter.to_row(rowdata)
  
if True:#RUN_FULL_ANALYSIS:
#     spark = Spark.getOrCreate()
    
#     for p in ALL_FPS:
#         import cloudpickle
#         try:
#             cloudpickle.dumps(p)
#         except Exception:
#             assert False, p
#     print('all good')
    
    import pickle
    fp_rdd = spark.sparkContext.parallelize(ALL_FPS, numSlices=200)
#     print(fp_rdd.count())
    df = spark.createDataFrame(fp_rdd.map(flow_pair_to_full_row)).persist()

    print(df.count())
    df.show(10)
    df.printSchema()



120
+-------------------+------------------+------------------+-----------------+-------------------+------------------+------------------+------------------+------------------+------------------+-------------------+--------------------+
|Backwards_Edges_MSE|     Backwards_MSE|    Backwards_PSNR|   Backwards_RMSE|     Backwards_SSIM|Forwards_Edges_MSE|      Forwards_MSE|     Forwards_PSNR|     Forwards_RMSE|     Forwards_SSIM|      flow_coverage|                  fp|
+-------------------+------------------+------------------+-----------------+-------------------+------------------+------------------+------------------+------------------+------------------+-------------------+--------------------+
|    12901.162109375| 44.07437546508901| 31.68894193815717|6.638853475193515| 0.4778988779442309| 3324.151611328125| 42.68944660650385| 31.82759835890608| 6.533716140643382|0.7317521046906023| 0.3409865359042553|[cheap_optical_fl...|
|   16117.0888671875| 36.85686426286956|32.465619776588206|6

In [24]:
df.columns

['Backwards_Edges_MSE',
 'Backwards_MSE',
 'Backwards_PSNR',
 'Backwards_RMSE',
 'Backwards_SSIM',
 'Forwards_Edges_MSE',
 'Forwards_MSE',
 'Forwards_PSNR',
 'Forwards_RMSE',
 'Forwards_SSIM',
 'flow_coverage',
 'fp']

In [25]:
from oarphpy import plotting as pl
class Plotter(pl.HistogramWithExamplesPlotter):
    NUM_BINS = 10
    ROWS_TO_DISPLAY_PER_BUCKET = 3
    SUB_PIVOT_COL = 'fp.dataset'

    def display_bucket(self, sub_pivot, bucket_id, irows):
        import itertools
        from oarphpy.spark import RowAdapter
        
        row_htmls = []
        for row in itertools.islice(irows, self.ROWS_TO_DISPLAY_PER_BUCKET):
            rowdata = RowAdapter.from_row(row)
            
            fp = rowdata['fp']
            recon = FlowReconstructedImagePair.create_from(fp)
            fstats = OFlowStats.create_from(fp)
            errors = OFlowReconErrors(recon)
            
            row_html = "<br/>".join((fp.to_html(), recon.to_html(), fstats.to_html(), errors.to_html()))
            row_htmls.append(row_html)
        
        HTML = """
        <b>Pivot: {spv} Bucket: {bucket_id} </b> <br/>
        
        {row_bodies}
        """.format(
              spv=sub_pivot,
              bucket_id=bucket_id,
              row_bodies="<br/><br/><br/>".join(row_htmls))
        
        return bucket_id, HTML

plotter = Plotter()

for col in df.columns:
    col = str(col)
    if col == 'fp':
        continue
    
    fig = plotter.run(df, col)
    pl.save_bokeh_fig(fig, '/tmp/%s.html' % col)

# from bokeh.io import output_notebook
# output_notebook()
# from bokeh.plotting import show
# show(fig)

2021-02-25 13:05:25,023	oarph 1504052 : Plotting histogram for Backwards_Edges_MSE of DataFrame[Backwards_Edges_MSE: double, Backwards_MSE: double, Backwards_PSNR: double, Backwards_RMSE: double, Backwards_SSIM: double, Forwards_Edges_MSE: double, Forwards_MSE: double, Forwards_PSNR: double, Forwards_RMSE: double, Forwards_SSIM: double, flow_coverage: double, fp: struct<__pyclass__:string,dataset:string,id1:string,id2:string,img1:struct<__pyclass__:string,shape:array<bigint>,dtype:string,order:string,values:array<bigint>,values_packed:binary>,img2:struct<__pyclass__:string,shape:array<bigint>,dtype:string,order:string,values:array<bigint>,values_packed:binary>,flow:struct<__pyclass__:string,shape:array<bigint>,dtype:string,order:string,values:array<double>,values_packed:binary>>] ...
2021-02-25 13:05:28,177	oarph 1504052 : ... building data source for ALL ...
2021-02-25 13:05:28,177	oarph 1504052 : ... histogramming ALL ...
2021-02-25 13:05:33,123	oarph 1504052 : ... display-ifying exa

2021-02-25 13:31:19,812	oarph 1504052 : ... building data source for ALL ...
2021-02-25 13:31:19,813	oarph 1504052 : ... histogramming ALL ...
2021-02-25 13:31:24,386	oarph 1504052 : ... display-ifying examples for ALL ...
2021-02-25 13:33:11,292	oarph 1504052 : ... building data source for kitti-360 ...
2021-02-25 13:33:11,295	oarph 1504052 : ... histogramming kitti-360 ...
2021-02-25 13:33:16,442	oarph 1504052 : ... display-ifying examples for kitti-360 ...
2021-02-25 13:35:09,714	oarph 1504052 : Wrote Bokeh figure to /tmp/Forwards_MSE.html
2021-02-25 13:35:09,715	oarph 1504052 : Plotting histogram for Forwards_PSNR of DataFrame[Backwards_Edges_MSE: double, Backwards_MSE: double, Backwards_PSNR: double, Backwards_RMSE: double, Backwards_SSIM: double, Forwards_Edges_MSE: double, Forwards_MSE: double, Forwards_PSNR: double, Forwards_RMSE: double, Forwards_SSIM: double, flow_coverage: double, fp: struct<__pyclass__:string,dataset:string,id1:string,id2:string,img1:struct<__pyclass__:stri