Skip to content

Commit

Permalink
Add read_flo/write_flo
Browse files Browse the repository at this point in the history
  • Loading branch information
willprice committed May 6, 2019
1 parent 1686947 commit d032c37
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
8 changes: 8 additions & 0 deletions src/flowty/cv/c_core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from libcpp cimport bool
from libcpp.vector cimport vector
from libcpp.memory cimport shared_ptr
from libcpp.string cimport string


cdef extern from "opencv2/core.hpp" nogil:
Expand Down Expand Up @@ -106,6 +107,13 @@ cdef extern from "opencv2/core.hpp" namespace "cv" nogil:
unsigned char* dataend
int cols, rows, flags, dims

cdef cppclass String:
String(String&)
String(String&, size_t, size_t)
String(char *)
String(char *, size_t)
String(string&) except +

int getNumThreads()
void setNumThreads(int)
void setUseOptimized(bool)
Expand Down
7 changes: 4 additions & 3 deletions src/flowty/cv/c_optflow.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# cython: language_level = 3
from libcpp cimport bool
from .c_core cimport Ptr, InputArray, InputOutputArray
from .c_core cimport Ptr, InputArray, InputOutputArray, String, Mat


cdef extern from "opencv2/video/tracking.hpp" namespace "cv" nogil:
Expand Down Expand Up @@ -74,14 +74,15 @@ cdef extern from "opencv2/video/tracking.hpp" namespace "cv" nogil:
void setOmega(float)
void setSorIterations(int)

bool writeOpticalFlow(String&, InputArray)
Mat readOpticalFlow(String&)

cdef extern from "opencv2/video/tracking.hpp" namespace "cv::DISOpticalFlow":
enum:
PRESET_ULTRAFAST
PRESET_FAST
PRESET_MEDIUM


cdef extern from "opencv2/optflow.hpp" namespace "cv::optflow" nogil:
cdef cppclass DualTVL1OpticalFlow(DenseOpticalFlow):
@staticmethod
Expand Down Expand Up @@ -115,4 +116,4 @@ cdef extern from "opencv2/optflow.hpp" namespace "cv::optflow" nogil:

cdef cppclass DenseRLOFOpticalFlow(DenseOpticalFlow):
@staticmethod
Ptr[DenseRLOFOpticalFlow] create() except +
Ptr[DenseRLOFOpticalFlow] create() except +
19 changes: 16 additions & 3 deletions src/flowty/cv/optflow.pyx
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# cython: language_level=3
from pathlib import Path

from cython.operator cimport dereference as deref
from libcpp cimport bool
from ..cv.c_core cimport Ptr, Mat as c_Mat, InputArray, OutputArray, \
from libcpp.string cimport string
from ..cv.c_core cimport Ptr, String, Mat as c_Mat, InputArray, OutputArray, \
InputOutputArray, CV_32FC2, CV_32FC1
from ..cv.c_imgproc cimport cvtColor, ColorConversionCodes
from ..cv.core cimport Mat
from ..cv.core import get_num_threads, set_num_threads
from ..cv.c_optflow cimport DenseOpticalFlow as c_DenseOpticalFlow, \
DualTVL1OpticalFlow as c_DualTVL1OpticalFlow, \
FarnebackOpticalFlow as c_FarnebackOpticalFlow, \
DISOpticalFlow as c_DISOpticalFlow, PRESET_ULTRAFAST, PRESET_FAST, PRESET_MEDIUM, \
VariationalRefinement as c_VariationalRefinement, \
DenseRLOFOpticalFlow as c_DenseRLOFOpticalFlow
writeOpticalFlow, readOpticalFlow


cdef compute_flow(Ptr[c_DenseOpticalFlow] algorithm,
Expand Down Expand Up @@ -424,3 +425,15 @@ cdef class DenseInverseSearchOpticalFlow:
@variational_refinement_iterations.setter
def variational_refinement_iterations(self, int iterations):
deref(self.alg).setVariationalRefinementIterations(iterations)

def read_flo(str path) -> Mat:
if isinstance(path, Path):
path = str(path)
cdef c_Mat flow = readOpticalFlow(String(<string> path.encode('utf8')))
return Mat.from_mat(flow, copy=True)

def write_flo(Mat flow, path):
if isinstance(path, Path):
path = str(path)
return writeOpticalFlow(String(<string> path.encode('utf8')), <InputArray>
flow.c_mat)
18 changes: 16 additions & 2 deletions tests/unit/cv/test_optflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import ABC

import numpy as np
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_array_equal
from pytest import approx

from flowty.cv.optflow import TvL1OpticalFlow, FarnebackOpticalFlow, \
DenseInverseSearchOpticalFlow, VariationalRefinementOpticalFlow
DenseInverseSearchOpticalFlow, VariationalRefinementOpticalFlow, read_flo, write_flo
from flowty.cv.core import Mat, CV_32FC2
import pytest

Expand Down Expand Up @@ -146,3 +146,17 @@ def get_flow_algorithm(self):
])
def test_property(self, property, expected_value):
assert getattr(self.get_flow_algorithm(), property) == expected_value


def test_read_write_flo(tmpdir):
flow_path = str(tmpdir / 'test.flo')
rows = 20
cols = 30
channels = 2
flow_np = (np.random.rand(rows, cols, channels) * 20).astype(np.float32)
flow = Mat.fromarray(flow_np, copy=True)

write_flo(flow, flow_path)
flow2 = np.array(read_flo(flow_path))

assert_array_equal(flow2, flow)

0 comments on commit d032c37

Please sign in to comment.