Permalink
Browse files

ENH: misc: Don't use plain assert for argument validation.

  • Loading branch information...
1 parent 0a5ca08 commit c2bb0e70e3abf3ac7ff06b6c4344b9809a2306e9 warren.weckesser committed Jan 29, 2011
Showing with 22 additions and 12 deletions.
  1. +12 −6 scipy/misc/common.py
  2. +10 −6 scipy/misc/pilutil.py
View
@@ -217,7 +217,7 @@ def comb(N,k,exact=0):
sv = special.errprint(sv)
return where(cond, vals, 0.0)
-def central_diff_weights(Np,ndiv=1):
+def central_diff_weights(Np, ndiv=1):
"""
Return weights for an Np-point central derivative of order ndiv
assuming equally-spaced function points.
@@ -230,8 +230,10 @@ def central_diff_weights(Np,ndiv=1):
Can be inaccurate for large number of points.
"""
- assert (Np >= ndiv+1), "Number of points must be at least the derivative order + 1."
- assert (Np % 2 == 1), "Odd-number of points only."
+ if Np < ndiv + 1:
+ raise ValueError("Number of points must be at least the derivative order + 1.")
+ if Np % 2 == 0:
+ raise ValueError("The number of points must be odd.")
from scipy import linalg
ho = Np >> 1
x = arange(-ho,ho+1.0)
@@ -242,7 +244,7 @@ def central_diff_weights(Np,ndiv=1):
w = product(arange(1,ndiv+1),axis=0)*linalg.inv(X)[ndiv]
return w
-def derivative(func,x0,dx=1.0,n=1,args=(),order=3):
+def derivative(func, x0, dx=1.0, n=1, args=(), order=3):
"""
Find the n-th derivative of a function at point x0.
@@ -277,8 +279,12 @@ def derivative(func,x0,dx=1.0,n=1,args=(),order=3):
4.0
"""
- assert (order >= n+1), "Number of points must be at least the derivative order + 1."
- assert (order % 2 == 1), "Odd number of points only."
+ if order < n + 1:
+ raise ValueError("'order' (the number of points used to compute the derivative), "
+ "must be at least the derivative order 'n' + 1.")
+ if order % 2 == 0:
+ raise ValueError("'order' (the number of points used to compute the derivative) "
+ "must be odd.")
# pre-computed for n=1 and 2 and low-order for speed.
if n==1:
if order == 3:
View
@@ -147,8 +147,9 @@ def fromimage(im, flatten=0):
return array(im)
_errstr = "Mode is unknown or incompatible with input array shape."
-def toimage(arr,high=255,low=0,cmin=None,cmax=None,pal=None,
- mode=None,channel_axis=None):
+
+def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None,
+ mode=None, channel_axis=None):
"""Takes a numpy array and returns a PIL image. The mode of the
PIL image depends on the array shape, the pal keyword, and the mode
keyword.
@@ -171,7 +172,8 @@ def toimage(arr,high=255,low=0,cmin=None,cmax=None,pal=None,
shape = list(data.shape)
valid = len(shape)==2 or ((len(shape)==3) and \
((3 in shape) or (4 in shape)))
- assert valid, "Not a suitable array shape for any mode."
+ if not valid:
+ raise ValueError("'arr' does not have a suitable array shape for any mode.")
if len(shape) == 2:
shape = (shape[1],shape[0]) # columns show up first
if mode == 'F':
@@ -242,11 +244,13 @@ def toimage(arr,high=255,low=0,cmin=None,cmax=None,pal=None,
raise ValueError(_errstr)
if mode in ['RGB', 'YCbCr']:
- assert numch == 3, "Invalid array shape for mode."
+ if numch != 3:
+ raise ValueError("Invalid array shape for mode.")
if mode in ['RGBA', 'CMYK']:
- assert numch == 4, "Invalid array shape for mode."
+ if numch != 4:
+ raise ValueError("Invalid array shape for mode.")
- # Here we know data and mode is coorect
+ # Here we know data and mode is correct
image = Image.fromstring(mode, shape, strdata)
return image

0 comments on commit c2bb0e7

Please sign in to comment.