Skip to content

Commit

Permalink
ENH: special/generate_ufuncs: check at compile-time that function sig…
Browse files Browse the repository at this point in the history
…nature is what was assumed
  • Loading branch information
pv committed Oct 8, 2012
1 parent d776abc commit 0dfe2b6
Showing 1 changed file with 54 additions and 15 deletions.
69 changes: 54 additions & 15 deletions scipy/special/generate_ufuncs.py
Expand Up @@ -329,7 +329,7 @@ def errprint(inflag=None):
import textwrap
import add_newdocs

C_TYPES = {
CY_TYPES = {
'f': 'float',
'd': 'double',
'g': 'long double',
Expand All @@ -341,6 +341,18 @@ def errprint(inflag=None):
'v': 'void',
}

C_TYPES = {
'f': 'np.npy_float',
'd': 'np.npy_double',
'g': 'np.npy_longdouble',
'F': 'np.npy_cfloat',
'D': 'np.npy_cdouble',
'G': 'np.npy_clongdouble',
'i': 'np.npy_int',
'l': 'np.npy_long',
'v': 'void',
}

TYPE_NAMES = {
'f': 'np.NPY_FLOAT',
'd': 'np.NPY_DOUBLE',
Expand All @@ -366,8 +378,8 @@ def generate_loop(func_inputs, func_outputs, func_retval,
func_inputs, func_outputs, func_retval : str
Signature of the function to call, given as type codes of the
input, output and return value arguments. These 1-character
codes are given according to the C_TYPES and TYPE_NAMES lists
above.
codes are given according to the CY_TYPES and TYPE_NAMES
lists above.
The corresponding C function signature to be called is:
Expand Down Expand Up @@ -415,19 +427,21 @@ def generate_loop(func_inputs, func_outputs, func_retval,
fvars = []
outtypecodes = []
for j in range(len(func_inputs)):
ftypes.append(C_TYPES[func_inputs[j]])
fvars.append("<%s>(<%s*>ip%d)[0]" % (C_TYPES[func_inputs[j]], C_TYPES[ufunc_inputs[j]], j))
ftypes.append(CY_TYPES[func_inputs[j]])
fvars.append("<%s>(<%s*>ip%d)[0]" % (
CY_TYPES[func_inputs[j]],
CY_TYPES[ufunc_inputs[j]], j))

if len(func_outputs)+1 == len(ufunc_outputs):
func_joff = 1
outtypecodes.append(func_retval)
body += " cdef %s ov0\n" % (C_TYPES[func_retval],)
body += " cdef %s ov0\n" % (CY_TYPES[func_retval],)
else:
func_joff = 0

for j, outtype in enumerate(func_outputs):
body += " cdef %s ov%d\n" % (C_TYPES[outtype], j+func_joff)
ftypes.append("%s *" % C_TYPES[outtype])
body += " cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff)
ftypes.append("%s *" % CY_TYPES[outtype])
fvars.append("&ov%d" % (j+func_joff))
outtypecodes.append(outtype)

Expand All @@ -438,11 +452,12 @@ def generate_loop(func_inputs, func_outputs, func_retval,
rv = ""

body += " %s(<%s(*)(%s) nogil>func)(%s)\n" % (
rv, C_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars))
rv, CY_TYPES[func_retval],
", ".join(ftypes), ", ".join(fvars))

for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)):
body += " (<%s *>op%d)[0] = <%s>ov%d\n" % (
C_TYPES[outtype], j, C_TYPES[outtype], j)
CY_TYPES[outtype], j, CY_TYPES[outtype], j)
for j in range(len(ufunc_inputs)):
body += " ip%d += steps[%d]\n" % (j, j)
for j in range(len(ufunc_outputs)):
Expand Down Expand Up @@ -551,6 +566,19 @@ def add_variant(func_name, inarg, outarg, ret, inp, outp):
add_variant(func_name, inarg, outarg, ret, inp2, outp2)
return variants, inarg_num, outarg_num

def get_prototypes(self):
prototypes = []
for func_name, inarg, outarg, ret in self.signatures:
ret = ret.replace('*', '')
c_args = ([C_TYPES[x] for x in inarg]
+ [C_TYPES[x] + ' *' for x in outarg])
cy_args = ([CY_TYPES[x] for x in inarg]
+ [CY_TYPES[x] + ' *' for x in outarg])
c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args))
cy_proto = "%s (*)(%s)" % (CY_TYPES[ret], ", ".join(cy_args))
prototypes.append((func_name, c_proto, cy_proto))
return prototypes

def generate(self, all_loops):
toplevel = ""

Expand Down Expand Up @@ -589,7 +617,7 @@ def generate(self, all_loops):
inarg_num, outarg_num)
).replace('@', self.name)

return toplevel, list(set(datas))
return toplevel

def generate(filename, ufunc_str, extra_code):
ufuncs = []
Expand All @@ -615,22 +643,33 @@ def generate(filename, ufunc_str, extra_code):

ufuncs.sort(key=lambda u: u.name)
for ufunc in ufuncs:
t, cfuncs = ufunc.generate(all_loops)
t = ufunc.generate(all_loops)
toplevel += t + "\n"

cfuncs = ufunc.get_prototypes()

hdrs = headers.get(ufunc.name, ['cephes.h'])
if len(hdrs) == 1:
hdrs = [hdrs[0]] * len(cfuncs)
elif len(hdrs) != len(cfuncs):
raise ValueError("%s: wrong number of headers" % ufunc.name)

for cfunc, header in zip(cfuncs, hdrs):
for (c_name, c_proto, cy_proto), header in zip(cfuncs, hdrs):
if header.endswith('.pxd'):
defs += "from %s cimport %s as _func_%s\n" % (header[:-4], cfunc, cfunc)
defs += "from %s cimport %s as _func_%s\n" % (header[:-4], c_name, c_name)

# check function signature at compile time
proto_name = '_proto_%s_t' % c_name
defs += "ctypedef %s\n" % (cy_proto.replace('(*)', proto_name))
defs += "cdef %s *%s_var\n" % (proto_name, proto_name)
defs += "%s_var = &_func_%s\n" % (proto_name, c_name)
else:
# redeclare the function, so that the assumed
# signature is checked at compile time
defs += "cdef extern from \"%s\":\n" % header
defs += " void _func_%s \"%s\"()\n" % (cfunc, cfunc)
defs += " pass\n"
new_name = "_func_%s \"%s\"" % (c_name, c_name)
defs += "cdef extern %s\n" % (c_proto.replace('(*)', new_name))

toplevel = "\n".join(all_loops.values() + [defs, toplevel])

Expand Down

0 comments on commit 0dfe2b6

Please sign in to comment.