Skip to content

Commit

Permalink
feature(cstruct) (#110)
Browse files Browse the repository at this point in the history
* add cstruct

* try to fix

* unify the types of functions

* update

* fix loader

* add MOBULA_FUNC template

* fix

* type
  • Loading branch information
wkcn committed Aug 6, 2020
1 parent 6b6abbb commit df25b88
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 75 deletions.
32 changes: 26 additions & 6 deletions mobula/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib
import warnings
from . import glue
from .internal.dtype import DType, TemplateType, UnknownCType
from .internal.dtype import DType, CStruct, TemplateType, UnknownCType
from .building.build_utils import config


Expand Down Expand Up @@ -61,6 +61,11 @@ def is_const(self):
return self.ptype.is_const


class CStructArg:
def __init__(self, var):
self.var = var


def _wait_to_read(var):
if hasattr(var, 'wait_to_read'):
var.wait_to_read()
Expand All @@ -81,6 +86,9 @@ def _get_raw_pointer(arg, const_vars, mutable_vars):
else:
mutable_vars.append((arg.var, v))
return p
if isinstance(arg, CStructArg):
const_vars.append(arg.var)
return ctypes.byref(const_vars[-1])
return arg


Expand Down Expand Up @@ -144,10 +152,13 @@ def __call__(self, arg_datas, arg_types, dev_id, glue_mod=None, using_async=Fals
const_vars = []
mutable_vars = []
raw_pointers = _get_raw_pointers(arg_datas, const_vars, mutable_vars)
if self.func_kind == self.KERNEL:
if self.func_kind == CFuncDef.KERNEL:
out = func(dev_id, *raw_pointers)
else:
elif self.func_kind == CFuncDef.FUNC:
out = func(*raw_pointers)
else:
raise TypeError(
'Unsupported func kind: {}'.format(self.func_kind))
for target, value in mutable_vars:
target[:] = value
return out
Expand Down Expand Up @@ -202,9 +213,17 @@ def __call__(self, *args, **kwargs):
try:
for var, ptype in zip(args, self.func.arg_types):
if ptype.is_pointer:
# The type of `var` is Tensor.
data, var_dev_id, ctype = self._get_tensor_info(
var, ptype, template_mapping, using_async)
if hasattr(ptype, 'constructor'):
var_dev_id = None
ctype = ctypes.POINTER(ptype.cstruct)
try:
data = CStructArg(ptype.constructor(var))
except TypeError:
data = CStructArg(ptype.constructor(*var))
else:
# The type of `var` is Tensor.
data, var_dev_id, ctype = self._get_tensor_info(
var, ptype, template_mapping, using_async)
else:
# The type of `var` is Scalar.
data, var_dev_id, ctype = self._get_scalar_info(var, ptype)
Expand All @@ -214,6 +233,7 @@ def __call__(self, *args, **kwargs):
ctype.is_const = ptype.is_const
arg_types.append(ctype)
else:
# pointer
arg_types.append(DType(ctype, is_const=ptype.is_const))

# update `dev_id`
Expand Down
2 changes: 1 addition & 1 deletion mobula/glue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import backend
from .common import register, CUSTOM_OP_LIST
from .common import register, register_cstruct, CUSTOM_OP_LIST
from . import common
common.backend = backend
9 changes: 9 additions & 0 deletions mobula/glue/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_varnames(func):

CUSTOM_OP_LIST = dict()
OP_MODULE_GLOBALS = None
CSTRUCT_CONSTRUCTOR = dict()


def get_in_data(*args, **kwargs):
Expand Down Expand Up @@ -193,6 +194,14 @@ def wrapper(*args, **kwargs):
return wrapper


def register_cstruct(name, cstruct, constructor=None):
if constructor is None:
constructor = cstruct
assert callable(
constructor), 'constructor {} should be callable'.format(name)
CSTRUCT_CONSTRUCTOR[name] = (cstruct, constructor)


def register(op_name=None, **attrs):
"""Regiseter a custom operator
1. @register
Expand Down
35 changes: 34 additions & 1 deletion mobula/internal/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
CTYPE_INTS = [ctypes.c_short, ctypes.c_int, ctypes.c_long, ctypes.c_longlong]
CTYPE_UINTS = [ctypes.c_ushort, ctypes.c_uint,
ctypes.c_ulong, ctypes.c_ulonglong]
CTYPENAME2CTYPE = {
'bool': ctypes.c_bool,
'char': ctypes.c_char,
'char*': ctypes.c_char_p,
'double': ctypes.c_double,
'float': ctypes.c_float,
'int': ctypes.c_int,
'int8_t': ctypes.c_int8,
'int16_t': ctypes.c_int16,
'int32_t': ctypes.c_int32,
'int64_t': ctypes.c_int64,
'long': ctypes.c_long,
'longlong': ctypes.c_longlong,
'short': ctypes.c_short,
'void*': ctypes.c_void_p,
'void': None,
None: None,
}


def get_ctype_name(ctype):
Expand All @@ -11,7 +29,10 @@ def get_ctype_name(ctype):
return 'int{}_t'.format(ctypes.sizeof(ctype) * 8)
if ctype in CTYPE_UINTS[2:]:
return 'uint{}_t'.format(ctypes.sizeof(ctype) * 8)
return ctype.__name__[2:]
name = ctype.__name__
if name.startswith('c_'):
name = name[2:]
return name


class DType:
Expand Down Expand Up @@ -50,6 +71,18 @@ def __call__(self, value):
return self.ctype(value)


class CStruct:
def __init__(self, name, is_const, cstruct, constructor):
self.cname = name + '*'
self.cstruct = cstruct
self.constructor = constructor
self.is_pointer = True
self.is_const = is_const

def __call__(self, *args, **kwargs):
return self.constructor(*args, **kwargs)


class UnknownCType:
def __init__(self, tname):
self.tname = tname
Expand Down
98 changes: 34 additions & 64 deletions mobula/op/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from ..func import CFuncDef, bind, get_func_idcode, get_idcode_hash
from ..building.build import source_to_so_ctx, build_context, file_is_changed, ENV_PATH
from ..utils import get_git_hash, makedirs
from ..internal.dtype import DType, TemplateType
from ..internal.dtype import DType, CStruct, TemplateType, CTYPENAME2CTYPE
from ..version import OP_LOAD_MODULE_BUILD_VERSION
from ..glue.common import CSTRUCT_CONSTRUCTOR
from ..glue.backend import get_glue_modules
from .gen_code import get_gen_rel_code

Expand Down Expand Up @@ -142,6 +143,10 @@ def parse_parameter_decl(decl):
ctype = ctypes.POINTER(ctype)
return DType(ctype, is_const=is_const), var_name

if type_name in CSTRUCT_CONSTRUCTOR:
info = CSTRUCT_CONSTRUCTOR[type_name]
return CStruct(type_name, is_const=is_const, cstruct=info[0], constructor=info[1]), var_name

# template type
return TemplateType(tname=type_name, is_pointer=is_pointer, is_const=is_const), var_name

Expand Down Expand Up @@ -305,39 +310,16 @@ def _generate_func_code(func_idcode_hash, rtn_type, arg_types, arg_names, func_n
) for dtype, name in zip(arg_types, arg_names)])
args_inst = ', '.join(arg_names)

code = '''
MOBULA_DLL %s %s(%s) {
''' % (rtn_type, func_idcode_hash, args_def)
if rtn_type != 'void':
code += ' return '
code += '%s(%s);\n}\n' % (func_name, args_inst)
return code

code = gen_code('./templates/func_code.cpp')(
return_value=rtn_type,
return_statement='' if rtn_type == 'void' else 'return',
func_idcode_hash=func_idcode_hash,
args_def=args_def,
func_name=func_name,
args_inst=args_inst,
)

def _generate_ordinary_code(cpp_info):
code_buffer = ''
# generate ordinary functions code
for func_name, ord_cfunc in cpp_info.function_args.items():
if ord_cfunc.template_list:
continue
func_idcode = get_func_idcode(func_name, ord_cfunc.arg_types)
func_idcode_hash = get_idcode_hash(func_idcode)
func_kind = ord_cfunc.func_kind
if func_kind == CFuncDef.KERNEL:
code_buffer += _generate_kernel_code(
func_idcode_hash, ord_cfunc.arg_types, ord_cfunc.arg_names, '{}_kernel'.format(func_name))
code_buffer += '\n'
return code_buffer


def _get_ordinary_functions(cpp_info):
res = list()
for func_name, ord_cfunc in cpp_info.function_args.items():
if ord_cfunc.template_list:
continue
func_idcode = get_func_idcode(func_name, ord_cfunc.arg_types)
res.append(func_idcode)
return res
return code


def _update_template_inst_map(idcode, template_functions, cfunc, arg_types):
Expand All @@ -364,7 +346,8 @@ def _update_template_inst_map(idcode, template_functions, cfunc, arg_types):

template_inst = [template_mapping[tname]
for tname in cfunc.template_list]
template_post = '<%s>' % (', '.join(template_inst))
template_post = '<%s>' % (', '.join(template_inst)
) if template_inst else ''
rtn_type = cfunc.rtn_type
if rtn_type in template_mapping:
rtn_type = template_mapping[rtn_type]
Expand All @@ -376,15 +359,18 @@ def _update_template_inst_map(idcode, template_functions, cfunc, arg_types):
else:
code = _generate_func_code(
func_idcode_hash, rtn_type, arg_types, cfunc.arg_names, func_name + template_post)
template_functions[idcode] = code
template_functions[idcode] = (code, rtn_type)


def _add_function(func_map, func_idcode, cpp_info, dll_fname):
def _add_function(func_map, func_idcode, rtn_type, cpp_info, dll_fname):
func_idcode_hash = get_idcode_hash(func_idcode)
func = getattr(cpp_info.dll, func_idcode_hash, None)
assert func is not None,\
Exception('No function `{}` in DLL {}'.format(
func_idcode, dll_fname))
func.restype = CTYPENAME2CTYPE[rtn_type]
if func is None:
functions = [name for name in dir(
cpp_info.dll) if not name.startswith('_')]
raise NameError('No function `{}` in DLL {}, current functions: {}'.format(
func_idcode, dll_fname, functions))

old_func = func_map.get(func_idcode, None)
if old_func is not None:
Expand Down Expand Up @@ -459,14 +445,10 @@ def __init__(self, cfunc, arg_types, ctx, cpp_info):
is_old_version = map_data.get(
'version') < OP_LOAD_MODULE_BUILD_VERSION
# load the information of template functions
ORDINARY_FUNCTION_NAME = 'ordinary_functions'
TEMPLATE_FUNCTION_NAME = 'template_functions'
TEMPLATE_FUNCTION_NAME = 'functions'
if is_old_version:
ordinary_functions = list()
template_functions = dict()
else:
ordinary_functions = map_data.get(
ORDINARY_FUNCTION_NAME, list())
template_functions = map_data.get(
TEMPLATE_FUNCTION_NAME, dict())

Expand All @@ -479,7 +461,7 @@ def __init__(self, cfunc, arg_types, ctx, cpp_info):

file_changed = file_is_changed(cpp_fname)
dll_existed = os.path.exists(dll_fname)
func_existed = idcode in template_functions or idcode in ordinary_functions
func_existed = idcode in template_functions

if file_changed or not dll_existed or not func_existed or is_old_version:
# Rebuild DLL file
Expand All @@ -503,14 +485,12 @@ def __init__(self, cfunc, arg_types, ctx, cpp_info):
build_id += 1
dll_fname = dll_fname_format.format(build_id=build_id)
# build code
code_buffer = _generate_ordinary_code(cpp_info)
ordinary_functions = _get_ordinary_functions(cpp_info)
if use_template:
if idcode not in template_functions:
_update_template_inst_map(
idcode, template_functions, cfunc, arg_types)
# add template instances code into code_buffer
code_buffer += ''.join(template_functions.values())
if idcode not in template_functions:
_update_template_inst_map(
idcode, template_functions, cfunc, arg_types)
# collects template instances code into code_buffer
code_buffer = ''.join([v[0]
for v in template_functions.values()])

with build_context():
try:
Expand All @@ -522,7 +502,6 @@ def __init__(self, cfunc, arg_types, ctx, cpp_info):
# update template_functions
map_data = dict(version=OP_LOAD_MODULE_BUILD_VERSION,
build_id=build_id)
map_data[ORDINARY_FUNCTION_NAME] = ordinary_functions
map_data[TEMPLATE_FUNCTION_NAME] = template_functions
# clear the old context and write json data
build_info_fs.seek(0)
Expand All @@ -536,18 +515,9 @@ def __init__(self, cfunc, arg_types, ctx, cpp_info):
cpp_info.load_dll(dll_fname)

# import all functions
# ordinary functions
for func_name, ord_cfunc in cpp_info.function_args.items():
if not ord_cfunc.template_list:
func_idcode = get_func_idcode(
func_name, ord_cfunc.arg_types)
_add_function(func_map,
func_idcode, cpp_info, dll_fname)

# template functions
for func_idcode in template_functions.keys():
_add_function(func_map,
func_idcode, cpp_info, dll_fname)
func_idcode, template_functions[func_idcode][1], cpp_info, dll_fname)

self.func = func_map[idcode].func
self.cpp_info = func_map[idcode].cpp_info
Expand Down
3 changes: 3 additions & 0 deletions mobula/op/templates/func_code.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MOBULA_DLL ${return_value} ${func_idcode_hash}(${args_def}) {
${return_statement} ${func_name}(${args_inst});
}
3 changes: 3 additions & 0 deletions mobula/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def check_value(data, other):
# If the shapes don't match, raise AssertionError and print the shapes
assert a.shape == b.shape,\
AssertionError('Unmatched Shape: {} vs {}'.format(a.shape, b.shape))
if len(a.shape) == 0:
a = a.reshape((1, ))
b = b.reshape((1, ))

# Compute Absolute Error |a - b|
error = a - b
Expand Down
2 changes: 1 addition & 1 deletion mobula/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""version information"""
__version__ = 2.32
__version__ = 2.4

OP_LOAD_MODULE_BUILD_VERSION = __version__
9 changes: 9 additions & 0 deletions tests/test_building/MyStruct/MyStruct.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
struct MyStruct {
int hello;
float mobula;
};

MOBULA_FUNC float hello(MyStruct *hi) {
LOG(INFO) << "Hello Mobula: " << hi->hello << ", " << hi->mobula;
return hi->hello + hi->mobula;
}
25 changes: 25 additions & 0 deletions tests/test_building/test_building.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import ctypes
import os

import mobula
from mobula.testing import assert_almost_equal, gradcheck

# [TODO] change BUILD_PATH


def test_custom_struct():
class MyStruct(ctypes.Structure):
_fields_ = [
('hello', ctypes.c_int),
('mobula', ctypes.c_float),
]

mobula.glue.register_cstruct('MyStruct', MyStruct)
mobula.op.load('MyStruct', os.path.dirname(__file__))

res = mobula.func.hello((42, 39))
assert_almost_equal(res, 42 + 39)


if __name__ == '__main__':
test_custom_struct()
Loading

0 comments on commit df25b88

Please sign in to comment.