Permalink
Browse files

re-enabled dst ortho overwrite tests

  • Loading branch information...
1 parent 0b3562a commit f4c626170dd1c182135ac359a41d98778c7ad833 Matt Terry committed Jan 13, 2012
Showing with 16 additions and 9 deletions.
  1. +13 −5 scipy/fftpack/realtransforms.py
  2. +2 −2 scipy/fftpack/src/dst.c.src
  3. +1 −2 scipy/fftpack/tests/test_real_transforms.py
@@ -301,6 +301,11 @@ def dst(x, type=2, n=None, axis=-1, norm=None, overwrite_x=0):
y[k] = 2* sum x[n]*sin(pi*(k+1)*(n+0.5)/N), 0 <= k < N.
n=0
+ if ``norm='ortho'``, ``y[k]`` is multiplied by a scaling factor `f`::
+
+ f = sqrt(1/(4*N)) if k == 0
+ f = sqrt(1/(2*N)) otherwise.
+
type III
~~~~~~~~
@@ -313,15 +318,17 @@ def dst(x, type=2, n=None, axis=-1, norm=None, overwrite_x=0):
n=0
The (unnormalized) DCT-III is the inverse of the (unnormalized) DCT-II, up
- to a factor `2N`.
+ to a factor `2N`. The orthonormalized DST-III is exactly the inverse of
+ the orthonormalized DST-II.
References
----------
http://en.wikipedia.org/wiki/Discrete_sine_transform
"""
- if norm is not None:
- raise NotImplementedError('DST Orthonormalization not yet implemented')
+ if type == 1 and norm is not None:
+ raise NotImplementedError(
+ "Orthonormalization not yet supported for IDCT-I")
return _dst(x, type, n, axis, normalize=norm, overwrite_x=overwrite_x)
def idst(x, type=2, n=None, axis=-1, norm=None, overwrite_x=0):
@@ -361,8 +368,9 @@ def idst(x, type=2, n=None, axis=-1, norm=None, overwrite_x=0):
types, see `dst`.
"""
- if norm is not None:
- raise NotImplementedError('idst orthonormalization not yet supported')
+ if type == 1 and norm is not None:
+ raise NotImplementedError(
+ "Orthonormalization not yet supported for IDCT-I")
# Inverse/forward type table
_TP = {1:1, 2:3, 3:2}
return _dst(x, _TP[type], n, axis, normalize=norm, overwrite_x=overwrite_x)
@@ -58,8 +58,8 @@ void @pref@dst1(@type@ * inout, int n, int howmany, int normalize)
#if 0
case DST_NORMALIZE_ORTHONORMAL:
ptr = inout;
- n1 = sqrt(0.5 / (n-1));
- n2 = sqrt(1. / (n-1));
+ n1 = sqrt(0.5 / (n+1));
+ n2 = sqrt(1. / (n+1));
for (i = 0; i < howmany; ++i, ptr+=n) {
ptr[0] *= n1;
for (j = 1; j < n-1; ++j) {
@@ -340,8 +340,7 @@ def _check_1d(self, routine, dtype, shape, axis, overwritable_dtypes):
for type in [1, 2, 3]:
for overwrite_x in [True, False]:
-# for norm in [None, 'ortho']:
- for norm in [None]:
+ for norm in [None, 'ortho']:
if type == 1 and norm == 'ortho':
continue

0 comments on commit f4c6261

Please sign in to comment.