Skip to content

Commit

Permalink
Reformat test and util files source code according to pep8/pyflakes h…
Browse files Browse the repository at this point in the history
…ints
  • Loading branch information
nigma committed Jul 21, 2012
1 parent 28ff680 commit 553b6cb
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 42 deletions.
42 changes: 25 additions & 17 deletions tests/test_matlab_compatibility.py
Expand Up @@ -8,59 +8,67 @@
and reproducibility.
"""

import math

import numpy
from numpy import asarray, float64

try:
from mlabwrap import mlab
from mlabwrap import mlab
except:
print "To run this test you need to have MathWorks MATLAB, MathWorks " \
"Wavelet Toolbox and mlabwrap Python extension installed."
raise SystemExit

import math
import numpy

import pywt
from numpy import asarray, float64


def mse(ar1, ar2):
"""Mean squared error"""
ar1 = asarray(ar1, dtype=float64)
ar2 = asarray(ar2, dtype=float64)
dif = ar1 - ar2
dif *= dif
return dif.sum()/len(ar1)
return dif.sum() / len(ar1)


def rms(ar1, ar2):
"""Root mean squared error"""
return math.sqrt(mse(ar1, ar2))


def test_accuracy(families, wavelets, modes, epsilon=1.0e-10):
print "Testing decomposition".upper()

for pmode, mmode in modes:
for wavelet in wavelets:
print "Wavelet: %-8s Mode: %s" % (wavelet, pmode)

w = pywt.Wavelet(wavelet)
data_size = range(w.dec_len, 40) + [100, 200, 500, 1000, 50000]

for N in data_size:
data = numpy.random.random(N)

# PyWavelets result
pa, pd = pywt.dwt(data, wavelet, pmode)

# Matlab result
ma, md = mlab.dwt(data, wavelet, 'mode', mmode, nout=2)
ma = ma.flat; md = md.flat
ma = ma.flat
md = md.flat

# calculate error measures
mse_a, mse_d = mse(pa, ma), mse(pd, md)
rms_a, rms_d = math.sqrt(mse_a), math.sqrt(mse_d)

if rms_a > epsilon:
print '[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a)

print '[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '\
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a)

if rms_d > epsilon:
print '[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d)
print '[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '\
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d)


if __name__ == '__main__':
Expand All @@ -69,8 +77,8 @@ def test_accuracy(families, wavelets, modes, epsilon=1.0e-10):
wavelets = sum([pywt.wavelist(name) for name in families], [])
# list of mode names in pywt and matlab
modes = [('zpd', 'zpd'), ('cpd', 'sp0'), ('sym', 'sym'),
('ppd', 'ppd'), ('sp1', 'sp1'), ('per', 'per')]
('ppd', 'ppd'), ('sp1', 'sp1'), ('per', 'per')]
# max RMSE
epsilon = 1.0e-10

test_accuracy(families, wavelets, modes, epsilon)
33 changes: 23 additions & 10 deletions tests/test_perfect_reconstruction.py
Expand Up @@ -5,29 +5,34 @@
"""

import math

import numpy
import pywt
from numpy import asarray, float64, float32

import pywt


def mse(ar1, ar2):
"""Mean squared error"""
ar1 = asarray(ar1, dtype=float64)
ar2 = asarray(ar2, dtype=float64)
dif = ar1 - ar2
dif *= dif
return dif.sum()/len(ar1)
return dif.sum() / len(ar1)


def rms(ar1, ar2):
"""Root mean squared error"""
return math.sqrt(mse(ar1, ar2))


def test_perfect_reconstruction(families, wavelets, modes, epsilon, dtype):
for wavelet in wavelets:
for pmode, mmode in modes:
print "Wavelet: %-8s Mode: %s" % (wavelet, pmode),

w = pywt.Wavelet(wavelet)
data_size = range(2, 40) + [100, 200, 500, 1000, 2000, 10000, 50000, 100000]
data_size = range(2, 40) + [100, 200, 500, 1000, 2000, 10000,
50000, 100000]

ok, over = 0, 0
for N in data_size:
Expand All @@ -46,12 +51,15 @@ def test_perfect_reconstruction(families, wavelets, modes, epsilon, dtype):
if rms_rec > epsilon:
if not over:
print
print '[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec, )
print '[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, ' \
'Length: %d, rms=%.3g' % (
pmode, wavelet, len(data), rms_rec)
over += 1
else:
ok += 1
if not over:
print "- RMSE for all %d cases was under %s" % (len(data_size), epsilon)
print "- RMSE for all %d cases was under %s" % (
len(data_size), epsilon)

if __name__ == '__main__':

Expand All @@ -62,8 +70,13 @@ def test_perfect_reconstruction(families, wavelets, modes, epsilon, dtype):
('ppd', 'ppd'), ('sp1', 'sp1'), ('per', 'per')]

print "Testing perfect reconstruction".upper()
for dtype, name, epsilon in [(float32, "float32", 1.0e-7), (float64, "float64", 0.5e-10)][::-1]:
print "#"*80 + "\nPrecision: %s, max RMSE: %s\n" % (name, epsilon) + "#"*80 + "\n"
test_perfect_reconstruction(families, wavelets, modes, epsilon=epsilon, dtype=dtype)
for dtype, name, epsilon in [
(float32, "float32", 1.0e-7),
(float64, "float64", 0.5e-10)
][::-1]:
print "#" * 80
print "Precision: %s, max RMSE: %s" % (name, epsilon)
print "#" * 80 + "\n"
test_perfect_reconstruction(families, wavelets, modes, epsilon=epsilon,
dtype=dtype)
print

9 changes: 5 additions & 4 deletions util/commands.py
Expand Up @@ -15,6 +15,7 @@

base_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..")


def replace_extension(path, newext):
return os.path.splitext(path)[0] + newext

Expand Down Expand Up @@ -75,12 +76,12 @@ def validate_templates_expanded(self):

if not os.path.exists(destination_file):
raise DistutilsClassError(
"Expanded file '{0}' not found. " \
"Expanded file '{0}' not found. "
"Run build first.".format(destination_file))

if templating.needs_update(template_file, destination_file):
raise DistutilsClassError(
"Expanded file '{0}' seems out of date compared to '{1}'. "\
"Expanded file '{0}' seems out of date compared to '{1}'. "
"Run build first.".format(destination_file, template_file))

def validate_pyx_expanded(self):
Expand All @@ -89,12 +90,12 @@ def validate_pyx_expanded(self):

if not os.path.exists(c_file):
raise DistutilsClassError(
"C-source file '{0}' not found. " \
"C-source file '{0}' not found. "
"Run build first.".format(c_file))

if is_newer(pyx_file, c_file):
raise DistutilsClassError(
"C-source file '{0}' seems out of date compared to '{1}'. "\
"C-source file '{0}' seems out of date compared to '{1}'. "
"Run build first.".format(c_file, pyx_file))

def run(self):
Expand Down
29 changes: 18 additions & 11 deletions util/templating.py
Expand Up @@ -15,10 +15,10 @@
pattern_for = re.compile(r"""(?P<for>
^\s*
(?:/{2,})? # optional C comment
\s*
\s*
\#{2} # two hashes
\s*
\s*
(FOR)
\s+ (?P<variable>[\w$][\d\w$]*) \s+
(IN)
Expand Down Expand Up @@ -48,16 +48,17 @@
)
""", re.X | re.M | re.S | re.I)


def expand_template(s):
"""
Currently it only does a simple repeat-and-replace in a loop:
FOR $variable$ IN (value1, value2, ...):
... start block ...
$variable$
... end block ...
ENDFOR $variable$
The above will repeat the block for every value from the list each time
substituting the $variable$ with the current value.
Expand All @@ -70,7 +71,7 @@ def expand_template(s):
... ## ENDFOR $y$
... ## ENDFOR $x$'''
>>> print expand_template(s)
w = 9
print 7, "{"
print 7, 1
Expand All @@ -81,19 +82,24 @@ def expand_template(s):
m = pattern_for.search(s)
if not m:
break

new_body = ''
for value in [v.strip() for v in m.group('values').split(',') if v.strip()]:
for value in [
v.strip() for v in m.group('values').split(',') if v.strip()
]:
new_body += m.group('body').replace(m.group('variable'), value)

s = s[:m.start()] + new_body + s[m.end():]

return s


def get_destination_filepath(source):
root, template_name = os.path.split(source)
destination_name, base_ext = os.path.splitext(template_name) # main extension

# main extension
destination_name, base_ext = os.path.splitext(template_name)

while os.path.extsep in destination_name:
# remove .template extension for files like file.template.c
destination_name = os.path.splitext(destination_name)[0]
Expand All @@ -113,7 +119,8 @@ def expand_files(glob_pattern, force_update=False):
for template_path in files:
destination_path = get_destination_filepath(template_path)
if force_update or needs_update(template_path, destination_path):
print "expanding template: %s -> %s" % (template_path, destination_path)
print "expanding template: %s -> %s" % (
template_path, destination_path)
content = expand_template(open(template_path, "rb").read())
new_file = open(destination_path, "wb")
new_file.write(content)
Expand Down

0 comments on commit 553b6cb

Please sign in to comment.