Skip to content

Commit

Permalink
ENH: special/generate_ufuncs: check at compile-time that function sig…
Browse files Browse the repository at this point in the history
…nature is what was assumed
  • Loading branch information
pv committed Oct 8, 2012
1 parent d776abc commit 0dfe2b6
Showing 1 changed file with 54 additions and 15 deletions.
69 changes: 54 additions & 15 deletions scipy/special/generate_ufuncs.py
Expand Up @@ -329,7 +329,7 @@ def errprint(inflag=None):
import textwrap import textwrap
import add_newdocs import add_newdocs


C_TYPES = { CY_TYPES = {
'f': 'float', 'f': 'float',
'd': 'double', 'd': 'double',
'g': 'long double', 'g': 'long double',
Expand All @@ -341,6 +341,18 @@ def errprint(inflag=None):
'v': 'void', 'v': 'void',
} }


C_TYPES = {
'f': 'np.npy_float',
'd': 'np.npy_double',
'g': 'np.npy_longdouble',
'F': 'np.npy_cfloat',
'D': 'np.npy_cdouble',
'G': 'np.npy_clongdouble',
'i': 'np.npy_int',
'l': 'np.npy_long',
'v': 'void',
}

TYPE_NAMES = { TYPE_NAMES = {
'f': 'np.NPY_FLOAT', 'f': 'np.NPY_FLOAT',
'd': 'np.NPY_DOUBLE', 'd': 'np.NPY_DOUBLE',
Expand All @@ -366,8 +378,8 @@ def generate_loop(func_inputs, func_outputs, func_retval,
func_inputs, func_outputs, func_retval : str func_inputs, func_outputs, func_retval : str
Signature of the function to call, given as type codes of the Signature of the function to call, given as type codes of the
input, output and return value arguments. These 1-character input, output and return value arguments. These 1-character
codes are given according to the C_TYPES and TYPE_NAMES lists codes are given according to the CY_TYPES and TYPE_NAMES
above. lists above.
The corresponding C function signature to be called is: The corresponding C function signature to be called is:
Expand Down Expand Up @@ -415,19 +427,21 @@ def generate_loop(func_inputs, func_outputs, func_retval,
fvars = [] fvars = []
outtypecodes = [] outtypecodes = []
for j in range(len(func_inputs)): for j in range(len(func_inputs)):
ftypes.append(C_TYPES[func_inputs[j]]) ftypes.append(CY_TYPES[func_inputs[j]])
fvars.append("<%s>(<%s*>ip%d)[0]" % (C_TYPES[func_inputs[j]], C_TYPES[ufunc_inputs[j]], j)) fvars.append("<%s>(<%s*>ip%d)[0]" % (
CY_TYPES[func_inputs[j]],
CY_TYPES[ufunc_inputs[j]], j))


if len(func_outputs)+1 == len(ufunc_outputs): if len(func_outputs)+1 == len(ufunc_outputs):
func_joff = 1 func_joff = 1
outtypecodes.append(func_retval) outtypecodes.append(func_retval)
body += " cdef %s ov0\n" % (C_TYPES[func_retval],) body += " cdef %s ov0\n" % (CY_TYPES[func_retval],)
else: else:
func_joff = 0 func_joff = 0


for j, outtype in enumerate(func_outputs): for j, outtype in enumerate(func_outputs):
body += " cdef %s ov%d\n" % (C_TYPES[outtype], j+func_joff) body += " cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff)
ftypes.append("%s *" % C_TYPES[outtype]) ftypes.append("%s *" % CY_TYPES[outtype])
fvars.append("&ov%d" % (j+func_joff)) fvars.append("&ov%d" % (j+func_joff))
outtypecodes.append(outtype) outtypecodes.append(outtype)


Expand All @@ -438,11 +452,12 @@ def generate_loop(func_inputs, func_outputs, func_retval,
rv = "" rv = ""


body += " %s(<%s(*)(%s) nogil>func)(%s)\n" % ( body += " %s(<%s(*)(%s) nogil>func)(%s)\n" % (
rv, C_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars)) rv, CY_TYPES[func_retval],
", ".join(ftypes), ", ".join(fvars))


for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)): for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)):
body += " (<%s *>op%d)[0] = <%s>ov%d\n" % ( body += " (<%s *>op%d)[0] = <%s>ov%d\n" % (
C_TYPES[outtype], j, C_TYPES[outtype], j) CY_TYPES[outtype], j, CY_TYPES[outtype], j)
for j in range(len(ufunc_inputs)): for j in range(len(ufunc_inputs)):
body += " ip%d += steps[%d]\n" % (j, j) body += " ip%d += steps[%d]\n" % (j, j)
for j in range(len(ufunc_outputs)): for j in range(len(ufunc_outputs)):
Expand Down Expand Up @@ -551,6 +566,19 @@ def add_variant(func_name, inarg, outarg, ret, inp, outp):
add_variant(func_name, inarg, outarg, ret, inp2, outp2) add_variant(func_name, inarg, outarg, ret, inp2, outp2)
return variants, inarg_num, outarg_num return variants, inarg_num, outarg_num


def get_prototypes(self):
prototypes = []
for func_name, inarg, outarg, ret in self.signatures:
ret = ret.replace('*', '')
c_args = ([C_TYPES[x] for x in inarg]
+ [C_TYPES[x] + ' *' for x in outarg])
cy_args = ([CY_TYPES[x] for x in inarg]
+ [CY_TYPES[x] + ' *' for x in outarg])
c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args))
cy_proto = "%s (*)(%s)" % (CY_TYPES[ret], ", ".join(cy_args))
prototypes.append((func_name, c_proto, cy_proto))
return prototypes

def generate(self, all_loops): def generate(self, all_loops):
toplevel = "" toplevel = ""


Expand Down Expand Up @@ -589,7 +617,7 @@ def generate(self, all_loops):
inarg_num, outarg_num) inarg_num, outarg_num)
).replace('@', self.name) ).replace('@', self.name)


return toplevel, list(set(datas)) return toplevel


def generate(filename, ufunc_str, extra_code): def generate(filename, ufunc_str, extra_code):
ufuncs = [] ufuncs = []
Expand All @@ -615,22 +643,33 @@ def generate(filename, ufunc_str, extra_code):


ufuncs.sort(key=lambda u: u.name) ufuncs.sort(key=lambda u: u.name)
for ufunc in ufuncs: for ufunc in ufuncs:
t, cfuncs = ufunc.generate(all_loops) t = ufunc.generate(all_loops)
toplevel += t + "\n" toplevel += t + "\n"


cfuncs = ufunc.get_prototypes()


hdrs = headers.get(ufunc.name, ['cephes.h']) hdrs = headers.get(ufunc.name, ['cephes.h'])
if len(hdrs) == 1: if len(hdrs) == 1:
hdrs = [hdrs[0]] * len(cfuncs) hdrs = [hdrs[0]] * len(cfuncs)
elif len(hdrs) != len(cfuncs): elif len(hdrs) != len(cfuncs):
raise ValueError("%s: wrong number of headers" % ufunc.name) raise ValueError("%s: wrong number of headers" % ufunc.name)


for cfunc, header in zip(cfuncs, hdrs): for (c_name, c_proto, cy_proto), header in zip(cfuncs, hdrs):
if header.endswith('.pxd'): if header.endswith('.pxd'):
defs += "from %s cimport %s as _func_%s\n" % (header[:-4], cfunc, cfunc) defs += "from %s cimport %s as _func_%s\n" % (header[:-4], c_name, c_name)

# check function signature at compile time
proto_name = '_proto_%s_t' % c_name
defs += "ctypedef %s\n" % (cy_proto.replace('(*)', proto_name))
defs += "cdef %s *%s_var\n" % (proto_name, proto_name)
defs += "%s_var = &_func_%s\n" % (proto_name, c_name)
else: else:
# redeclare the function, so that the assumed
# signature is checked at compile time
defs += "cdef extern from \"%s\":\n" % header defs += "cdef extern from \"%s\":\n" % header
defs += " void _func_%s \"%s\"()\n" % (cfunc, cfunc) defs += " pass\n"
new_name = "_func_%s \"%s\"" % (c_name, c_name)
defs += "cdef extern %s\n" % (c_proto.replace('(*)', new_name))


toplevel = "\n".join(all_loops.values() + [defs, toplevel]) toplevel = "\n".join(all_loops.values() + [defs, toplevel])


Expand Down

0 comments on commit 0dfe2b6

Please sign in to comment.