Skip to content

Commit

Permalink
TST: Save long double reference data in .mat file
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Aug 15, 2019
1 parent 4a70a14 commit 6d08b7e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions scipy/fft/_pocketfft/tests/test_real_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
assert_array_almost_equal, assert_equal, assert_allclose)
import pytest
from pytest import raises as assert_raises
from scipy.io import loadmat

from scipy.fft._pocketfft.realtransforms import (
dct, idct, dst, idst, dctn, idctn, dstn, idstn)
Expand All @@ -31,13 +32,18 @@ def get_reference_data():
# * for every type (1, 2, 3, 4) and every size, the array dct_type_size
# contains the output of the DCT applied to the input np.linspace(0, size-1,
# size)
FFTWDATA_LONGDOUBLE = np.load(join(fftpack_test_dir, 'fftw_longdouble_ref.npz'))
FFTWDATA_DOUBLE = np.load(join(fftpack_test_dir, 'fftw_double_ref.npz'))
FFTWDATA_SINGLE = np.load(join(fftpack_test_dir, 'fftw_single_ref.npz'))
FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']

assert len(FFTWDATA_SIZES) == FFTWDATA_COUNT

# not binary compatible on 32-bit systems so cannot use np.load
FFTWDATA_LONGDOUBLE = loadmat(
join(fftpack_test_dir, 'fftw_longdouble_ref.mat'))
FFTWDATA_LONGDOUBLE = {k: v.reshape((-1,))
for k,v in FFTWDATA_LONGDOUBLE.items()
if isinstance(v, np.ndarray)}

ref = {
'FFTWDATA_LONGDOUBLE': FFTWDATA_LONGDOUBLE,
'FFTWDATA_DOUBLE': FFTWDATA_DOUBLE,
Expand All @@ -46,7 +52,6 @@ def get_reference_data():
'X': X,
'Y': Y
}

globals()['__reference_data'] = ref
return ref

Expand Down
Binary file added scipy/fftpack/tests/fftw_longdouble_ref.mat
Binary file not shown.
Binary file removed scipy/fftpack/tests/fftw_longdouble_ref.npz
Binary file not shown.
5 changes: 3 additions & 2 deletions scipy/fftpack/tests/gen_fftw_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from subprocess import Popen, PIPE, STDOUT

import numpy as np
from scipy.io import savemat

SZ = [2, 3, 4, 8, 12, 15, 16, 17, 32, 64, 128, 256, 512, 1024]

Expand Down Expand Up @@ -62,7 +63,7 @@ def gen_data(dt):

# generate long double precision data
data = gen_data(np.float128)
filename = 'fftw_longdouble_ref'
filename = 'fftw_longdouble_ref.mat'
# Save ref data into npz format
d = {'sizes': SZ}
for type in [1, 2, 3, 4]:
Expand All @@ -73,4 +74,4 @@ def gen_data(dt):
for type in [5, 6, 7, 8]:
for sz in SZ:
d['dst_%d_%d' % (type-4, sz)] = data[type][sz]
np.savez(filename, **d)
savemat(filename, d)

0 comments on commit 6d08b7e

Please sign in to comment.