Skip to content

Commit

Permalink
[POC] Cython FFI support for tuple, slice, ellipsis
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Dec 23, 2019
1 parent e6ff3f7 commit ddd9323
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 9 deletions.
34 changes: 34 additions & 0 deletions benchmark_ffi.py
@@ -0,0 +1,34 @@
import timeit
import tvm
nop = tvm._api_internal._nop

setup = """
import tvm
nop = tvm._api_internal._nop
"""
timer = timeit.Timer(setup=setup,
stmt='nop((None,..., slice(0, 100, 2)))')
timer.timeit(1)
num_repeat = 1000
print("tvm.tuple_slice_ellipsis_combo:", timer.timeit(num_repeat) / num_repeat)


setup = """
import numpy as np
"""

timer = timeit.Timer(setup=setup,
stmt='np.empty((1,2,1))')
timer.timeit(1)
print("numpy.emmpty:", timer.timeit(num_repeat) / num_repeat)


setup = """
import tvm
nop = tvm._api_internal._nop
"""
timer = timeit.Timer(setup=setup,
stmt='nop("mystr")')
timer.timeit(1)
num_repeat = 1000
print("tvm.str_arg:", timer.timeit(num_repeat) / num_repeat)
4 changes: 3 additions & 1 deletion include/tvm/runtime/container.h
Expand Up @@ -35,6 +35,7 @@
namespace tvm {
namespace runtime {

class ADTBuilder;
/*!
* \brief Base template for classes with array like memory layout.
*
Expand Down Expand Up @@ -114,6 +115,7 @@ class InplaceArrayBase {
}

protected:
friend class ADTBuilder;
/*!
* \brief Construct a value in place with the arguments.
*
Expand Down Expand Up @@ -165,7 +167,7 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
/*! \brief The tag representing the constructor used. */
uint32_t tag;
/*! \brief Number of fields in the ADT object. */
uint32_t size;
uint32_t size{0};
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
Expand Down
111 changes: 111 additions & 0 deletions include/tvm/runtime/ffi_helper.h
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file ffi_helper
* \brief Helper class to support additional objects in FFI.
*/
#ifndef TVM_RUNTIME_FFI_HELPER_H_
#define TVM_RUNTIME_FFI_HELPER_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
#include <limits>

namespace tvm {
namespace runtime {

/*! \brief Ellipsis. */
class EllipsisObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kEllipsis;
static constexpr const char* _type_key = "vm.Ellipsis";
TVM_DECLARE_FINAL_OBJECT_INFO(EllipsisObj, Object);
};

inline ObjectRef CreateEllipsis() {
return ObjectRef(make_object<EllipsisObj>());
}

/*! \brief Slice. */
class SliceObj : public Object {
public:
int64_t start;
int64_t stop;
int64_t step;

static constexpr const uint32_t _type_index = TypeIndex::kSlice;
static constexpr const char* _type_key = "vm.Slice";
TVM_DECLARE_FINAL_OBJECT_INFO(SliceObj, Object);
};

class Slice : public ObjectRef {
public:
explicit Slice(int64_t start, int64_t stop, int64_t step,
ObjectPtr<SliceObj>&& data = make_object<SliceObj>()) {
data->start = start;
data->stop = stop;
data->step = step;
data_ = std::move(data);
}

explicit Slice(int64_t stop)
: Slice(kNoneValue, stop, kNoneValue) {
}

// constant to represent None.
static constexpr int64_t kNoneValue = std::numeric_limits<int64_t>::min();

TVM_DEFINE_OBJECT_REF_METHODS(Slice, ObjectRef, SliceObj);
};

int64_t SliceNoneValue() {
return Slice::kNoneValue;
}

// Helper functions for fast FFI implementations
/*!
* \brief A builder class that helps to incrementally build ADT.
*/
class ADTBuilder {
public:
/*! \brief default constructor */
ADTBuilder() = default;

explicit ADTBuilder(uint32_t tag, uint32_t size)
: data_(make_inplace_array_object<ADTObj, ObjectRef>(size)) {
}

template <typename... Args>
void EmplaceInit(size_t idx, Args&&... args) {
data_->EmplaceInit(idx, std::forward<Args>(args)...);
}

ADT Get() {
return ADT(std::move(data_));
}

private:
friend class ADT;
ObjectPtr<ADTObj> data_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_FFI_HELPER_H_
2 changes: 2 additions & 0 deletions include/tvm/runtime/object.h
Expand Up @@ -54,6 +54,8 @@ enum TypeIndex {
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kEllipsis = 5,
kSlice = 6,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Expand Up @@ -96,6 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include",
],
extra_compile_args=["-std=c++11"],
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/_cython/base.pxi
Expand Up @@ -19,7 +19,7 @@ from ..base import get_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t, uint32_t
import ctypes

cdef enum TVMTypeCode:
Expand Down
80 changes: 80 additions & 0 deletions python/tvm/_ffi/_cython/convert.pxi
@@ -0,0 +1,80 @@
"""Fast object conversion API in cython."""

cdef extern from "tvm/runtime/object.h" namespace "tvm::runtime":
cdef cppclass Object:
pass

cdef cppclass ObjectPtr[T]:
ObjectPtr()

cdef cppclass ObjectRef:
ObjectRef()
ObjectRef(ObjectPtr[Object])
Object* get()


cdef ObjectPtr[T] GetObjectPtr[T](T* ptr)


cdef extern from "tvm/runtime/container.h" namespace "tvm::runtime":
cdef cppclass ADT(ObjectRef):
ADT()

cdef extern from "tvm/runtime/ffi_helper.h" namespace "tvm::runtime":
cdef cppclass ADTBuilder:
ADTBuilder()
ADTBuilder(uint32_t tag, uint32_t size)
void EmplaceInit(size_t idx, ObjectRef)
ADT Get()

cdef ObjectRef CreateEllipsis()

cdef cppclass Slice(ObjectRef):
Slice()
Slice(int)
Slice(int, int, int)

cdef int64_t SliceNoneValue()



cdef extern from "tvm/runtime/memory.h" namespace "tvm::runtime":
cdef ObjectPtr[T] make_object[T]()


cdef extern from "tvm/expr.h" namespace "tvm":
cdef cppclass Integer(ObjectRef):
Integer(int value);


## Implementations
cdef inline ADT convert_tuple(tuple src_tuple) except +:
cdef uint32_t size = len(src_tuple)
cdef ADTBuilder builder = ADTBuilder(0, size);

for i in range(size):
builder.EmplaceInit(i, convert_object(src_tuple[i]))

return builder.Get()


cdef inline Slice convert_slice(slice slice_obj) except +:
cdef int64_t kNoneValue = SliceNoneValue()
return Slice(<int>(slice_obj.start) if slice_obj.start is not None else kNoneValue,
<int>(slice_obj.stop) if slice_obj.stop is not None else kNoneValue,
<int>(slice_obj.step) if slice_obj.step is not None else kNoneValue)


cdef inline ObjectRef convert_object(object src_obj) except +:
if isinstance(src_obj, int):
return Integer(<int>src_obj)
elif isinstance(src_obj, tuple):
return convert_tuple(src_obj)
elif src_obj is Ellipsis:
return CreateEllipsis()
elif isinstance(src_obj, slice):
return convert_slice(src_obj)
elif src_obj is None:
return ObjectRef()
else:
raise TypeError("Don't know how to convert type %s" % type(src_obj))
3 changes: 1 addition & 2 deletions python/tvm/_ffi/_cython/core.pyx
Expand Up @@ -17,7 +17,6 @@

include "./base.pxi"
include "./object.pxi"
# include "./node.pxi"
include "./convert.pxi"
include "./function.pxi"
include "./ndarray.pxi"

21 changes: 16 additions & 5 deletions python/tvm/_ffi/_cython/function.pxi
Expand Up @@ -35,6 +35,7 @@ cdef int tvm_callback(TVMValue* args,
void* fhandle) with gil:
cdef list pyargs
cdef TVMValue value
cdef ObjectRef temp_obj
cdef int tcode
local_pyfunc = <object>(fhandle)
pyargs = []
Expand Down Expand Up @@ -62,7 +63,7 @@ cdef int tvm_callback(TVMValue* args,
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
make_arg(rv, &value, &tcode, temp_args)
make_arg(rv, &value, &tcode, &temp_obj, temp_args)
CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
return 0

Expand Down Expand Up @@ -94,16 +95,22 @@ def convert_to_tvm_func(object pyfunc):
cdef inline int make_arg(object arg,
TVMValue* value,
int* tcode,
ObjectRef* temp_objs,
list temp_args) except -1:
"""Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr

if isinstance(arg, ObjectBase):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if
not (<NDArrayBase>arg).c_is_view else kArrayHandle)
elif isinstance(arg, tuple):
temp_objs[0] = convert_tuple(<tuple>arg)
value[0].v_handle = (<void*>(temp_objs[0].get()))
tcode[0] = kObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
Expand Down Expand Up @@ -221,10 +228,11 @@ cdef inline int FuncCall3(void* chandle,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef int[3] tcodes
cdef ObjectRef[3] temp_objs
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], &temp_objs[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
Expand All @@ -241,11 +249,14 @@ cdef inline int FuncCall(void* chandle,

cdef vector[TVMValue] values
cdef vector[int] tcodes
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
cdef vector[ObjectRef] temp_objs
values.resize(nargs)
tcodes.resize(nargs)
temp_objs.resize(nargs)

temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], &temp_objs[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
Expand Down

0 comments on commit ddd9323

Please sign in to comment.