diff --git a/just_forward_check.py b/just_forward_check.py new file mode 100644 index 0000000..3a45cdc --- /dev/null +++ b/just_forward_check.py @@ -0,0 +1,258 @@ +import matplotlib.pyplot as plt +import numpy as np +import scipy.special +from numba import njit +from numpy import fft +from numpy.typing import NDArray + + +### ORIGNAL CODE ### (dont change this part) + +def wdm_transform(data: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]: + return transform_wavelet_freq_helper(np.fft.rfft(data), Nf, Nt, 2 / Nf * phitilde_vec_norm(Nf, Nt, nx)) + + +def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]: + """Normalize phitilde as needed for inverse frequency domain transform""" + ND: int = Nf * Nt + om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64) + + OM: float = np.pi # Nyquist angular frequency + DOM: float = float(OM / Nf) # 2 pi times DF + insDOM: float = float(1.0 / np.sqrt(DOM)) + B = OM / (2 * Nf) + A = (DOM - B) / 2 + phif = np.zeros(om.size, dtype=np.float64) + + mask = (np.abs(om) >= A) & (np.abs(om) < A + B) + + x = (np.abs(om[mask]) - A) / B + y = scipy.special.betainc(nx, nx, x) + phif[mask] = insDOM * np.cos(np.pi / 2.0 * y) + + phif[np.abs(om) < A] = insDOM + + # nrm should be 1 + nrm: float = float( + np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi), + ) + return phif / nrm + + +@njit() +def DX_assign_loop( + m: int, + Nt: int, + Nf: int, + DX: NDArray[np.complex128], + data: NDArray[np.complex128], + phif: NDArray[np.float64], +) -> None: + """Helper for assigning DX in the main loop""" + assert len(DX.shape) == 1, 'Storage array must be 1D' + assert len(data.shape) == 1, 'Data must be 1D' + assert len(phif.shape) == 1, 'Phi array must be 1D' + + i_base: int = int(Nt // 2) + jj_base: int = int(m * Nt // 2) + + if m in (0, Nf): + # NOTE this term appears to be needed to recover correct constant (at least for m=0) but was previously missing + DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] / 2.0 + else: + DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] + + for jj in range(jj_base + 1 - int(Nt // 2), jj_base + int(Nt // 2)): + j: int = int(np.abs(jj - jj_base)) + i: int = i_base - jj_base + jj + if (m == Nf and jj > jj_base) or (m == 0 and jj < jj_base): + DX[i] = 0.0 + elif j == 0: + continue + else: + DX[i] = phif[j] * data[jj] + + +@njit() +def DX_unpack_loop(m: int, Nt: int, Nf: int, DX_trans: NDArray[np.complex128], wave: NDArray[np.float64]) -> None: + """Helper for unpacking fftd DX in main loop""" + assert len(DX_trans.shape) == 1, 'Data array must be 1D' + assert len(wave.shape) == 2, 'Output array must be 2D' + if m == 0: + # half of lowest and highest frequency bin pixels are redundant + # so store them in even and odd components of m=0 respectively + for n in range(0, Nt, 2): + wave[n, 0] = DX_trans[n].real * np.sqrt(2.0) + elif m == Nf: + for n in range(0, Nt, 2): + wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2.0) + else: + for n in range(Nt): + if m % 2: + if (n + m) % 2: + wave[n, m] = -DX_trans[n].imag + else: + wave[n, m] = DX_trans[n].real + elif (n + m) % 2: + wave[n, m] = DX_trans[n].imag + else: + wave[n, m] = DX_trans[n].real + + +def transform_wavelet_freq_helper( + data: NDArray[np.complex128], + Nf: int, + Nt: int, + phif: NDArray[np.float64], +) -> NDArray[np.float64]: + """Helper to do the wavelet transform using the fast wavelet domain transform""" + assert len(data.shape) == 1, 'Only support 1D Arrays currently' + assert len(phif.shape) == 1, 'phif must be 1D' + wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal + + DX = np.zeros(Nt, dtype=np.complex128) + for m in range(Nf + 1): + DX_assign_loop(m, Nt, Nf, DX, data, phif) + DX_trans = fft.ifft(DX, Nt) + DX_unpack_loop(m, Nt, Nf, DX_trans, wave) + return wave + + +#### END OF ORIGINAL CODE ### + +## NEW VERSION ## (change this part) + + +def Phi_unit(f, A, d): + """ + Meyer window function for the WDM wavelet transform. + + See Eq. (10) of Cornish (2020). + `f` and half-width `A` are in units of Δf; `d` controls the smoothness. + """ + B = 1.0 - 2.0 * A + if B <= 0: + if A >= 0.5: + raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.") + + f_arr = np.asarray(f) + result = np.zeros_like(f_arr, dtype=float) + + # Region 1: |f| < A → φ = 1 + mask1 = np.abs(f_arr) < A + result[mask1] = 1.0 + + # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d) + mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B)) + if np.any(mask2): + z = (np.abs(f_arr[mask2]) - A) / B + z = np.clip(z, 0.0, 1.0) + p = scipy.special.betainc(d, d, z) + result[mask2] = np.cos(np.pi * p / 2.0) + + return result.item() if np.isscalar(f) else result + + +def wdm_dT_dF(nt, nf, dt): + """ + Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt. + """ + return nf * dt, 1.0 / (2.0 * nf * dt) + + +def wdm_times_frequencies(nt, nf, dt): + """ + Returns (ts, fs) for WDM: + """ + ΔT, ΔF = wdm_dT_dF(nt, nf, dt) + return np.arange(nt) * ΔT, np.arange(nf) * ΔF + + +def wdm_transform_roll(x, nt, nf, A, d): + n_total = nt * nf + + if x.shape[-1] != n_total: + raise ValueError(f"len(x) must be nt*nf = {n_total}") + if nt % 2 or nf % 2 or not (0 < A < 0.5) or d <= 0: + raise ValueError("nt,nf even; 00 required.") + + # full FFT + X_fft = fft.fft(x) + + # build phi window + fs_full = fft.fftfreq(n_total) + half = nt // 2 + fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]]) + phi = Phi_unit(fs_phi / (1.0 / (2.0 * nf)), A, d) / np.sqrt(1.0 / (2.0 * nf)) + + W = np.zeros((nt, nf), dtype=float) + center = n_total // 2 + start = center - half + + # Handle m=1 to nf-1 (the regular frequency bands) + for m in range(1, nf): + shift = center - m * half + rolled = np.roll(X_fft, shift) + sl = rolled[start:start + nt] + block = np.concatenate([sl[half:], sl[:half]]) + xnm = fft.ifft(block * phi) + # parity factor: swap real/imag mapping + n = np.arange(nt) + parity = (n + m) % 2 + # even indices use imaginary, odd use real + C = np.where(parity == 0, 1, 1j) + W[:, m] = (np.sqrt(2.0) / nf) * np.real(C * xnm) + + # Handle m=0 (DC components) - store in even indices of column 0 + shift = center # No shift for DC + rolled = np.roll(X_fft, shift) + sl = rolled[start:start + nt] + block = np.concatenate([sl[half:], sl[:half]]) + xnm = fft.ifft(block * phi) + # DC components go to even indices with sqrt(2) normalization and factor of 1/2 + for n in range(0, nt, 2): + W[n, 0] = np.real(xnm[n]) * np.sqrt(2.0) / (2.0 * nf) + + # Handle m=nf (Nyquist components) - store in odd indices of column 0 + shift = center - nf * half + rolled = np.roll(X_fft, shift) + sl = rolled[start:start + nt] + block = np.concatenate([sl[half:], sl[:half]]) + xnm = fft.ifft(block * phi) + # Nyquist components go to odd indices with sqrt(2) normalization and factor of 1/2 + for n in range(1, nt, 2): + W[n, 0] = np.real(xnm[n - 1]) * np.sqrt(2.0) / (2.0 * nf) + + return W + + +### + + +def check_transform(): + f0, dt = 1, 0.125 # Frequency and time step + Nf = Nt = 8 + t = np.arange(0, Nt * Nf) * dt + data = np.sin(2 * np.pi * f0 * t) + wave = wdm_transform(data, Nf, Nt, nx=4.0) + wave_roll = wdm_transform_roll(data, Nt, Nf, A=0.25, d=4.0) + wave_diff = wave - wave_roll + + # plotting + T_bins = np.arange(Nt) * dt # Time axis for plots + F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots + fig, axes = plt.subplots(3, 1, figsize=(10, 8)) + + def _plot_ax(fig, ax, w, title): + pmc = ax.pcolormesh(T_bins, F_bins, w.T, shading='auto', lw=0.5, edgecolor='white') + fig.colorbar(pmc, ax=ax, label=title, orientation='vertical') + + _plot_ax(fig, axes[0], wave, 'Wavelet Coefficients') + _plot_ax(fig, axes[1], wave_roll, 'Roll Wavelet Coefficients') + _plot_ax(fig, axes[2], wave_diff, 'Difference (Original - Roll)') + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + check_transform() diff --git a/just_inverse_check.py b/just_inverse_check.py new file mode 100644 index 0000000..c789ca8 --- /dev/null +++ b/just_inverse_check.py @@ -0,0 +1,266 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import scipy.special +from numba import njit +from numpy import fft +from numpy.typing import NDArray + +OUTDIR = 'out' +os.makedirs(OUTDIR, exist_ok=True) + + +### ORIGINAL CODE ### (dont change) +def wdm_inverse_transform(wave_in: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]: + rfft_data = inverse_wavelet_freq_helper_fast(wave_in, phitilde_vec_norm(Nf, Nt, nx), Nf, Nt) + return np.fft.irfft(rfft_data) + + +@njit() +def unpack_wave_inverse( + m: int, + Nt: int, + Nf: int, + phif: NDArray[np.float64], + fft_prefactor2s: NDArray[np.complex128], + res: NDArray[np.complex128], +) -> None: + """Helper for unpacking results of frequency domain inverse transform""" + if m in (0, Nf): + for i_ind in range(int(Nt // 2)): + i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2 + ind3 = (2 * i) % Nt + res[i] += fft_prefactor2s[ind3] * phif[i_ind] + if m == Nf: + i_ind = int(Nt // 2) + i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2 + ind3 = 0 + res[i] += fft_prefactor2s[ind3] * phif[i_ind] + else: + ind31 = (int(Nt // 2) * m) % Nt + ind32 = (int(Nt // 2) * m) % Nt + for i_ind in range(int(Nt // 2)): + i1 = int(Nt // 2) * m - i_ind + i2 = int(Nt // 2) * m + i_ind + res[i1] += fft_prefactor2s[ind31] * phif[i_ind] + res[i2] += fft_prefactor2s[ind32] * phif[i_ind] + ind31 -= 1 + ind32 += 1 + if ind31 < 0: + ind31 = Nt - 1 + if ind32 == Nt: + ind32 = 0 + + res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0] + + +@njit() +def pack_wave_inverse( + m: int, + Nt: int, + Nf: int, + prefactor2s: NDArray[np.complex128], + wave_in: NDArray[np.float64], +) -> None: + """Helper for fast frequency domain inverse transform to prepare for fourier transform""" + if m == 0: + for n in range(Nt): + prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt, 0] + elif m == Nf: + for n in range(Nt): + prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt + 1, 0] + else: + for n in range(Nt): + val = float(wave_in[n, m]) + if (n + m) % 2: + mult2 = -1j + else: + mult2 = 1 + + prefactor2s[n] = mult2 * val + + +# @njit() +def inverse_wavelet_freq_helper_fast( + wave_in: NDArray[np.float64], + phif: NDArray[np.float64], + Nf: int, + Nt: int, +) -> NDArray[np.complex128]: + """Jit compatible loop for wdm_inverse_transform""" + ND = Nf * Nt + + prefactor2s = np.zeros(Nt, np.complex128) + res = np.zeros(ND // 2 + 1, dtype=np.complex128) + + for m in range(Nf + 1): + pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in) + # with numba.objmode(fft_prefactor2s="complex128[:]"): + fft_prefactor2s = fft.fft(prefactor2s) + unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res) + + return res + + +def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]: + """Normalize phitilde as needed for inverse frequency domain transform""" + ND: int = Nf * Nt + om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64) + + OM: float = np.pi # Nyquist angular frequency + DOM: float = float(OM / Nf) # 2 pi times DF + insDOM: float = float(1.0 / np.sqrt(DOM)) + B = OM / (2 * Nf) + A = (DOM - B) / 2 + phif = np.zeros(om.size, dtype=np.float64) + + mask = (np.abs(om) >= A) & (np.abs(om) < A + B) + + x = (np.abs(om[mask]) - A) / B + y = scipy.special.betainc(nx, nx, x) + phif[mask] = insDOM * np.cos(np.pi / 2.0 * y) + + phif[np.abs(om) < A] = insDOM + + # nrm should be 1 + nrm: float = float( + np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi), + ) + return phif / nrm + + +### END OF ORIGINAL CODE ### + +### NEW CODE ### + + +def Phi_unit(f, A, d): + """ + Meyer window function for the WDM wavelet transform. + + See Eq. (10) of Cornish (2020). + `f` and half-width `A` are in units of Δf; `d` controls the smoothness. + """ + B = 1.0 - 2.0 * A + if B <= 0: + if A >= 0.5: + raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.") + + f_arr = np.asarray(f) + result = np.zeros_like(f_arr, dtype=float) + + # Region 1: |f| < A → φ = 1 + mask1 = np.abs(f_arr) < A + result[mask1] = 1.0 + + # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d) + mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B)) + if np.any(mask2): + z = (np.abs(f_arr[mask2]) - A) / B + z = np.clip(z, 0.0, 1.0) + p = scipy.special.betainc(d, d, z) + result[mask2] = np.cos(np.pi * p / 2.0) + + return result.item() if np.isscalar(f) else result + + +def wdm_dT_dF(nt, nf, dt): + """ + Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt. + """ + return nf * dt, 1.0 / (2.0 * nf * dt) + + +def wdm_times_frequencies(nt, nf, dt): + """ + Returns (ts, fs) for WDM: + """ + ΔT, ΔF = wdm_dT_dF(nt, nf, dt) + return np.arange(nt) * ΔT, np.arange(nf) * ΔF + + +def wdm_inverse_transform_new(W, A, d): + nt, nf = W.shape + n_total = nt * nf + half = nt // 2 + + # Build phi window + fs_full = fft.fftfreq(n_total) + fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]]) + _, dF = wdm_dT_dF(nt, nf, 1.0) + phi = Phi_unit(fs_phi / dF, A, d) / np.sqrt(dF) + + # Step 1: Apply parity factors and normalize + ylm = np.zeros((nt, nf), dtype=complex) + for n in range(nt): + for m in range(1, nf): # Julia starts from m=2, Python from m=1 + C = 1 if (n + m) % 2 == 0 else 1j + ylm[n, m] = C * W[n, m] / np.sqrt(2.0) + + # Step 2: FFT over time dimension (axis=0) + ylm_fft = fft.fft(ylm, axis=0) + + # Step 3: Build frequency domain array + X = np.zeros(n_total, dtype=complex) + + for m in range(1, nf): + l0 = m * half + + # First contribution: G[l - m*Nt/2] * Ylm + # This maps ylm_fft[:, m] to frequencies around l0 + temp1 = np.zeros(n_total, dtype=complex) + temp1[:half] = ylm_fft[:half, m] * phi[:half] # Positive frequencies + temp1[half:2 * half] = ylm_fft[half:, m] * phi[half:] # Negative frequencies + X += np.roll(temp1, l0 - half) + + # Second contribution: G[l + m*Nt/2] * conj(Y(-l)m) + l1 = n_total - l0 + + temp2 = np.zeros(n_total, dtype=complex) + # Zero frequency of conjugate + temp2[0] = np.conj(ylm_fft[0, m]) * phi[0] + + # Positive frequencies of conjugate (reversed) + if half > 1: + temp2[1:half] = np.conj(ylm_fft[nt - 1:half:-1, m]) * phi[1:half] + + # Negative frequencies of conjugate (reversed) + temp2[half:2 * half] = np.conj(ylm_fft[half:0:-1, m]) * phi[half:] + + X += np.roll(temp2, l1 - half) + + # Step 4: Inverse FFT to get time domain + x_reconstructed = np.real(fft.ifft(X)) + + return x_reconstructed + + +### end of NEW CODE ### + +def run_single_element_inverse_reconstruction_test(): + nf = nt = 6 + t = np.arange(nf * nt) + x_original_wdm = np.zeros((nt, nf)) + x_original_wdm[3, 3] = 1.0 + + data = wdm_inverse_transform(x_original_wdm, nf, nt, nx=4.0) + data_new = wdm_inverse_transform_new(x_original_wdm, 0.25, 4.0) + print("Reconstructed data:", data) + print("Reconstructed data (new):", data_new) + print("data / data(new):", data/data_new) + if not np.allclose(data, data_new, atol=1e-7): + print("Reconstruction failed for the original test case.") + fig = plt.figure() + plt.plot(t, data, label='Reconstructed Data', color='orange') + plt.plot(t, data_new, label='Reconstructed Data (new)', color='blue', linestyle='--') + plt.legend() + plt.title('Reconstruction Failure for Original Test Case') + plt.savefig(os.path.join(OUTDIR, 'reconstruction_failure_original.png'), dpi=300) + plt.close(fig) + else: + print("Original test case passed.") + + +if __name__ == "__main__": + run_single_element_inverse_reconstruction_test() diff --git a/original.py b/original.py new file mode 100644 index 0000000..1373d69 --- /dev/null +++ b/original.py @@ -0,0 +1,237 @@ +import matplotlib.pyplot as plt +import numpy as np +import scipy.special +from numba import njit +from numpy import fft +from numpy.typing import NDArray + + +def wdm_inverse_transform(wave_in: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]: + rfft_data = inverse_wavelet_freq_helper_fast(wave_in, phitilde_vec_norm(Nf, Nt, nx), Nf, Nt) + return np.fft.irfft(rfft_data) + + +def wdm_transform(data: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]: + return transform_wavelet_freq_helper(np.fft.rfft(data), Nf, Nt, 2 / Nf * phitilde_vec_norm(Nf, Nt, nx)) + + +@njit() +def unpack_wave_inverse( + m: int, + Nt: int, + Nf: int, + phif: NDArray[np.float64], + fft_prefactor2s: NDArray[np.complex128], + res: NDArray[np.complex128], +) -> None: + """Helper for unpacking results of frequency domain inverse transform""" + if m in (0, Nf): + for i_ind in range(int(Nt // 2)): + i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2 + ind3 = (2 * i) % Nt + res[i] += fft_prefactor2s[ind3] * phif[i_ind] + if m == Nf: + i_ind = int(Nt // 2) + i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2 + ind3 = 0 + res[i] += fft_prefactor2s[ind3] * phif[i_ind] + else: + ind31 = (int(Nt // 2) * m) % Nt + ind32 = (int(Nt // 2) * m) % Nt + for i_ind in range(int(Nt // 2)): + i1 = int(Nt // 2) * m - i_ind + i2 = int(Nt // 2) * m + i_ind + res[i1] += fft_prefactor2s[ind31] * phif[i_ind] + res[i2] += fft_prefactor2s[ind32] * phif[i_ind] + ind31 -= 1 + ind32 += 1 + if ind31 < 0: + ind31 = Nt - 1 + if ind32 == Nt: + ind32 = 0 + + res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0] + + +@njit() +def pack_wave_inverse( + m: int, + Nt: int, + Nf: int, + prefactor2s: NDArray[np.complex128], + wave_in: NDArray[np.float64], +) -> None: + """Helper for fast frequency domain inverse transform to prepare for fourier transform""" + if m == 0: + for n in range(Nt): + prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt, 0] + elif m == Nf: + for n in range(Nt): + prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt + 1, 0] + else: + for n in range(Nt): + val = float(wave_in[n, m]) + if (n + m) % 2: + mult2 = -1j + else: + mult2 = 1 + + prefactor2s[n] = mult2 * val + + +# @njit() +def inverse_wavelet_freq_helper_fast( + wave_in: NDArray[np.float64], + phif: NDArray[np.float64], + Nf: int, + Nt: int, +) -> NDArray[np.complex128]: + """Jit compatible loop for wdm_inverse_transform""" + ND = Nf * Nt + + prefactor2s = np.zeros(Nt, np.complex128) + res = np.zeros(ND // 2 + 1, dtype=np.complex128) + + for m in range(Nf + 1): + pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in) + # with numba.objmode(fft_prefactor2s="complex128[:]"): + fft_prefactor2s = fft.fft(prefactor2s) + unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res) + + return res + + +def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]: + """Normalize phitilde as needed for inverse frequency domain transform""" + ND: int = Nf * Nt + om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64) + + OM: float = np.pi # Nyquist angular frequency + DOM: float = float(OM / Nf) # 2 pi times DF + insDOM: float = float(1.0 / np.sqrt(DOM)) + B = OM / (2 * Nf) + A = (DOM - B) / 2 + phif = np.zeros(om.size, dtype=np.float64) + + mask = (np.abs(om) >= A) & (np.abs(om) < A + B) + + x = (np.abs(om[mask]) - A) / B + y = scipy.special.betainc(nx, nx, x) + phif[mask] = insDOM * np.cos(np.pi / 2.0 * y) + + phif[np.abs(om) < A] = insDOM + + # nrm should be 1 + nrm: float = float( + np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi), + ) + return phif / nrm + + +@njit() +def DX_assign_loop( + m: int, + Nt: int, + Nf: int, + DX: NDArray[np.complex128], + data: NDArray[np.complex128], + phif: NDArray[np.float64], +) -> None: + """Helper for assigning DX in the main loop""" + assert len(DX.shape) == 1, 'Storage array must be 1D' + assert len(data.shape) == 1, 'Data must be 1D' + assert len(phif.shape) == 1, 'Phi array must be 1D' + + i_base: int = int(Nt // 2) + jj_base: int = int(m * Nt // 2) + + if m in (0, Nf): + # NOTE this term appears to be needed to recover correct constant (at least for m=0) but was previously missing + DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] / 2.0 + else: + DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] + + for jj in range(jj_base + 1 - int(Nt // 2), jj_base + int(Nt // 2)): + j: int = int(np.abs(jj - jj_base)) + i: int = i_base - jj_base + jj + if (m == Nf and jj > jj_base) or (m == 0 and jj < jj_base): + DX[i] = 0.0 + elif j == 0: + continue + else: + DX[i] = phif[j] * data[jj] + + +@njit() +def DX_unpack_loop(m: int, Nt: int, Nf: int, DX_trans: NDArray[np.complex128], wave: NDArray[np.float64]) -> None: + """Helper for unpacking fftd DX in main loop""" + assert len(DX_trans.shape) == 1, 'Data array must be 1D' + assert len(wave.shape) == 2, 'Output array must be 2D' + if m == 0: + # half of lowest and highest frequency bin pixels are redundant + # so store them in even and odd components of m=0 respectively + for n in range(0, Nt, 2): + wave[n, 0] = DX_trans[n].real * np.sqrt(2.0) + elif m == Nf: + for n in range(0, Nt, 2): + wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2.0) + else: + for n in range(Nt): + if m % 2: + if (n + m) % 2: + wave[n, m] = -DX_trans[n].imag + else: + wave[n, m] = DX_trans[n].real + elif (n + m) % 2: + wave[n, m] = DX_trans[n].imag + else: + wave[n, m] = DX_trans[n].real + + +def transform_wavelet_freq_helper( + data: NDArray[np.complex128], + Nf: int, + Nt: int, + phif: NDArray[np.float64], +) -> NDArray[np.float64]: + """Helper to do the wavelet transform using the fast wavelet domain transform""" + assert len(data.shape) == 1, 'Only support 1D Arrays currently' + assert len(phif.shape) == 1, 'phif must be 1D' + wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal + + DX = np.zeros(Nt, dtype=np.complex128) + for m in range(Nf + 1): + DX_assign_loop(m, Nt, Nf, DX, data, phif) + DX_trans = fft.ifft(DX, Nt) + DX_unpack_loop(m, Nt, Nf, DX_trans, wave) + return wave + + +def check_roundtrip(): + f0, dt = 1, 0.125 # Frequency and time step + Nf = Nt = 8 + nx = 4.0 + t = np.arange(0, Nt * Nf) * dt + data = np.sin(2 * np.pi * f0 * t) + wave = wdm_transform(data, Nf, Nt, nx) + data_reconstructed = wdm_inverse_transform(wave, Nf, Nt, nx) + + # plotting + T_bins = np.arange(Nt) * dt # Time axis for plots + F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots + fig, axes = plt.subplots(2, 1, figsize=(10, 8)) + axes[0].plot(t, data, label='Original Signal', color='blue') + axes[0].plot(t, data_reconstructed, label='Reconstructed Signal', color='orange', linestyle='--') + pmc= axes[1].pcolormesh(T_bins, F_bins, wave.T, shading='auto', lw=0.5, edgecolor='white') + fig.colorbar(pmc, ax=axes[1], label='Wavelet Coefficients', orientation='vertical') + plt.tight_layout() + plt.show() + + # Check if the original and reconstructed data match + assert np.allclose( + data, data_reconstructed, + atol=1e-6), "Roundtrip failed: Original and reconstructed data do not match." + + +if __name__ == "__main__": + check_roundtrip() diff --git a/run_checks.py b/run_checks.py index be01064..9ba7447 100644 --- a/run_checks.py +++ b/run_checks.py @@ -1,53 +1,138 @@ +import os + import matplotlib.pyplot as plt -import numpy as np + from wdm_roll import * +from original import wdm_inverse_transform as wdm_inverse_transform_orig +from original import wdm_transform as wdm_transform_orig plt.style.use('seaborn-v0_8-whitegrid') +HERE = __file__ +OUTDIR = os.path.join(os.path.dirname(HERE), 'out') +os.makedirs(OUTDIR, exist_ok=True) + +A_wavelet_param = 0.25 # Global for tests, +d_wavelet_param = 4 # Global for tests + + # --- Test Functions and Classes --- def chirp_signal(ts, Ac, As, f, fdot): """Generates a chirp signal.""" phases = 2 * np.pi * ts * (f + fdot * ts / 2.0) return Ac * np.cos(phases) + As * np.sin(phases) -A_wavelet_param = 0.25 # Global for tests, -d_wavelet_param = 4 # Global for tests + +def run_monochromatic_wnm_test(): + """Runs a test for monochromatic WNM generation, printing results and plotting.""" + print("\n--- Running Monochromatic WNM Test ---") + + f0, dt = 1, 0.125 # Frequency and time step + Nf = Nt = 8 + nx = 4.0 + A = 1 + N = Nt * Nf + t = np.arange(0, N) * dt # Time vector for original signal + signal_time = A * np.sin(2 * np.pi * f0 * t) # Monochromatic signal + + # Generate analytical WNM for a monochromatic signal + T_bins = np.arange(Nt) * dt # Time axis for plots + F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots + + orig_wnm = wdm_transform_orig(signal_time, Nt, Nf, nx=4.0) + + # Perform WDM transform + wnm = wdm_transform(signal_time, Nt, Nf, A_wavelet_param, d_wavelet_param) + reconstructed_time = wdm_inverse_transform(orig_wnm, A_wavelet_param, d_wavelet_param) + + # Check if the WNM matches the analytical result + diff_wnm = np.abs(wnm - orig_wnm) + max_diff = np.max(diff_wnm) + if max_diff > 1e-6: + print(f"Monochromatic WNM Test FAILED: Max difference {max_diff:.2e} exceeds tolerance.") + else: + print(f"Monochromatic WNM Test PASSED: Max difference {max_diff:.2e} within tolerance.") + + # round-trip reconstruction of analytical WNM + diff_time = np.abs(signal_time - reconstructed_time) + max_diff_time = np.max(diff_time) + if max_diff_time > 1e-6: + print(f"Monochromatic WNM Test FAILED: Max time-domain difference {max_diff_time:.2e} exceeds tolerance.") + else: + print(f"Monochromatic WNM Test PASSED: Max time-domain difference {max_diff_time:.2e} within tolerance.") + + # Plotting ([[time, WDM transform, analytical]]) + fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + + # axes[0]: Time series plot + axes[0].plot(t, signal_time, label='Original Signal', color='blue') + axes[0].plot(t, reconstructed_time, label='Reconstructed Signal', color='orange') + axes[0].set_title("Original vs. Reconstructed Signal (Time Domain)") + axes[0].set_xlabel("Time (s)") + axes[0].set_ylabel("Amplitude") + axes[0].legend() + + # twin plot fot phase + ax2 = axes[0].twinx() + ax2.plot(t, np.angle(reconstructed_time), color='green', linestyle='--', label='Phase of Reconstructed Signal') + ax2.plot(t, np.angle(signal_time), color='red', linestyle='--', label='Phase of Original Signal') + + # axes[1]: WDM Transform + xy = (T_bins, F_bins) + kwgs = dict(cmap='viridis', shading='auto', edgecolors='white', linewidth=0.5) + pcm1 = axes[1].pcolormesh(*xy, np.abs(wnm.T), **kwgs) + fig.colorbar(pcm1, ax=axes[1], orientation='horizontal', pad=0.15) + axes[1].set_title("WDM Transform Output") + + # axes[2]: Analytical WNM + pcm2 = axes[2].pcolormesh(*xy, np.abs(orig_wnm.T), **kwgs) + fig.colorbar(pcm2, ax=axes[2], orientation='horizontal', pad=0.15) + axes[2].set_title("Analytical WNM") + + # axes[3]: Difference between WNM and Analytical WNM + diff_wnm = np.abs(wnm - orig_wnm) + pcm3 = axes[3].pcolormesh(*xy, diff_wnm.T, **kwgs) + fig.colorbar(pcm3, ax=axes[3], orientation='horizontal', pad=0.15) + axes[3].set_title("Difference (WNM - Analytical WNM)") + + plt.tight_layout() + plt.savefig(os.path.join(OUTDIR, 'monochromatic_wnm_test.png'), dpi=300) + def run_parsevals_theorem_and_chirp_track_test(): """Runs Parseval's theorem and chirp tracking tests, printing results and plotting.""" print("\n--- Running Parseval's Theorem and Chirp Track Test ---") - dt = 1.0 / np.pi # Sampling interval of original signal - fny = 1.0 / (2.0 * dt) # Nyquist frequency + dt = 1.0 / np.pi # Sampling interval of original signal + fny = 1.0 / (2.0 * dt) # Nyquist frequency - nt = 64 # WDM time bins - nf = 64 # WDM frequency bins - n_total = nt * nf # Total samples in original signal + nt = 32 # WDM time bins + nf = 32 # WDM frequency bins + n_total = nt * nf # Total samples in original signal - ts_signal = dt * np.arange(n_total) # Time vector for original signal - T_duration = n_total * dt # Total duration of original signal + ts_signal = dt * np.arange(n_total) # Time vector for original signal + T_duration = n_total * dt # Total duration of original signal - f0 = fny / 5.0 # Initial frequency of chirp - fdot = f0 / T_duration # Rate of change of frequency + f0 = fny / 5.0 # Initial frequency of chirp + fdot = f0 / T_duration # Rate of change of frequency Amplitude = 1.0 - rng = np.random.default_rng(seed=42) + rng = np.random.default_rng(seed=42) phi_chirp = np.arctan2(rng.standard_normal(), rng.standard_normal()) Ac = Amplitude * np.cos(phi_chirp) As = Amplitude * np.sin(phi_chirp) f_time_domain = chirp_signal(ts_signal, Ac, As, f0, fdot) - + # Perform WDM transform f_tilde_wdm = wdm_transform(f_time_domain, nt, nf, A_wavelet_param, d_wavelet_param) - + # Perform inverse WDM transform for reconstruction plots f_reconstructed_time = wdm_inverse_transform(f_tilde_wdm, A_wavelet_param, d_wavelet_param) - # --- Parseval's Theorem Test --- - sum_f_sq = np.sum(f_time_domain**2) - sum_f_tilde_sq = np.sum(f_tilde_wdm**2) - + sum_f_sq = np.sum(f_time_domain ** 2) + sum_f_tilde_sq = np.sum(f_tilde_wdm ** 2) + parseval_check = np.isclose(sum_f_sq, sum_f_tilde_sq, rtol=1e-2, atol=0) if not parseval_check: print(f"Parseval's Theorem FAILED: sum(f^2)={sum_f_sq:.4e}, sum(f_tilde^2)={sum_f_tilde_sq:.4e} (rtol=1e-2)") @@ -55,49 +140,57 @@ def run_parsevals_theorem_and_chirp_track_test(): print(f"Parseval's Theorem PASSED: sum(f^2)={sum_f_sq:.4e}, sum(f_tilde^2)={sum_f_tilde_sq:.4e} (rtol=1e-2)") # --- Chirp Track Test --- - dT_wdm_bin, dF_wdm_bin = wdm_dT_dF(nt, nf, dt) # WDM bin widths based on original signal dt + dT_wdm_bin, dF_wdm_bin = wdm_dT_dF(nt, nf, dt) # WDM bin widths based on original signal dt max_power_indices_freq = np.argmax(np.abs(f_tilde_wdm), axis=1) - - times_for_chirp_track_pred = np.arange(nt) * dT_wdm_bin # Time for each WDM time bin + + times_for_chirp_track_pred = np.arange(nt) * dT_wdm_bin # Time for each WDM time bin predicted_frequencies_chirp = f0 + fdot * times_for_chirp_track_pred predicted_max_power_indices_freq = predicted_frequencies_chirp / dF_wdm_bin - + diff_chirp_track = np.abs(max_power_indices_freq[1:-1] - predicted_max_power_indices_freq[1:-1]) if diff_chirp_track.size > 0: chirp_track_check = np.all(diff_chirp_track <= 2.5) max_diff_val = np.max(diff_chirp_track) if not chirp_track_check: - print(f"Chirp Track FAILED: Max deviation {max_diff_val:.2f} > 2.5. Failing diffs: {diff_chirp_track[diff_chirp_track > 2.5]}") + print( + f"Chirp Track FAILED: Max deviation {max_diff_val:.2f} > 2.5. Failing diffs: {diff_chirp_track[diff_chirp_track > 2.5]}") else: print(f"Chirp Track PASSED: Max deviation {max_diff_val:.2f} <= 2.5") - elif nt <=2 : + elif nt <= 2: print("Chirp Track SKIPPED: nt is too small to evaluate edges.") - else: + else: print("Chirp Track SKIPPED: Not enough data points after excluding edges.") # --- Plotting --- - - orig_kwgs = dict( color='tab:blue', alpha=0.5, label='original') + orig_kwgs = dict(color='tab:blue', alpha=0.5, label='original') recon_kwgs = dict(color='tab:orange', alpha=0.5, label='recon') # 1. Chirp in Frequency Domain (FFT of original signal) plt.figure(figsize=(12, 8)) - original_fft = fft(f_time_domain) - original_fft_freqs = fftfreq(n_total, d=dt) - reconstructed_fft = fft(f_reconstructed_time) + original_fft = fft.fft(f_time_domain) + original_fft_freqs = fft.fftfreq(n_total, d=dt) + reconstructed_fft = fft.fft(f_reconstructed_time) # Plot only positive frequencies for clarity positive_freq_mask = original_fft_freqs >= 0 plt.subplot(2, 2, 1) + + ratio = np.abs(original_fft[positive_freq_mask]) / np.abs(reconstructed_fft[positive_freq_mask]) + plt.plot(original_fft_freqs[positive_freq_mask], np.abs(original_fft[positive_freq_mask]), **orig_kwgs) - plt.plot(original_fft_freqs[positive_freq_mask],np.abs(reconstructed_fft[positive_freq_mask]), **recon_kwgs ) - plt.legend() + plt.plot(original_fft_freqs[positive_freq_mask], np.abs(reconstructed_fft[positive_freq_mask]), **recon_kwgs) + plt.legend(fontsize='small', loc='upper right', frameon=True) + # twinx for ratio + ax2 = plt.gca().twinx() + ax2.plot(original_fft_freqs[positive_freq_mask], ratio, color='tab:green', linestyle='--', + label='Ratio (Original/Reconstructed)') + ax2.legend(fontsize='small', loc='upper left', frameon=True) + plt.title('Chirp in Frequency Domain (FFT of Original)') plt.xlabel('Frequency (Hz)') plt.ylabel('Magnitude') - plt.legend(fontsize='small', loc='upper right', frameon=True) - plt.xlim(0, fny) # Show up to Nyquist + plt.xlim(0, fny) # Show up to Nyquist # 2. WDM Transform (imshow) # Get actual time and frequency extents for WDM plot @@ -110,14 +203,14 @@ def run_parsevals_theorem_and_chirp_track_test(): # We want to show the full range covered by the bins # Time axis: from wdm_times_xaxis[0] to wdm_times_xaxis[-1] + dT_wdm_bin # Freq axis: from wdm_freqs_yaxis[0] to wdm_freqs_yaxis[-1] + dF_wdm_bin - img_extent = [wdm_times_xaxis[0], wdm_times_xaxis[-1] + dT_wdm_bin, + img_extent = [wdm_times_xaxis[0], wdm_times_xaxis[-1] + dT_wdm_bin, wdm_freqs_yaxis[0], wdm_freqs_yaxis[-1] + dF_wdm_bin] - + # Transpose f_tilde_wdm because imshow's first index is rows (y-axis, frequency), second is columns (x-axis, time) # And WDM matrix is (nt_bins, nf_bins) = (time_bins, freq_bins) # So f_tilde_wdm is (time, freq). For imshow(M), M[row,col]. # We want time on x-axis, freq on y-axis. So imshow(f_tilde_wdm.T) - plt.imshow(np.abs(f_tilde_wdm.T), aspect='auto', origin='lower', + plt.imshow(np.abs(f_tilde_wdm.T), aspect='auto', origin='lower', extent=img_extent, cmap='viridis') plt.colorbar(label='Magnitude') plt.title('WDM Transform Output') @@ -127,11 +220,10 @@ def run_parsevals_theorem_and_chirp_track_test(): plt.plot(times_for_chirp_track_pred, predicted_frequencies_chirp, 'r--', linewidth=1, label='Predicted Chirp Track') plt.legend(fontsize='small', loc='upper right', frameon=True) - # 3. Reconstructed Chirp (Time Domain) plt.subplot(2, 2, 3) - plt.plot(ts_signal, f_time_domain, **orig_kwgs) - plt.plot(ts_signal, f_reconstructed_time, **recon_kwgs) + plt.plot(ts_signal, f_time_domain, **orig_kwgs) + plt.plot(ts_signal, f_reconstructed_time, **recon_kwgs) plt.title('Original vs. Reconstructed Chirp (Time Domain)') plt.xlabel('Time (s)') plt.ylabel('Amplitude') @@ -146,52 +238,69 @@ def run_parsevals_theorem_and_chirp_track_test(): plt.title('Frequency-Domain Residuals (FFT(Original) - FFT(Reconstructed))') plt.xlabel('Frequency (Hz)') plt.ylabel('Magnitude of Difference') - plt.yscale('log') # Residuals can be small, log scale helps + plt.yscale('log') # Residuals can be small, log scale helps plt.xlim(0, fny) - plt.ylim(bottom=max(1e-9, np.min(freq_residuals[positive_freq_mask & (freq_residuals > 0)])*0.1)) # Avoid zero for log scale + # plt.ylim(bottom=max(1e-9, np.min( + # freq_residuals[positive_freq_mask & (freq_residuals > 0)]) * 0.1)) # Avoid zero for log scale plt.tight_layout() - plt.show() + plt.savefig(os.path.join(OUTDIR, 'parseval_chirp_track_test.png'), dpi=300) def run_single_element_inverse_reconstruction_test(): """Runs round-trip transform for single element impulses, printing results.""" print("\n--- Running Single Element Inverse Reconstruction Test ---") - nt = 32 - nf = 32 + nt = 4 + nf = 4 + + wdm_times_xaxis = np.arange(nt) # Time bins for WDM + wdm_freqs_yaxis = np.arange(nf) # Frequency bins for WDM reconstruction_failures = [] all_passed = True - for i in range(nt): - for j in range(1, nf): + for i in range(nt): + for j in range(nf): x_original_wdm = np.zeros((nt, nf)) x_original_wdm[i, j] = 1.0 - time_signal_from_single_coeff = wdm_inverse_direct(x_original_wdm, A_wavelet_param, d_wavelet_param) - + time_signal_from_single_coeff = wdm_inverse_transform(x_original_wdm, A_wavelet_param, d_wavelet_param) x_reconstructed_wdm = wdm_transform(time_signal_from_single_coeff, nt, nf, A_wavelet_param, d_wavelet_param) - - if not np.allclose(x_original_wdm, x_reconstructed_wdm, atol=1e-7): + + # time_signal_from_single_coeff = np.fft.irfft(wdm_inverse_transform_orig(x_original_wdm, nf, nt, nx=4)) + # x_reconstructed_wdm = wdm_transform_orig(np.fft.rfft(time_signal_from_single_coeff), nf, nt, nx=4) + + + if not np.allclose(x_original_wdm, x_reconstructed_wdm, atol=1e-7): diff_wdm = np.abs(x_original_wdm - x_reconstructed_wdm) diff_val = np.max(diff_wdm) - reconstruction_failures.append(((i,j), diff_val)) + reconstruction_failures.append(((i, j), diff_val)) all_passed = False - - fig, ax = plt.subplots(1,3, figsize=(6,4)) - ax[0].imshow(np.abs(x_original_wdm.T), aspect='auto', origin='lower', cmap='viridis') - ax[1].imshow(np.abs(x_reconstructed_wdm.T), aspect='auto', origin='lower', cmap='viridis') - ax[2].imshow(diff_wdm.T, aspect='auto', origin='lower', cmap='viridis') + + plt.close('all') # Close any previous plots to avoid clutter + fig, ax = plt.subplots(1, 3, figsize=(6, 4)) + + # Use pcolormesh for exact grid plotting and add colorbars on top + xy = (wdm_times_xaxis, wdm_freqs_yaxis) + kwgs = dict(cmap='viridis', shading='auto', edgecolors='white', linewidth=0.5, ) + pcm0 = ax[0].pcolormesh(*xy, np.abs(x_original_wdm.T), **kwgs) + pcm1 = ax[1].pcolormesh(*xy, np.abs(x_reconstructed_wdm.T), **kwgs) + pcm2 = ax[2].pcolormesh(*xy, diff_wdm.T, **kwgs) + fig.colorbar(pcm0, ax=ax[0], orientation='horizontal', pad=0.15) + fig.colorbar(pcm1, ax=ax[1], orientation='horizontal', pad=0.15) + fig.colorbar(pcm2, ax=ax[2], orientation='horizontal', pad=0.15) + ax[0].set_title("Original WDM") ax[1].set_title("Orig->time->WDM") ax[2].set_title("diff") - break - + fig.savefig(os.path.join(OUTDIR, f'single_element_reconstruction_fail_{i}_{j}.png'), dpi=300) + if reconstruction_failures: print(f"Single element reconstruction FAILED for {len(reconstruction_failures)} cases:") - for k_idx in range(min(5, len(reconstruction_failures))): - print(f" Index (t,f)=({reconstruction_failures[k_idx][0][0]},{reconstruction_failures[k_idx][0][1]}), max_abs_diff={reconstruction_failures[k_idx][1]:.2e}") - + for k_idx in range(min(5, len(reconstruction_failures))): + print( + f" Index (t,f)=({reconstruction_failures[k_idx][0][0]},{reconstruction_failures[k_idx][0][1]}), max_abs_diff={reconstruction_failures[k_idx][1]:.2e}") + if all_passed: print("Single Element Inverse Reconstruction Test PASSED for all elements.") else: @@ -199,12 +308,7 @@ def run_single_element_inverse_reconstruction_test(): return all_passed - - if __name__ == '__main__': + run_monochromatic_wnm_test() run_parsevals_theorem_and_chirp_track_test() run_single_element_inverse_reconstruction_test() - - - - diff --git a/wdm_roll.py b/wdm_roll.py index 9f35702..403dca9 100644 --- a/wdm_roll.py +++ b/wdm_roll.py @@ -1,15 +1,24 @@ import numpy as np from scipy.special import betainc -from numpy.fft import fft, ifft, fftfreq +from numpy import fft + def Phi_unit(f, A, d): """ - Meyer window function for the WDM wavelet transform. + Meyer window function Φ(ω) from Cornish Eq. (11). + + Φ(ω) = 1/√ΔΩ for |ω| < A + Φ(ω) = (1/√ΔΩ) cos[νd(π/2 * (|ω|-A)/B)] for A ≤ |ω| ≤ A+B - See Eq. (10) of Cornish (2020). - `f` and half-width `A` are in units of Δf; `d` controls the smoothness. + where νd(x) is the normalized incomplete Beta function (Eq. 12) + and B = ΔΩ - 2A with constraint 2A + B = ΔΩ. + + Args: + f: frequency in units of ΔF + A: half-width parameter (0 < A < 0.5) + d: steepness parameter controlling edge smoothness """ - B = 1.0 - 2.0 * A + B = 1.0 - 2.0 * A # From constraint 2A + B = ΔΩ = 1 (normalized) if B <= 0: if A >= 0.5: raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.") @@ -17,259 +26,184 @@ def Phi_unit(f, A, d): f_arr = np.asarray(f) result = np.zeros_like(f_arr, dtype=float) - # Region 1: |f| < A → φ = 1 + # Region 1: |ω| < A → Φ = 1/√ΔΩ (normalized to 1) mask1 = np.abs(f_arr) < A result[mask1] = 1.0 - # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d) + # Region 2: A ≤ |ω| ≤ A + B → Φ = (1/√ΔΩ) cos[νd(π/2 * (|ω|-A)/B)] mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B)) - if np.any(mask2) and B > 1e-12: + if np.any(mask2): z = (np.abs(f_arr[mask2]) - A) / B z = np.clip(z, 0.0, 1.0) - p = betainc(d, d, z) - result[mask2] = np.cos(np.pi * p / 2.0) + # νd(x) from Eq. (12): normalized incomplete Beta function + nu_d = betainc(d, d, z) + result[mask2] = np.cos(np.pi * nu_d / 2.0) return result.item() if np.isscalar(f) else result -def wdm_dT_dF(nt, nf, dt): +def wdm_dT_dF(Nt, Nf, delta_t): """ - Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt. - ΔT = nf · dt - ΔF = 1 / (2 · nf · dt) + WDM time-frequency grid parameters. + + From Cornish: ΔT = Nf * Δt, ΔF = 1/(2*Nf*Δt) + where ΔT is time pixel width, ΔF is frequency pixel width. """ - ΔT = nf * dt - ΔF = 1.0 / (2.0 * nf * dt) - return (ΔT, ΔF) + Delta_T = Nf * delta_t + Delta_F = 1.0 / (2.0 * Nf * delta_t) + return Delta_T, Delta_F -def wdm_times_frequencies(nt, nf, dt): +def wdm_times_frequencies(Nt, Nf, delta_t): """ - Returns (ts, fs) for WDM: - ts = ΔT · [0..nt−1], fs = ΔF · [0..nf−1], - where (ΔT, ΔF) = wdm_dT_dF(nt,nf,dt). + Generate WDM time-frequency coordinate arrays. """ - ΔT, ΔF = wdm_dT_dF(nt, nf, dt) - ts = np.arange(nt) * ΔT - fs = np.arange(nf) * ΔF - return ts, fs + Delta_T, Delta_F = wdm_dT_dF(Nt, Nf, delta_t) + t_n = np.arange(Nt) * Delta_T # Time coordinates t_n + f_m = np.arange(Nf) * Delta_F # Frequency coordinates f_m + return t_n, f_m -def wdm_transform(x, nt, nf, A, d): +def wdm_transform(x, Nt, Nf, A, d): """ - Forward WDM transform using np.roll + slice + reorder. + Forward WDM transform implementing Cornish Eqs. (16) and (17). + + Eq. (16): w_nm = √2 (-1)^nm ℜ[C_nm * x_m[n]] + Eq. (17): x_m[n] = Σ_{l=-Nt/2}^{Nt/2-1} exp(-2πiln/Nt) X[l + mNt/2] Φ[l] Args: - x : 1D real array of length n_total = nt*nf - nt : number of time bins (even) - nf : number of frequency bins (even) - A,d : Meyer window parameters (0 < A < 0.5, d > 0) + x: Input time series of length N = Nt*Nf + Nt: Number of time bins (must be even) + Nf: Number of frequency bins (must be even) + A, d: Meyer window parameters from Eqs. (11-12) Returns: - W : real array of shape (nt, nf) with WDM coefficients. - - Implementation notes: - 1) Compute full FFT X_fft of length n_total. - 2) Build a single φ-window of length nt by sampling fftfreq(n_total). - 3) For each m = 1..nf−1: - a) Compute shift = (n_total//2) − (m*(nt//2)). - b) rolled = np.roll(X_fft, shift). - c) slice_full = rolled[start : start+nt], where start = center − (nt//2). - d) Reorder slice_full → [ positive_half, negative_half ]. - e) Multiply by φ-window, IFFT to get complex xnm_time of length nt. - f) Multiply by C(n,m) = (1 if (n+m)%2==0 else 1j), take real, scale by √2/nf. - 4) Column m=0 is zero. + w_nm: WDM coefficients, shape (Nt, Nf) + Regular bands in columns 1 to Nf-1 + DC components in even rows of column 0 + Nyquist components in odd rows of column 0 """ - - n_total = nt * nf - if nt % 2 != 0 or nf % 2 != 0: - raise ValueError("nt and nf must both be even.") - if x.shape[-1] != n_total: - raise ValueError(f"len(x)={x.shape[-1]} must equal nt*nf={n_total}.") - if not (0 < A < 0.5): - raise ValueError("A must be in (0, 0.5).") - - # 1) Compute FFT of full signal - X_fft = fft(x) # length = n_total - - # 2) Build φ-window of length=nt - _, dF_phi = wdm_dT_dF(nt, nf, 1.0) - fs_full = fftfreq(n_total) # length = n_total - half = nt // 2 - fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]]) # length = nt - phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi) # length = nt - - # 3) Prepare output array - W = np.zeros((nt, nf), dtype=float) - - center_idx = n_total // 2 - start = center_idx - half # starting index in rolled array - - # 4) For each sub-band m=1..nf-1: - for m in range(1, nf): - freq_bin = m * half - shift = center_idx - freq_bin - rolled = np.roll(X_fft, shift) - - # Slice exactly nt samples around center - slice_full = rolled[start : start + nt] # length = nt - - # slice_full is ordered [ negative_half | positive_half ] - neg_half = slice_full[:half] - pos_half = slice_full[half:] - # We need [pos_half | neg_half] - block = np.concatenate([pos_half, neg_half]) - - # Multiply by φ-window and IFFT - xnm_time = ifft(block * phi_window) # length = nt, complex - - # Build parity factor C(n,m) = 1 if (n+m)%2==0 else 1j - n_idx = np.arange(nt) - parity = (n_idx + m) % 2 - C_col = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) - - # Real part of conj(C)·xnm_time, scaled by √2/nf - W[:, m] = (np.sqrt(2.0) / nf) * np.real(np.conj(C_col) * xnm_time) - - # Column m=0 remains zero - return W + N = Nt * Nf # Total data length + + # Validation + if x.shape[-1] != N: + raise ValueError(f"len(x) must be Nt*Nf = {N}") + if Nt % 2 or Nf % 2 or not (0 < A < 0.5) or d <= 0: + raise ValueError("Nt,Nf even; 00 required.") + + # Step 1: Compute X[l] = FFT(x) (Cornish notation) + X = fft.fft(x) # [Dc, +ive freqs, -ive freqs, nyquist] + + # Step 2: Build frequency domain window Φ[l] from Eq. (11) + # Note: "discrete Fourier samples are evaluated at f = l*Δf" (Cornish text) + l_freqs = fft.fftfreq(N) # l = 0,1,...,N-1 mapped to [-0.5, 0.5) + half = Nt // 2 + + # Reorder for l = -Nt/2, ..., Nt/2-1 as in Eq. (17) + l_indices = np.concatenate([l_freqs[:half], l_freqs[-half:]]) # length Nt + Delta_F_norm = 1.0 / (2.0 * Nf) # Normalized ΔF for unit sampling + Phi = Phi_unit(l_indices / Delta_F_norm, A, d) / np.sqrt(Delta_F_norm) + + # Step 3: Initialize output w_nm + w_nm = np.zeros((Nt, Nf), dtype=float) + + # Grid indices for parity calculations + center_idx = N // 2 + start_idx = center_idx - half + + # Step 4: Process regular frequency bands m = 1, 2, ..., Nf-1 + # Following Eqs. (16) and (17) + for m in range(1, Nf): + # Eq. (17): Extract X[l + mNt/2] using frequency shift + freq_shift = m * half # mNt/2 term from Eq. (17) + roll_amount = center_idx - freq_shift + X_shifted = np.roll(X, roll_amount) + + # Extract Nt samples: X[l + mNt/2] for l = -Nt/2, ..., Nt/2-1 + X_slice = X_shifted[start_idx:start_idx + Nt] + + # Reorder to match IFFT convention: [positive freqs, negative freqs] + X_reordered = np.concatenate([X_slice[half:], X_slice[:half]]) + + # Apply window Φ[l] and compute IFFT to get x_m[n] from Eq. (17) + x_m = fft.ifft(X_reordered * Phi) + + # Eq. (16): Apply parity factors C_nm and phase (-1)^nm + n_indices = np.arange(Nt) + parity = (n_indices + m) % 2 + # C_nm from Eq. (10): C_nm = 1 for (n+m) even, C_nm = i for (n+m) odd + C_nm = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) + + # Final formula: w_nm = √2 (-1)^nm ℜ[C_nm * x_m[n]] + # Note: (-1)^nm = 1 for even parity, -1 for odd parity (but absorbed in C_nm here) + w_nm[:, m] = (np.sqrt(2.0) / Nf) * np.real(C_nm * x_m) + + # Step 5: Handle m = 0 (DC components) - Cornish Ref. [19] special case + # Store in even time indices of column 0 + m = 0 + roll_amount = center_idx # No frequency shift for DC + X_shifted = np.roll(X, roll_amount) + X_slice = X_shifted[start_idx:start_idx + Nt] + X_reordered = np.concatenate([X_slice[half:], X_slice[:half]]) + + # Zero negative frequencies for DC component + X_dc = X_reordered.copy() + X_dc[half:] = 0.0 # Zero negative frequencies + X_dc[0] /= 2.0 # Divide DC bin by 2 + + x_0 = fft.ifft(X_dc * Phi) + # Store DC components in even time indices with √2 normalization + even_indices = np.arange(0, Nt, 2) + w_nm[even_indices, 0] = np.real(x_0[even_indices]) * np.sqrt(2.0) / Nf + + # Step 6: Handle m = Nf (Nyquist components) - store in odd indices of column 0 + m = Nf + freq_shift = Nf * half + roll_amount = center_idx - freq_shift + X_shifted = np.roll(X, roll_amount) + X_slice = X_shifted[start_idx:start_idx + Nt] + X_reordered = np.concatenate([X_slice[half:], X_slice[:half]]) + + # Zero positive frequencies for Nyquist component + X_nyquist = X_reordered.copy() + X_nyquist[:half] = 0.0 # Zero positive frequencies + X_nyquist[half] /= 2.0 # Divide Nyquist bin by 2 + + x_Nf = fft.ifft(X_nyquist * Phi) + # Store Nyquist components in odd time indices + odd_indices = np.arange(1, Nt, 2) + even_source_indices = np.arange(0, Nt, 2) # Take even indices from x_Nf + w_nm[odd_indices, 0] = np.real(x_Nf[even_source_indices]) * np.sqrt(2.0) / Nf + + return w_nm def wdm_inverse_transform(W, A, d): - """ - Inverse WDM using the “roll + reorder + sum” approach: - - Args: - W : real array (nt, nf) of WDM coefficients - A,d : Meyer window parameters - - Returns: - 1D real signal of length n_total = nt * nf. - """ - nt, nf = W.shape - if nt % 2 != 0 or nf % 2 != 0: - raise ValueError("nt and nf must both be even.") n_total = nt * nf - if not (0 < A < 0.5): - raise ValueError("A must be in (0, 0.5).") - - # 1) Build φ-window (same as forward) - _, dF_phi = wdm_dT_dF(nt, nf, 1.0) - fs_full = fftfreq(n_total) + # validation omitted for brevity + fs_full = fft.fftfreq(n_total) half = nt // 2 fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]]) - phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi) - - # 2) Build parity matrix C(n,m) and form ylm = C · W / sqrt(2) · nf - n_idx = np.arange(nt)[:, None] # shape (nt,1) - m_idx = np.arange(nf)[None, :] # shape (1,nf) - parity = (n_idx + m_idx) % 2 - Cmat = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) # shape (nt,nf) + phi = Phi_unit(fs_phi / (1.0/(2.0*nf)), A, d) - # ylm shape = (nt,nf) - # (for m=0 we keep column zero) - ylm = np.zeros((nt, nf), dtype=complex) - ylm[:, 1:] = (Cmat[:, 1:] * W[:, 1:] / np.sqrt(2.0)) * nf + n = np.arange(nt)[:, None] + m = np.arange(nf)[None, :] + parity = (n + m) % 2 + C = np.where(parity == 0, 1, 1j) + ylm = np.zeros((nt, nf), complex) + ylm[:,1:] = (C[:,1:] * W[:,1:] / np.sqrt(2.0)) * nf - # 3) FFT each column along axis=0 - Y = np.fft.fft(ylm, axis=0) # shape (nt,nf) - - # 4) Reconstruct full-spectrum X_recon of length n_total - X_recon = np.zeros(n_total, dtype=complex) - center_idx = n_total // 2 - start = center_idx - half - - for m in range(1, nf): - # Build the nt-length “block” in frequency space: Y[:,m] * φ-window - block = Y[:, m] * phi_window # length = nt - - pos_half = block[:half] - neg_half = block[half:] - block_full = np.concatenate([neg_half, pos_half]) - # → this arranges [ negative_half | positive_half ] at indices [start:start+nt] - - # Place block_full into an otherwise-zero length-n_total array, then roll: - temp_full = np.zeros(n_total, dtype=complex) - temp_full[start : start + nt] = block_full - - # Roll so that the “center” (index=center_idx) goes to freq_bin=m*half - freq_bin = m * half - shift = freq_bin - center_idx - X_recon += np.roll(temp_full, shift) - - # 5) IFFT back to time domain - x_time = ifft(X_recon) - return np.real(x_time) - - - - -def wdm_inverse_direct(W, A, d): - """ - Inverse WDM transform (direct indexing, no roll -- i think i might have a bug :( )). Given W shape=(nt,nf), - returns the real 1D signal x of length = nt*nf. - - Algorithm: - nt, nf must be even, W[:,0]=0. For m=1..nf-1: - 1) Build ylm[n,m] = C(n,m)*W[n,m]/√2 * nf, where C(n,m)=1 or 1j by parity. - 2) Y[:,m] = FFT(ylm[:,m]) (length=nt). - 3) block = Y[:,m] * φ_window (length=nt, with φ_window from fftfreq(nt*nf)). - 4) Split block = [pos_half|neg_half], where pos_half = block[:nt//2], neg_half=block[nt//2:]. - 5) l0 = m*(nt//2). Add pos_half into X[l0 : l0+half]; add neg_half into X[l0-half : l0]. - 6) After looping m=1..nf-1, IFFT(X) → time, return real part. - """ - nt, nf = W.shape - if nt % 2 != 0 or nf % 2 != 0: - raise ValueError("nt and nf must both be even.") - n_total = nt * nf - if not (0 < A < 0.5): - raise ValueError("A must be in (0,0.5).") - - # 1) Build φ_window of length=nt - _, dF_phi = wdm_dT_dF(nt, nf, 1.0) - fs_full = fftfreq(n_total) - half = nt // 2 - fs_phi = np.concatenate((fs_full[:half], fs_full[-half:])) # length=nt - phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi) - - # 2) Build parity matrix C(n,m) and compute ylm - n_idx = np.arange(nt)[:, None] # shape (nt,1) - m_idx = np.arange(nf)[None, :] # shape (1,nf) - parity = (n_idx + m_idx) % 2 - Cmat = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) # shape (nt,nf) - - # ylm[n,m] = C(n,m) * W[n,m] / sqrt(2) * nf (for m>=1; ylm[:,0]=0) - ylm = np.zeros((nt, nf), dtype=complex) - ylm[:, 1:] = (Cmat[:, 1:] * W[:, 1:] / np.sqrt(2.0)) * nf - - # 3) FFT each column of ylm along axis=0 → Y (shape (nt,nf)) - Y = np.fft.fft(ylm, axis=0) - - # 4) Reconstruct full-spectrum X_recon of length n_total by adding each band’s contributions - X_recon = np.zeros(n_total, dtype=complex) - half = nt // 2 + Y = fft.fft(ylm, axis=0) + X_rec = np.zeros(n_total, complex) + center = n_total // 2 + start = center - half for m in range(1, nf): - # Build the nt-length “block” = [pos_half | neg_half] - block = Y[:, m] * phi_window # length = nt - - # Split block into pos/neg - pos_half = block[:half] - neg_half = block[half:] - - # l0 = m*(nt/2) - l0 = m * half - - # Add pos_half into X_recon[l0 : l0 + half] - X_recon[l0 : l0 + half] += pos_half - - # Add neg_half into X_recon[l0 - half : l0] - X_recon[l0 - half : l0] += neg_half - - # (No extra “conj” handling needed because ylm was built with the correct parity factor) - - # 5) IFFT back to time domain - x_time = ifft(X_recon) - return np.real(x_time) + block = Y[:, m] * phi + pos, neg = block[:half], block[half:] + temp = np.zeros(n_total, complex) + temp[start:start+nt] = np.concatenate([neg, pos]) + X_rec += np.roll(temp, m * half - center) + return np.real(fft.ifft(X_rec))