Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add orthogonalize argument to DCT/DST #15102

Merged
merged 2 commits into from Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 22 additions & 12 deletions scipy/fft/_pocketfft/pypocketfft.cxx
Expand Up @@ -231,7 +231,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_,

template<typename T> py::array dct_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
size_t nthreads)
size_t nthreads, bool ortho)
{
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in));
Expand All @@ -244,24 +244,27 @@ template<typename T> py::array dct_internal(const py::array &in,
py::gil_scoped_release release;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, -1)
: norm_fct<T>(inorm, dims, axes, 2);
bool ortho = inorm == 1;
pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}

py::array dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, py::object &out_, size_t nthreads, const py::object & ortho_obj)
{
bool ortho=inorm==1;
if (!ortho_obj.is_none())
ortho=ortho_obj.cast<bool>();

if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, out_,
nthreads))
nthreads, ortho))
}

template<typename T> py::array dst_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
size_t nthreads)
size_t nthreads, bool ortho)
{
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in));
Expand All @@ -274,19 +277,22 @@ template<typename T> py::array dst_internal(const py::array &in,
py::gil_scoped_release release;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, 1)
: norm_fct<T>(inorm, dims, axes, 2);
bool ortho = inorm == 1;
pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}

py::array dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, py::object &out_, size_t nthreads, const py::object &ortho_obj)
{
bool ortho=inorm==1;
if (!ortho_obj.is_none())
ortho=ortho_obj.cast<bool>();

if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm,
out_, nthreads))
out_, nthreads, ortho))
}

template<typename T> py::array c2r_internal(const py::array &in,
Expand Down Expand Up @@ -622,7 +628,7 @@ axes : list of integers
inorm : int
Normalization type
0 : no normalization
1 : make transform orthogonal and divide by sqrt(N)
1 : divide by sqrt(N)
2 : divide by N
where N is the product of n_i for every transformed axis i.
n_i is 2*(<axis_length>-1 for type 1 and 2*<axis length>
Expand All @@ -640,6 +646,8 @@ out : numpy.ndarray (same shape and data type as `a`)
nthreads : int
Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable).
ortho: bool
Orthogonalize transform (defaults to ``inorm==1``)

Returns
-------
Expand All @@ -661,7 +669,7 @@ axes : list of integers
inorm : int
Normalization type
0 : no normalization
1 : make transform orthogonal and divide by sqrt(N)
1 : divide by sqrt(N)
2 : divide by N
where N is the product of n_i for every transformed axis i.
n_i is 2*(<axis_length>+1 for type 1 and 2*<axis length>
Expand All @@ -678,6 +686,8 @@ out : numpy.ndarray (same shape and data type as `a`)
nthreads : int
Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable).
ortho: bool
Orthogonalize transform (defaults to ``inorm==1``)

Returns
-------
Expand Down Expand Up @@ -721,9 +731,9 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a,
"axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"out"_a=None, "nthreads"_a=1, "ortho"_a=None);
m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"out"_a=None, "nthreads"_a=1, "ortho"_a=None);

static PyMethodDef good_size_meth[] =
{{"good_size", (PyCFunction)good_size,
Expand Down
8 changes: 4 additions & 4 deletions scipy/fft/_pocketfft/realtransforms.py
Expand Up @@ -6,7 +6,7 @@


def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None):
overwrite_x=False, workers=None, orthogonalize=None):
"""Forward or backward 1-D DCT/DST

Parameters
Expand Down Expand Up @@ -43,7 +43,7 @@ def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
transform(tmp.imag, type, (axis,), norm, out.imag, workers)
return out

return transform(tmp, type, (axis,), norm, out, workers)
return transform(tmp, type, (axis,), norm, out, workers, orthogonalize)


dct = functools.partial(_r2r, True, pfft.dct)
Expand All @@ -58,7 +58,7 @@ def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,


def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
overwrite_x=False, workers=None):
overwrite_x=False, workers=None, orthogonalize=None):
"""Forward or backward nd DCT/DST

Parameters
Expand Down Expand Up @@ -96,7 +96,7 @@ def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
transform(tmp.imag, type, axes, norm, out.imag, workers)
return out

return transform(tmp, type, axes, norm, out, workers)
return transform(tmp, type, axes, norm, out, workers, orthogonalize)


dctn = functools.partial(_r2rn, True, pfft.dct)
Expand Down