Browse files

ENH: special/generate_ufuncs: check at compile-time that function sig…

…nature is what was assumed
  • Loading branch information...
1 parent d776abc commit 0dfe2b67028e1eaf31c2bd10514e895f62c9a40c @pv committed Oct 8, 2012
Showing with 54 additions and 15 deletions.
  1. +54 −15 scipy/special/generate_ufuncs.py
View
69 scipy/special/generate_ufuncs.py
@@ -329,7 +329,7 @@ def errprint(inflag=None):
import textwrap
import add_newdocs
-C_TYPES = {
+CY_TYPES = {
'f': 'float',
'd': 'double',
'g': 'long double',
@@ -341,6 +341,18 @@ def errprint(inflag=None):
'v': 'void',
}
+C_TYPES = {
+ 'f': 'np.npy_float',
+ 'd': 'np.npy_double',
+ 'g': 'np.npy_longdouble',
+ 'F': 'np.npy_cfloat',
+ 'D': 'np.npy_cdouble',
+ 'G': 'np.npy_clongdouble',
+ 'i': 'np.npy_int',
+ 'l': 'np.npy_long',
+ 'v': 'void',
+}
+
TYPE_NAMES = {
'f': 'np.NPY_FLOAT',
'd': 'np.NPY_DOUBLE',
@@ -366,8 +378,8 @@ def generate_loop(func_inputs, func_outputs, func_retval,
func_inputs, func_outputs, func_retval : str
Signature of the function to call, given as type codes of the
input, output and return value arguments. These 1-character
- codes are given according to the C_TYPES and TYPE_NAMES lists
- above.
+ codes are given according to the CY_TYPES and TYPE_NAMES
+ lists above.
The corresponding C function signature to be called is:
@@ -415,19 +427,21 @@ def generate_loop(func_inputs, func_outputs, func_retval,
fvars = []
outtypecodes = []
for j in range(len(func_inputs)):
- ftypes.append(C_TYPES[func_inputs[j]])
- fvars.append("<%s>(<%s*>ip%d)[0]" % (C_TYPES[func_inputs[j]], C_TYPES[ufunc_inputs[j]], j))
+ ftypes.append(CY_TYPES[func_inputs[j]])
+ fvars.append("<%s>(<%s*>ip%d)[0]" % (
+ CY_TYPES[func_inputs[j]],
+ CY_TYPES[ufunc_inputs[j]], j))
if len(func_outputs)+1 == len(ufunc_outputs):
func_joff = 1
outtypecodes.append(func_retval)
- body += " cdef %s ov0\n" % (C_TYPES[func_retval],)
+ body += " cdef %s ov0\n" % (CY_TYPES[func_retval],)
else:
func_joff = 0
for j, outtype in enumerate(func_outputs):
- body += " cdef %s ov%d\n" % (C_TYPES[outtype], j+func_joff)
- ftypes.append("%s *" % C_TYPES[outtype])
+ body += " cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff)
+ ftypes.append("%s *" % CY_TYPES[outtype])
fvars.append("&ov%d" % (j+func_joff))
outtypecodes.append(outtype)
@@ -438,11 +452,12 @@ def generate_loop(func_inputs, func_outputs, func_retval,
rv = ""
body += " %s(<%s(*)(%s) nogil>func)(%s)\n" % (
- rv, C_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars))
+ rv, CY_TYPES[func_retval],
+ ", ".join(ftypes), ", ".join(fvars))
for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)):
body += " (<%s *>op%d)[0] = <%s>ov%d\n" % (
- C_TYPES[outtype], j, C_TYPES[outtype], j)
+ CY_TYPES[outtype], j, CY_TYPES[outtype], j)
for j in range(len(ufunc_inputs)):
body += " ip%d += steps[%d]\n" % (j, j)
for j in range(len(ufunc_outputs)):
@@ -551,6 +566,19 @@ def add_variant(func_name, inarg, outarg, ret, inp, outp):
add_variant(func_name, inarg, outarg, ret, inp2, outp2)
return variants, inarg_num, outarg_num
+ def get_prototypes(self):
+ prototypes = []
+ for func_name, inarg, outarg, ret in self.signatures:
+ ret = ret.replace('*', '')
+ c_args = ([C_TYPES[x] for x in inarg]
+ + [C_TYPES[x] + ' *' for x in outarg])
+ cy_args = ([CY_TYPES[x] for x in inarg]
+ + [CY_TYPES[x] + ' *' for x in outarg])
+ c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args))
+ cy_proto = "%s (*)(%s)" % (CY_TYPES[ret], ", ".join(cy_args))
+ prototypes.append((func_name, c_proto, cy_proto))
+ return prototypes
+
def generate(self, all_loops):
toplevel = ""
@@ -589,7 +617,7 @@ def generate(self, all_loops):
inarg_num, outarg_num)
).replace('@', self.name)
- return toplevel, list(set(datas))
+ return toplevel
def generate(filename, ufunc_str, extra_code):
ufuncs = []
@@ -615,22 +643,33 @@ def generate(filename, ufunc_str, extra_code):
ufuncs.sort(key=lambda u: u.name)
for ufunc in ufuncs:
- t, cfuncs = ufunc.generate(all_loops)
+ t = ufunc.generate(all_loops)
toplevel += t + "\n"
+ cfuncs = ufunc.get_prototypes()
hdrs = headers.get(ufunc.name, ['cephes.h'])
if len(hdrs) == 1:
hdrs = [hdrs[0]] * len(cfuncs)
elif len(hdrs) != len(cfuncs):
raise ValueError("%s: wrong number of headers" % ufunc.name)
- for cfunc, header in zip(cfuncs, hdrs):
+ for (c_name, c_proto, cy_proto), header in zip(cfuncs, hdrs):
if header.endswith('.pxd'):
- defs += "from %s cimport %s as _func_%s\n" % (header[:-4], cfunc, cfunc)
+ defs += "from %s cimport %s as _func_%s\n" % (header[:-4], c_name, c_name)
+
+ # check function signature at compile time
+ proto_name = '_proto_%s_t' % c_name
+ defs += "ctypedef %s\n" % (cy_proto.replace('(*)', proto_name))
+ defs += "cdef %s *%s_var\n" % (proto_name, proto_name)
+ defs += "%s_var = &_func_%s\n" % (proto_name, c_name)
else:
+ # redeclare the function, so that the assumed
+ # signature is checked at compile time
defs += "cdef extern from \"%s\":\n" % header
- defs += " void _func_%s \"%s\"()\n" % (cfunc, cfunc)
+ defs += " pass\n"
+ new_name = "_func_%s \"%s\"" % (c_name, c_name)
+ defs += "cdef extern %s\n" % (c_proto.replace('(*)', new_name))
toplevel = "\n".join(all_loops.values() + [defs, toplevel])

0 comments on commit 0dfe2b6

Please sign in to comment.