diff --git a/just_forward_check.py b/just_forward_check.py
new file mode 100644
index 0000000..3a45cdc
--- /dev/null
+++ b/just_forward_check.py
@@ -0,0 +1,258 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.special
+from numba import njit
+from numpy import fft
+from numpy.typing import NDArray
+
+
+### ORIGNAL CODE ### (dont change this part)
+
+def wdm_transform(data: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]:
+ return transform_wavelet_freq_helper(np.fft.rfft(data), Nf, Nt, 2 / Nf * phitilde_vec_norm(Nf, Nt, nx))
+
+
+def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]:
+ """Normalize phitilde as needed for inverse frequency domain transform"""
+ ND: int = Nf * Nt
+ om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64)
+
+ OM: float = np.pi # Nyquist angular frequency
+ DOM: float = float(OM / Nf) # 2 pi times DF
+ insDOM: float = float(1.0 / np.sqrt(DOM))
+ B = OM / (2 * Nf)
+ A = (DOM - B) / 2
+ phif = np.zeros(om.size, dtype=np.float64)
+
+ mask = (np.abs(om) >= A) & (np.abs(om) < A + B)
+
+ x = (np.abs(om[mask]) - A) / B
+ y = scipy.special.betainc(nx, nx, x)
+ phif[mask] = insDOM * np.cos(np.pi / 2.0 * y)
+
+ phif[np.abs(om) < A] = insDOM
+
+ # nrm should be 1
+ nrm: float = float(
+ np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi),
+ )
+ return phif / nrm
+
+
+@njit()
+def DX_assign_loop(
+ m: int,
+ Nt: int,
+ Nf: int,
+ DX: NDArray[np.complex128],
+ data: NDArray[np.complex128],
+ phif: NDArray[np.float64],
+) -> None:
+ """Helper for assigning DX in the main loop"""
+ assert len(DX.shape) == 1, 'Storage array must be 1D'
+ assert len(data.shape) == 1, 'Data must be 1D'
+ assert len(phif.shape) == 1, 'Phi array must be 1D'
+
+ i_base: int = int(Nt // 2)
+ jj_base: int = int(m * Nt // 2)
+
+ if m in (0, Nf):
+ # NOTE this term appears to be needed to recover correct constant (at least for m=0) but was previously missing
+ DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] / 2.0
+ else:
+ DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)]
+
+ for jj in range(jj_base + 1 - int(Nt // 2), jj_base + int(Nt // 2)):
+ j: int = int(np.abs(jj - jj_base))
+ i: int = i_base - jj_base + jj
+ if (m == Nf and jj > jj_base) or (m == 0 and jj < jj_base):
+ DX[i] = 0.0
+ elif j == 0:
+ continue
+ else:
+ DX[i] = phif[j] * data[jj]
+
+
+@njit()
+def DX_unpack_loop(m: int, Nt: int, Nf: int, DX_trans: NDArray[np.complex128], wave: NDArray[np.float64]) -> None:
+ """Helper for unpacking fftd DX in main loop"""
+ assert len(DX_trans.shape) == 1, 'Data array must be 1D'
+ assert len(wave.shape) == 2, 'Output array must be 2D'
+ if m == 0:
+ # half of lowest and highest frequency bin pixels are redundant
+ # so store them in even and odd components of m=0 respectively
+ for n in range(0, Nt, 2):
+ wave[n, 0] = DX_trans[n].real * np.sqrt(2.0)
+ elif m == Nf:
+ for n in range(0, Nt, 2):
+ wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2.0)
+ else:
+ for n in range(Nt):
+ if m % 2:
+ if (n + m) % 2:
+ wave[n, m] = -DX_trans[n].imag
+ else:
+ wave[n, m] = DX_trans[n].real
+ elif (n + m) % 2:
+ wave[n, m] = DX_trans[n].imag
+ else:
+ wave[n, m] = DX_trans[n].real
+
+
+def transform_wavelet_freq_helper(
+ data: NDArray[np.complex128],
+ Nf: int,
+ Nt: int,
+ phif: NDArray[np.float64],
+) -> NDArray[np.float64]:
+ """Helper to do the wavelet transform using the fast wavelet domain transform"""
+ assert len(data.shape) == 1, 'Only support 1D Arrays currently'
+ assert len(phif.shape) == 1, 'phif must be 1D'
+ wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
+
+ DX = np.zeros(Nt, dtype=np.complex128)
+ for m in range(Nf + 1):
+ DX_assign_loop(m, Nt, Nf, DX, data, phif)
+ DX_trans = fft.ifft(DX, Nt)
+ DX_unpack_loop(m, Nt, Nf, DX_trans, wave)
+ return wave
+
+
+#### END OF ORIGINAL CODE ###
+
+## NEW VERSION ## (change this part)
+
+
+def Phi_unit(f, A, d):
+ """
+ Meyer window function for the WDM wavelet transform.
+
+ See Eq. (10) of Cornish (2020).
+ `f` and half-width `A` are in units of Δf; `d` controls the smoothness.
+ """
+ B = 1.0 - 2.0 * A
+ if B <= 0:
+ if A >= 0.5:
+ raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.")
+
+ f_arr = np.asarray(f)
+ result = np.zeros_like(f_arr, dtype=float)
+
+ # Region 1: |f| < A → φ = 1
+ mask1 = np.abs(f_arr) < A
+ result[mask1] = 1.0
+
+ # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d)
+ mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B))
+ if np.any(mask2):
+ z = (np.abs(f_arr[mask2]) - A) / B
+ z = np.clip(z, 0.0, 1.0)
+ p = scipy.special.betainc(d, d, z)
+ result[mask2] = np.cos(np.pi * p / 2.0)
+
+ return result.item() if np.isscalar(f) else result
+
+
+def wdm_dT_dF(nt, nf, dt):
+ """
+ Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt.
+ """
+ return nf * dt, 1.0 / (2.0 * nf * dt)
+
+
+def wdm_times_frequencies(nt, nf, dt):
+ """
+ Returns (ts, fs) for WDM:
+ """
+ ΔT, ΔF = wdm_dT_dF(nt, nf, dt)
+ return np.arange(nt) * ΔT, np.arange(nf) * ΔF
+
+
+def wdm_transform_roll(x, nt, nf, A, d):
+ n_total = nt * nf
+
+ if x.shape[-1] != n_total:
+ raise ValueError(f"len(x) must be nt*nf = {n_total}")
+ if nt % 2 or nf % 2 or not (0 < A < 0.5) or d <= 0:
+ raise ValueError("nt,nf even; 00 required.")
+
+ # full FFT
+ X_fft = fft.fft(x)
+
+ # build phi window
+ fs_full = fft.fftfreq(n_total)
+ half = nt // 2
+ fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]])
+ phi = Phi_unit(fs_phi / (1.0 / (2.0 * nf)), A, d) / np.sqrt(1.0 / (2.0 * nf))
+
+ W = np.zeros((nt, nf), dtype=float)
+ center = n_total // 2
+ start = center - half
+
+ # Handle m=1 to nf-1 (the regular frequency bands)
+ for m in range(1, nf):
+ shift = center - m * half
+ rolled = np.roll(X_fft, shift)
+ sl = rolled[start:start + nt]
+ block = np.concatenate([sl[half:], sl[:half]])
+ xnm = fft.ifft(block * phi)
+ # parity factor: swap real/imag mapping
+ n = np.arange(nt)
+ parity = (n + m) % 2
+ # even indices use imaginary, odd use real
+ C = np.where(parity == 0, 1, 1j)
+ W[:, m] = (np.sqrt(2.0) / nf) * np.real(C * xnm)
+
+ # Handle m=0 (DC components) - store in even indices of column 0
+ shift = center # No shift for DC
+ rolled = np.roll(X_fft, shift)
+ sl = rolled[start:start + nt]
+ block = np.concatenate([sl[half:], sl[:half]])
+ xnm = fft.ifft(block * phi)
+ # DC components go to even indices with sqrt(2) normalization and factor of 1/2
+ for n in range(0, nt, 2):
+ W[n, 0] = np.real(xnm[n]) * np.sqrt(2.0) / (2.0 * nf)
+
+ # Handle m=nf (Nyquist components) - store in odd indices of column 0
+ shift = center - nf * half
+ rolled = np.roll(X_fft, shift)
+ sl = rolled[start:start + nt]
+ block = np.concatenate([sl[half:], sl[:half]])
+ xnm = fft.ifft(block * phi)
+ # Nyquist components go to odd indices with sqrt(2) normalization and factor of 1/2
+ for n in range(1, nt, 2):
+ W[n, 0] = np.real(xnm[n - 1]) * np.sqrt(2.0) / (2.0 * nf)
+
+ return W
+
+
+###
+
+
+def check_transform():
+ f0, dt = 1, 0.125 # Frequency and time step
+ Nf = Nt = 8
+ t = np.arange(0, Nt * Nf) * dt
+ data = np.sin(2 * np.pi * f0 * t)
+ wave = wdm_transform(data, Nf, Nt, nx=4.0)
+ wave_roll = wdm_transform_roll(data, Nt, Nf, A=0.25, d=4.0)
+ wave_diff = wave - wave_roll
+
+ # plotting
+ T_bins = np.arange(Nt) * dt # Time axis for plots
+ F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots
+ fig, axes = plt.subplots(3, 1, figsize=(10, 8))
+
+ def _plot_ax(fig, ax, w, title):
+ pmc = ax.pcolormesh(T_bins, F_bins, w.T, shading='auto', lw=0.5, edgecolor='white')
+ fig.colorbar(pmc, ax=ax, label=title, orientation='vertical')
+
+ _plot_ax(fig, axes[0], wave, 'Wavelet Coefficients')
+ _plot_ax(fig, axes[1], wave_roll, 'Roll Wavelet Coefficients')
+ _plot_ax(fig, axes[2], wave_diff, 'Difference (Original - Roll)')
+ plt.tight_layout()
+ plt.show()
+
+
+if __name__ == "__main__":
+ check_transform()
diff --git a/just_inverse_check.py b/just_inverse_check.py
new file mode 100644
index 0000000..c789ca8
--- /dev/null
+++ b/just_inverse_check.py
@@ -0,0 +1,266 @@
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.special
+from numba import njit
+from numpy import fft
+from numpy.typing import NDArray
+
+OUTDIR = 'out'
+os.makedirs(OUTDIR, exist_ok=True)
+
+
+### ORIGINAL CODE ### (dont change)
+def wdm_inverse_transform(wave_in: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]:
+ rfft_data = inverse_wavelet_freq_helper_fast(wave_in, phitilde_vec_norm(Nf, Nt, nx), Nf, Nt)
+ return np.fft.irfft(rfft_data)
+
+
+@njit()
+def unpack_wave_inverse(
+ m: int,
+ Nt: int,
+ Nf: int,
+ phif: NDArray[np.float64],
+ fft_prefactor2s: NDArray[np.complex128],
+ res: NDArray[np.complex128],
+) -> None:
+ """Helper for unpacking results of frequency domain inverse transform"""
+ if m in (0, Nf):
+ for i_ind in range(int(Nt // 2)):
+ i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2
+ ind3 = (2 * i) % Nt
+ res[i] += fft_prefactor2s[ind3] * phif[i_ind]
+ if m == Nf:
+ i_ind = int(Nt // 2)
+ i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2
+ ind3 = 0
+ res[i] += fft_prefactor2s[ind3] * phif[i_ind]
+ else:
+ ind31 = (int(Nt // 2) * m) % Nt
+ ind32 = (int(Nt // 2) * m) % Nt
+ for i_ind in range(int(Nt // 2)):
+ i1 = int(Nt // 2) * m - i_ind
+ i2 = int(Nt // 2) * m + i_ind
+ res[i1] += fft_prefactor2s[ind31] * phif[i_ind]
+ res[i2] += fft_prefactor2s[ind32] * phif[i_ind]
+ ind31 -= 1
+ ind32 += 1
+ if ind31 < 0:
+ ind31 = Nt - 1
+ if ind32 == Nt:
+ ind32 = 0
+
+ res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
+
+
+@njit()
+def pack_wave_inverse(
+ m: int,
+ Nt: int,
+ Nf: int,
+ prefactor2s: NDArray[np.complex128],
+ wave_in: NDArray[np.float64],
+) -> None:
+ """Helper for fast frequency domain inverse transform to prepare for fourier transform"""
+ if m == 0:
+ for n in range(Nt):
+ prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt, 0]
+ elif m == Nf:
+ for n in range(Nt):
+ prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt + 1, 0]
+ else:
+ for n in range(Nt):
+ val = float(wave_in[n, m])
+ if (n + m) % 2:
+ mult2 = -1j
+ else:
+ mult2 = 1
+
+ prefactor2s[n] = mult2 * val
+
+
+# @njit()
+def inverse_wavelet_freq_helper_fast(
+ wave_in: NDArray[np.float64],
+ phif: NDArray[np.float64],
+ Nf: int,
+ Nt: int,
+) -> NDArray[np.complex128]:
+ """Jit compatible loop for wdm_inverse_transform"""
+ ND = Nf * Nt
+
+ prefactor2s = np.zeros(Nt, np.complex128)
+ res = np.zeros(ND // 2 + 1, dtype=np.complex128)
+
+ for m in range(Nf + 1):
+ pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
+ # with numba.objmode(fft_prefactor2s="complex128[:]"):
+ fft_prefactor2s = fft.fft(prefactor2s)
+ unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res)
+
+ return res
+
+
+def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]:
+ """Normalize phitilde as needed for inverse frequency domain transform"""
+ ND: int = Nf * Nt
+ om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64)
+
+ OM: float = np.pi # Nyquist angular frequency
+ DOM: float = float(OM / Nf) # 2 pi times DF
+ insDOM: float = float(1.0 / np.sqrt(DOM))
+ B = OM / (2 * Nf)
+ A = (DOM - B) / 2
+ phif = np.zeros(om.size, dtype=np.float64)
+
+ mask = (np.abs(om) >= A) & (np.abs(om) < A + B)
+
+ x = (np.abs(om[mask]) - A) / B
+ y = scipy.special.betainc(nx, nx, x)
+ phif[mask] = insDOM * np.cos(np.pi / 2.0 * y)
+
+ phif[np.abs(om) < A] = insDOM
+
+ # nrm should be 1
+ nrm: float = float(
+ np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi),
+ )
+ return phif / nrm
+
+
+### END OF ORIGINAL CODE ###
+
+### NEW CODE ###
+
+
+def Phi_unit(f, A, d):
+ """
+ Meyer window function for the WDM wavelet transform.
+
+ See Eq. (10) of Cornish (2020).
+ `f` and half-width `A` are in units of Δf; `d` controls the smoothness.
+ """
+ B = 1.0 - 2.0 * A
+ if B <= 0:
+ if A >= 0.5:
+ raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.")
+
+ f_arr = np.asarray(f)
+ result = np.zeros_like(f_arr, dtype=float)
+
+ # Region 1: |f| < A → φ = 1
+ mask1 = np.abs(f_arr) < A
+ result[mask1] = 1.0
+
+ # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d)
+ mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B))
+ if np.any(mask2):
+ z = (np.abs(f_arr[mask2]) - A) / B
+ z = np.clip(z, 0.0, 1.0)
+ p = scipy.special.betainc(d, d, z)
+ result[mask2] = np.cos(np.pi * p / 2.0)
+
+ return result.item() if np.isscalar(f) else result
+
+
+def wdm_dT_dF(nt, nf, dt):
+ """
+ Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt.
+ """
+ return nf * dt, 1.0 / (2.0 * nf * dt)
+
+
+def wdm_times_frequencies(nt, nf, dt):
+ """
+ Returns (ts, fs) for WDM:
+ """
+ ΔT, ΔF = wdm_dT_dF(nt, nf, dt)
+ return np.arange(nt) * ΔT, np.arange(nf) * ΔF
+
+
+def wdm_inverse_transform_new(W, A, d):
+ nt, nf = W.shape
+ n_total = nt * nf
+ half = nt // 2
+
+ # Build phi window
+ fs_full = fft.fftfreq(n_total)
+ fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]])
+ _, dF = wdm_dT_dF(nt, nf, 1.0)
+ phi = Phi_unit(fs_phi / dF, A, d) / np.sqrt(dF)
+
+ # Step 1: Apply parity factors and normalize
+ ylm = np.zeros((nt, nf), dtype=complex)
+ for n in range(nt):
+ for m in range(1, nf): # Julia starts from m=2, Python from m=1
+ C = 1 if (n + m) % 2 == 0 else 1j
+ ylm[n, m] = C * W[n, m] / np.sqrt(2.0)
+
+ # Step 2: FFT over time dimension (axis=0)
+ ylm_fft = fft.fft(ylm, axis=0)
+
+ # Step 3: Build frequency domain array
+ X = np.zeros(n_total, dtype=complex)
+
+ for m in range(1, nf):
+ l0 = m * half
+
+ # First contribution: G[l - m*Nt/2] * Ylm
+ # This maps ylm_fft[:, m] to frequencies around l0
+ temp1 = np.zeros(n_total, dtype=complex)
+ temp1[:half] = ylm_fft[:half, m] * phi[:half] # Positive frequencies
+ temp1[half:2 * half] = ylm_fft[half:, m] * phi[half:] # Negative frequencies
+ X += np.roll(temp1, l0 - half)
+
+ # Second contribution: G[l + m*Nt/2] * conj(Y(-l)m)
+ l1 = n_total - l0
+
+ temp2 = np.zeros(n_total, dtype=complex)
+ # Zero frequency of conjugate
+ temp2[0] = np.conj(ylm_fft[0, m]) * phi[0]
+
+ # Positive frequencies of conjugate (reversed)
+ if half > 1:
+ temp2[1:half] = np.conj(ylm_fft[nt - 1:half:-1, m]) * phi[1:half]
+
+ # Negative frequencies of conjugate (reversed)
+ temp2[half:2 * half] = np.conj(ylm_fft[half:0:-1, m]) * phi[half:]
+
+ X += np.roll(temp2, l1 - half)
+
+ # Step 4: Inverse FFT to get time domain
+ x_reconstructed = np.real(fft.ifft(X))
+
+ return x_reconstructed
+
+
+### end of NEW CODE ###
+
+def run_single_element_inverse_reconstruction_test():
+ nf = nt = 6
+ t = np.arange(nf * nt)
+ x_original_wdm = np.zeros((nt, nf))
+ x_original_wdm[3, 3] = 1.0
+
+ data = wdm_inverse_transform(x_original_wdm, nf, nt, nx=4.0)
+ data_new = wdm_inverse_transform_new(x_original_wdm, 0.25, 4.0)
+ print("Reconstructed data:", data)
+ print("Reconstructed data (new):", data_new)
+ print("data / data(new):", data/data_new)
+ if not np.allclose(data, data_new, atol=1e-7):
+ print("Reconstruction failed for the original test case.")
+ fig = plt.figure()
+ plt.plot(t, data, label='Reconstructed Data', color='orange')
+ plt.plot(t, data_new, label='Reconstructed Data (new)', color='blue', linestyle='--')
+ plt.legend()
+ plt.title('Reconstruction Failure for Original Test Case')
+ plt.savefig(os.path.join(OUTDIR, 'reconstruction_failure_original.png'), dpi=300)
+ plt.close(fig)
+ else:
+ print("Original test case passed.")
+
+
+if __name__ == "__main__":
+ run_single_element_inverse_reconstruction_test()
diff --git a/original.py b/original.py
new file mode 100644
index 0000000..1373d69
--- /dev/null
+++ b/original.py
@@ -0,0 +1,237 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.special
+from numba import njit
+from numpy import fft
+from numpy.typing import NDArray
+
+
+def wdm_inverse_transform(wave_in: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]:
+ rfft_data = inverse_wavelet_freq_helper_fast(wave_in, phitilde_vec_norm(Nf, Nt, nx), Nf, Nt)
+ return np.fft.irfft(rfft_data)
+
+
+def wdm_transform(data: NDArray[np.float64], Nf: int, Nt: int, nx: float = 4.0) -> NDArray[np.float64]:
+ return transform_wavelet_freq_helper(np.fft.rfft(data), Nf, Nt, 2 / Nf * phitilde_vec_norm(Nf, Nt, nx))
+
+
+@njit()
+def unpack_wave_inverse(
+ m: int,
+ Nt: int,
+ Nf: int,
+ phif: NDArray[np.float64],
+ fft_prefactor2s: NDArray[np.complex128],
+ res: NDArray[np.complex128],
+) -> None:
+ """Helper for unpacking results of frequency domain inverse transform"""
+ if m in (0, Nf):
+ for i_ind in range(int(Nt // 2)):
+ i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2
+ ind3 = (2 * i) % Nt
+ res[i] += fft_prefactor2s[ind3] * phif[i_ind]
+ if m == Nf:
+ i_ind = int(Nt // 2)
+ i = int(np.abs(m * int(Nt // 2) - i_ind)) # i_off+i_min2
+ ind3 = 0
+ res[i] += fft_prefactor2s[ind3] * phif[i_ind]
+ else:
+ ind31 = (int(Nt // 2) * m) % Nt
+ ind32 = (int(Nt // 2) * m) % Nt
+ for i_ind in range(int(Nt // 2)):
+ i1 = int(Nt // 2) * m - i_ind
+ i2 = int(Nt // 2) * m + i_ind
+ res[i1] += fft_prefactor2s[ind31] * phif[i_ind]
+ res[i2] += fft_prefactor2s[ind32] * phif[i_ind]
+ ind31 -= 1
+ ind32 += 1
+ if ind31 < 0:
+ ind31 = Nt - 1
+ if ind32 == Nt:
+ ind32 = 0
+
+ res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
+
+
+@njit()
+def pack_wave_inverse(
+ m: int,
+ Nt: int,
+ Nf: int,
+ prefactor2s: NDArray[np.complex128],
+ wave_in: NDArray[np.float64],
+) -> None:
+ """Helper for fast frequency domain inverse transform to prepare for fourier transform"""
+ if m == 0:
+ for n in range(Nt):
+ prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt, 0]
+ elif m == Nf:
+ for n in range(Nt):
+ prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt + 1, 0]
+ else:
+ for n in range(Nt):
+ val = float(wave_in[n, m])
+ if (n + m) % 2:
+ mult2 = -1j
+ else:
+ mult2 = 1
+
+ prefactor2s[n] = mult2 * val
+
+
+# @njit()
+def inverse_wavelet_freq_helper_fast(
+ wave_in: NDArray[np.float64],
+ phif: NDArray[np.float64],
+ Nf: int,
+ Nt: int,
+) -> NDArray[np.complex128]:
+ """Jit compatible loop for wdm_inverse_transform"""
+ ND = Nf * Nt
+
+ prefactor2s = np.zeros(Nt, np.complex128)
+ res = np.zeros(ND // 2 + 1, dtype=np.complex128)
+
+ for m in range(Nf + 1):
+ pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
+ # with numba.objmode(fft_prefactor2s="complex128[:]"):
+ fft_prefactor2s = fft.fft(prefactor2s)
+ unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res)
+
+ return res
+
+
+def phitilde_vec_norm(Nf: int, Nt: int, nx: float) -> NDArray[np.float64]:
+ """Normalize phitilde as needed for inverse frequency domain transform"""
+ ND: int = Nf * Nt
+ om: NDArray[np.float64] = np.asarray(2 * np.pi / ND * np.arange(0, Nt // 2 + 1), dtype=np.float64)
+
+ OM: float = np.pi # Nyquist angular frequency
+ DOM: float = float(OM / Nf) # 2 pi times DF
+ insDOM: float = float(1.0 / np.sqrt(DOM))
+ B = OM / (2 * Nf)
+ A = (DOM - B) / 2
+ phif = np.zeros(om.size, dtype=np.float64)
+
+ mask = (np.abs(om) >= A) & (np.abs(om) < A + B)
+
+ x = (np.abs(om[mask]) - A) / B
+ y = scipy.special.betainc(nx, nx, x)
+ phif[mask] = insDOM * np.cos(np.pi / 2.0 * y)
+
+ phif[np.abs(om) < A] = insDOM
+
+ # nrm should be 1
+ nrm: float = float(
+ np.sqrt((2 * np.sum(phif[1:] ** 2) + phif[0] ** 2) * 2 * np.pi / ND) / (np.pi ** (3 / 2) / np.pi),
+ )
+ return phif / nrm
+
+
+@njit()
+def DX_assign_loop(
+ m: int,
+ Nt: int,
+ Nf: int,
+ DX: NDArray[np.complex128],
+ data: NDArray[np.complex128],
+ phif: NDArray[np.float64],
+) -> None:
+ """Helper for assigning DX in the main loop"""
+ assert len(DX.shape) == 1, 'Storage array must be 1D'
+ assert len(data.shape) == 1, 'Data must be 1D'
+ assert len(phif.shape) == 1, 'Phi array must be 1D'
+
+ i_base: int = int(Nt // 2)
+ jj_base: int = int(m * Nt // 2)
+
+ if m in (0, Nf):
+ # NOTE this term appears to be needed to recover correct constant (at least for m=0) but was previously missing
+ DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)] / 2.0
+ else:
+ DX[Nt // 2] = phif[0] * data[int(m * Nt // 2)]
+
+ for jj in range(jj_base + 1 - int(Nt // 2), jj_base + int(Nt // 2)):
+ j: int = int(np.abs(jj - jj_base))
+ i: int = i_base - jj_base + jj
+ if (m == Nf and jj > jj_base) or (m == 0 and jj < jj_base):
+ DX[i] = 0.0
+ elif j == 0:
+ continue
+ else:
+ DX[i] = phif[j] * data[jj]
+
+
+@njit()
+def DX_unpack_loop(m: int, Nt: int, Nf: int, DX_trans: NDArray[np.complex128], wave: NDArray[np.float64]) -> None:
+ """Helper for unpacking fftd DX in main loop"""
+ assert len(DX_trans.shape) == 1, 'Data array must be 1D'
+ assert len(wave.shape) == 2, 'Output array must be 2D'
+ if m == 0:
+ # half of lowest and highest frequency bin pixels are redundant
+ # so store them in even and odd components of m=0 respectively
+ for n in range(0, Nt, 2):
+ wave[n, 0] = DX_trans[n].real * np.sqrt(2.0)
+ elif m == Nf:
+ for n in range(0, Nt, 2):
+ wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2.0)
+ else:
+ for n in range(Nt):
+ if m % 2:
+ if (n + m) % 2:
+ wave[n, m] = -DX_trans[n].imag
+ else:
+ wave[n, m] = DX_trans[n].real
+ elif (n + m) % 2:
+ wave[n, m] = DX_trans[n].imag
+ else:
+ wave[n, m] = DX_trans[n].real
+
+
+def transform_wavelet_freq_helper(
+ data: NDArray[np.complex128],
+ Nf: int,
+ Nt: int,
+ phif: NDArray[np.float64],
+) -> NDArray[np.float64]:
+ """Helper to do the wavelet transform using the fast wavelet domain transform"""
+ assert len(data.shape) == 1, 'Only support 1D Arrays currently'
+ assert len(phif.shape) == 1, 'phif must be 1D'
+ wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
+
+ DX = np.zeros(Nt, dtype=np.complex128)
+ for m in range(Nf + 1):
+ DX_assign_loop(m, Nt, Nf, DX, data, phif)
+ DX_trans = fft.ifft(DX, Nt)
+ DX_unpack_loop(m, Nt, Nf, DX_trans, wave)
+ return wave
+
+
+def check_roundtrip():
+ f0, dt = 1, 0.125 # Frequency and time step
+ Nf = Nt = 8
+ nx = 4.0
+ t = np.arange(0, Nt * Nf) * dt
+ data = np.sin(2 * np.pi * f0 * t)
+ wave = wdm_transform(data, Nf, Nt, nx)
+ data_reconstructed = wdm_inverse_transform(wave, Nf, Nt, nx)
+
+ # plotting
+ T_bins = np.arange(Nt) * dt # Time axis for plots
+ F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots
+ fig, axes = plt.subplots(2, 1, figsize=(10, 8))
+ axes[0].plot(t, data, label='Original Signal', color='blue')
+ axes[0].plot(t, data_reconstructed, label='Reconstructed Signal', color='orange', linestyle='--')
+ pmc= axes[1].pcolormesh(T_bins, F_bins, wave.T, shading='auto', lw=0.5, edgecolor='white')
+ fig.colorbar(pmc, ax=axes[1], label='Wavelet Coefficients', orientation='vertical')
+ plt.tight_layout()
+ plt.show()
+
+ # Check if the original and reconstructed data match
+ assert np.allclose(
+ data, data_reconstructed,
+ atol=1e-6), "Roundtrip failed: Original and reconstructed data do not match."
+
+
+if __name__ == "__main__":
+ check_roundtrip()
diff --git a/run_checks.py b/run_checks.py
index be01064..9ba7447 100644
--- a/run_checks.py
+++ b/run_checks.py
@@ -1,53 +1,138 @@
+import os
+
import matplotlib.pyplot as plt
-import numpy as np
+
from wdm_roll import *
+from original import wdm_inverse_transform as wdm_inverse_transform_orig
+from original import wdm_transform as wdm_transform_orig
plt.style.use('seaborn-v0_8-whitegrid')
+HERE = __file__
+OUTDIR = os.path.join(os.path.dirname(HERE), 'out')
+os.makedirs(OUTDIR, exist_ok=True)
+
+A_wavelet_param = 0.25 # Global for tests,
+d_wavelet_param = 4 # Global for tests
+
+
# --- Test Functions and Classes ---
def chirp_signal(ts, Ac, As, f, fdot):
"""Generates a chirp signal."""
phases = 2 * np.pi * ts * (f + fdot * ts / 2.0)
return Ac * np.cos(phases) + As * np.sin(phases)
-A_wavelet_param = 0.25 # Global for tests,
-d_wavelet_param = 4 # Global for tests
+
+def run_monochromatic_wnm_test():
+ """Runs a test for monochromatic WNM generation, printing results and plotting."""
+ print("\n--- Running Monochromatic WNM Test ---")
+
+ f0, dt = 1, 0.125 # Frequency and time step
+ Nf = Nt = 8
+ nx = 4.0
+ A = 1
+ N = Nt * Nf
+ t = np.arange(0, N) * dt # Time vector for original signal
+ signal_time = A * np.sin(2 * np.pi * f0 * t) # Monochromatic signal
+
+ # Generate analytical WNM for a monochromatic signal
+ T_bins = np.arange(Nt) * dt # Time axis for plots
+ F_bins = np.arange(Nf) / (2 * Nf * dt) # Frequency axis for plots
+
+ orig_wnm = wdm_transform_orig(signal_time, Nt, Nf, nx=4.0)
+
+ # Perform WDM transform
+ wnm = wdm_transform(signal_time, Nt, Nf, A_wavelet_param, d_wavelet_param)
+ reconstructed_time = wdm_inverse_transform(orig_wnm, A_wavelet_param, d_wavelet_param)
+
+ # Check if the WNM matches the analytical result
+ diff_wnm = np.abs(wnm - orig_wnm)
+ max_diff = np.max(diff_wnm)
+ if max_diff > 1e-6:
+ print(f"Monochromatic WNM Test FAILED: Max difference {max_diff:.2e} exceeds tolerance.")
+ else:
+ print(f"Monochromatic WNM Test PASSED: Max difference {max_diff:.2e} within tolerance.")
+
+ # round-trip reconstruction of analytical WNM
+ diff_time = np.abs(signal_time - reconstructed_time)
+ max_diff_time = np.max(diff_time)
+ if max_diff_time > 1e-6:
+ print(f"Monochromatic WNM Test FAILED: Max time-domain difference {max_diff_time:.2e} exceeds tolerance.")
+ else:
+ print(f"Monochromatic WNM Test PASSED: Max time-domain difference {max_diff_time:.2e} within tolerance.")
+
+ # Plotting ([[time, WDM transform, analytical]])
+ fig, axes = plt.subplots(1, 4, figsize=(20, 5))
+
+ # axes[0]: Time series plot
+ axes[0].plot(t, signal_time, label='Original Signal', color='blue')
+ axes[0].plot(t, reconstructed_time, label='Reconstructed Signal', color='orange')
+ axes[0].set_title("Original vs. Reconstructed Signal (Time Domain)")
+ axes[0].set_xlabel("Time (s)")
+ axes[0].set_ylabel("Amplitude")
+ axes[0].legend()
+
+ # twin plot fot phase
+ ax2 = axes[0].twinx()
+ ax2.plot(t, np.angle(reconstructed_time), color='green', linestyle='--', label='Phase of Reconstructed Signal')
+ ax2.plot(t, np.angle(signal_time), color='red', linestyle='--', label='Phase of Original Signal')
+
+ # axes[1]: WDM Transform
+ xy = (T_bins, F_bins)
+ kwgs = dict(cmap='viridis', shading='auto', edgecolors='white', linewidth=0.5)
+ pcm1 = axes[1].pcolormesh(*xy, np.abs(wnm.T), **kwgs)
+ fig.colorbar(pcm1, ax=axes[1], orientation='horizontal', pad=0.15)
+ axes[1].set_title("WDM Transform Output")
+
+ # axes[2]: Analytical WNM
+ pcm2 = axes[2].pcolormesh(*xy, np.abs(orig_wnm.T), **kwgs)
+ fig.colorbar(pcm2, ax=axes[2], orientation='horizontal', pad=0.15)
+ axes[2].set_title("Analytical WNM")
+
+ # axes[3]: Difference between WNM and Analytical WNM
+ diff_wnm = np.abs(wnm - orig_wnm)
+ pcm3 = axes[3].pcolormesh(*xy, diff_wnm.T, **kwgs)
+ fig.colorbar(pcm3, ax=axes[3], orientation='horizontal', pad=0.15)
+ axes[3].set_title("Difference (WNM - Analytical WNM)")
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(OUTDIR, 'monochromatic_wnm_test.png'), dpi=300)
+
def run_parsevals_theorem_and_chirp_track_test():
"""Runs Parseval's theorem and chirp tracking tests, printing results and plotting."""
print("\n--- Running Parseval's Theorem and Chirp Track Test ---")
- dt = 1.0 / np.pi # Sampling interval of original signal
- fny = 1.0 / (2.0 * dt) # Nyquist frequency
+ dt = 1.0 / np.pi # Sampling interval of original signal
+ fny = 1.0 / (2.0 * dt) # Nyquist frequency
- nt = 64 # WDM time bins
- nf = 64 # WDM frequency bins
- n_total = nt * nf # Total samples in original signal
+ nt = 32 # WDM time bins
+ nf = 32 # WDM frequency bins
+ n_total = nt * nf # Total samples in original signal
- ts_signal = dt * np.arange(n_total) # Time vector for original signal
- T_duration = n_total * dt # Total duration of original signal
+ ts_signal = dt * np.arange(n_total) # Time vector for original signal
+ T_duration = n_total * dt # Total duration of original signal
- f0 = fny / 5.0 # Initial frequency of chirp
- fdot = f0 / T_duration # Rate of change of frequency
+ f0 = fny / 5.0 # Initial frequency of chirp
+ fdot = f0 / T_duration # Rate of change of frequency
Amplitude = 1.0
- rng = np.random.default_rng(seed=42)
+ rng = np.random.default_rng(seed=42)
phi_chirp = np.arctan2(rng.standard_normal(), rng.standard_normal())
Ac = Amplitude * np.cos(phi_chirp)
As = Amplitude * np.sin(phi_chirp)
f_time_domain = chirp_signal(ts_signal, Ac, As, f0, fdot)
-
+
# Perform WDM transform
f_tilde_wdm = wdm_transform(f_time_domain, nt, nf, A_wavelet_param, d_wavelet_param)
-
+
# Perform inverse WDM transform for reconstruction plots
f_reconstructed_time = wdm_inverse_transform(f_tilde_wdm, A_wavelet_param, d_wavelet_param)
-
# --- Parseval's Theorem Test ---
- sum_f_sq = np.sum(f_time_domain**2)
- sum_f_tilde_sq = np.sum(f_tilde_wdm**2)
-
+ sum_f_sq = np.sum(f_time_domain ** 2)
+ sum_f_tilde_sq = np.sum(f_tilde_wdm ** 2)
+
parseval_check = np.isclose(sum_f_sq, sum_f_tilde_sq, rtol=1e-2, atol=0)
if not parseval_check:
print(f"Parseval's Theorem FAILED: sum(f^2)={sum_f_sq:.4e}, sum(f_tilde^2)={sum_f_tilde_sq:.4e} (rtol=1e-2)")
@@ -55,49 +140,57 @@ def run_parsevals_theorem_and_chirp_track_test():
print(f"Parseval's Theorem PASSED: sum(f^2)={sum_f_sq:.4e}, sum(f_tilde^2)={sum_f_tilde_sq:.4e} (rtol=1e-2)")
# --- Chirp Track Test ---
- dT_wdm_bin, dF_wdm_bin = wdm_dT_dF(nt, nf, dt) # WDM bin widths based on original signal dt
+ dT_wdm_bin, dF_wdm_bin = wdm_dT_dF(nt, nf, dt) # WDM bin widths based on original signal dt
max_power_indices_freq = np.argmax(np.abs(f_tilde_wdm), axis=1)
-
- times_for_chirp_track_pred = np.arange(nt) * dT_wdm_bin # Time for each WDM time bin
+
+ times_for_chirp_track_pred = np.arange(nt) * dT_wdm_bin # Time for each WDM time bin
predicted_frequencies_chirp = f0 + fdot * times_for_chirp_track_pred
predicted_max_power_indices_freq = predicted_frequencies_chirp / dF_wdm_bin
-
+
diff_chirp_track = np.abs(max_power_indices_freq[1:-1] - predicted_max_power_indices_freq[1:-1])
if diff_chirp_track.size > 0:
chirp_track_check = np.all(diff_chirp_track <= 2.5)
max_diff_val = np.max(diff_chirp_track)
if not chirp_track_check:
- print(f"Chirp Track FAILED: Max deviation {max_diff_val:.2f} > 2.5. Failing diffs: {diff_chirp_track[diff_chirp_track > 2.5]}")
+ print(
+ f"Chirp Track FAILED: Max deviation {max_diff_val:.2f} > 2.5. Failing diffs: {diff_chirp_track[diff_chirp_track > 2.5]}")
else:
print(f"Chirp Track PASSED: Max deviation {max_diff_val:.2f} <= 2.5")
- elif nt <=2 :
+ elif nt <= 2:
print("Chirp Track SKIPPED: nt is too small to evaluate edges.")
- else:
+ else:
print("Chirp Track SKIPPED: Not enough data points after excluding edges.")
# --- Plotting ---
-
- orig_kwgs = dict( color='tab:blue', alpha=0.5, label='original')
+ orig_kwgs = dict(color='tab:blue', alpha=0.5, label='original')
recon_kwgs = dict(color='tab:orange', alpha=0.5, label='recon')
# 1. Chirp in Frequency Domain (FFT of original signal)
plt.figure(figsize=(12, 8))
- original_fft = fft(f_time_domain)
- original_fft_freqs = fftfreq(n_total, d=dt)
- reconstructed_fft = fft(f_reconstructed_time)
+ original_fft = fft.fft(f_time_domain)
+ original_fft_freqs = fft.fftfreq(n_total, d=dt)
+ reconstructed_fft = fft.fft(f_reconstructed_time)
# Plot only positive frequencies for clarity
positive_freq_mask = original_fft_freqs >= 0
plt.subplot(2, 2, 1)
+
+ ratio = np.abs(original_fft[positive_freq_mask]) / np.abs(reconstructed_fft[positive_freq_mask])
+
plt.plot(original_fft_freqs[positive_freq_mask], np.abs(original_fft[positive_freq_mask]), **orig_kwgs)
- plt.plot(original_fft_freqs[positive_freq_mask],np.abs(reconstructed_fft[positive_freq_mask]), **recon_kwgs )
- plt.legend()
+ plt.plot(original_fft_freqs[positive_freq_mask], np.abs(reconstructed_fft[positive_freq_mask]), **recon_kwgs)
+ plt.legend(fontsize='small', loc='upper right', frameon=True)
+ # twinx for ratio
+ ax2 = plt.gca().twinx()
+ ax2.plot(original_fft_freqs[positive_freq_mask], ratio, color='tab:green', linestyle='--',
+ label='Ratio (Original/Reconstructed)')
+ ax2.legend(fontsize='small', loc='upper left', frameon=True)
+
plt.title('Chirp in Frequency Domain (FFT of Original)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
- plt.legend(fontsize='small', loc='upper right', frameon=True)
- plt.xlim(0, fny) # Show up to Nyquist
+ plt.xlim(0, fny) # Show up to Nyquist
# 2. WDM Transform (imshow)
# Get actual time and frequency extents for WDM plot
@@ -110,14 +203,14 @@ def run_parsevals_theorem_and_chirp_track_test():
# We want to show the full range covered by the bins
# Time axis: from wdm_times_xaxis[0] to wdm_times_xaxis[-1] + dT_wdm_bin
# Freq axis: from wdm_freqs_yaxis[0] to wdm_freqs_yaxis[-1] + dF_wdm_bin
- img_extent = [wdm_times_xaxis[0], wdm_times_xaxis[-1] + dT_wdm_bin,
+ img_extent = [wdm_times_xaxis[0], wdm_times_xaxis[-1] + dT_wdm_bin,
wdm_freqs_yaxis[0], wdm_freqs_yaxis[-1] + dF_wdm_bin]
-
+
# Transpose f_tilde_wdm because imshow's first index is rows (y-axis, frequency), second is columns (x-axis, time)
# And WDM matrix is (nt_bins, nf_bins) = (time_bins, freq_bins)
# So f_tilde_wdm is (time, freq). For imshow(M), M[row,col].
# We want time on x-axis, freq on y-axis. So imshow(f_tilde_wdm.T)
- plt.imshow(np.abs(f_tilde_wdm.T), aspect='auto', origin='lower',
+ plt.imshow(np.abs(f_tilde_wdm.T), aspect='auto', origin='lower',
extent=img_extent, cmap='viridis')
plt.colorbar(label='Magnitude')
plt.title('WDM Transform Output')
@@ -127,11 +220,10 @@ def run_parsevals_theorem_and_chirp_track_test():
plt.plot(times_for_chirp_track_pred, predicted_frequencies_chirp, 'r--', linewidth=1, label='Predicted Chirp Track')
plt.legend(fontsize='small', loc='upper right', frameon=True)
-
# 3. Reconstructed Chirp (Time Domain)
plt.subplot(2, 2, 3)
- plt.plot(ts_signal, f_time_domain, **orig_kwgs)
- plt.plot(ts_signal, f_reconstructed_time, **recon_kwgs)
+ plt.plot(ts_signal, f_time_domain, **orig_kwgs)
+ plt.plot(ts_signal, f_reconstructed_time, **recon_kwgs)
plt.title('Original vs. Reconstructed Chirp (Time Domain)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
@@ -146,52 +238,69 @@ def run_parsevals_theorem_and_chirp_track_test():
plt.title('Frequency-Domain Residuals (FFT(Original) - FFT(Reconstructed))')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude of Difference')
- plt.yscale('log') # Residuals can be small, log scale helps
+ plt.yscale('log') # Residuals can be small, log scale helps
plt.xlim(0, fny)
- plt.ylim(bottom=max(1e-9, np.min(freq_residuals[positive_freq_mask & (freq_residuals > 0)])*0.1)) # Avoid zero for log scale
+ # plt.ylim(bottom=max(1e-9, np.min(
+ # freq_residuals[positive_freq_mask & (freq_residuals > 0)]) * 0.1)) # Avoid zero for log scale
plt.tight_layout()
- plt.show()
+ plt.savefig(os.path.join(OUTDIR, 'parseval_chirp_track_test.png'), dpi=300)
def run_single_element_inverse_reconstruction_test():
"""Runs round-trip transform for single element impulses, printing results."""
print("\n--- Running Single Element Inverse Reconstruction Test ---")
- nt = 32
- nf = 32
+ nt = 4
+ nf = 4
+
+ wdm_times_xaxis = np.arange(nt) # Time bins for WDM
+ wdm_freqs_yaxis = np.arange(nf) # Frequency bins for WDM
reconstruction_failures = []
all_passed = True
- for i in range(nt):
- for j in range(1, nf):
+ for i in range(nt):
+ for j in range(nf):
x_original_wdm = np.zeros((nt, nf))
x_original_wdm[i, j] = 1.0
- time_signal_from_single_coeff = wdm_inverse_direct(x_original_wdm, A_wavelet_param, d_wavelet_param)
-
+ time_signal_from_single_coeff = wdm_inverse_transform(x_original_wdm, A_wavelet_param, d_wavelet_param)
x_reconstructed_wdm = wdm_transform(time_signal_from_single_coeff, nt, nf, A_wavelet_param, d_wavelet_param)
-
- if not np.allclose(x_original_wdm, x_reconstructed_wdm, atol=1e-7):
+
+ # time_signal_from_single_coeff = np.fft.irfft(wdm_inverse_transform_orig(x_original_wdm, nf, nt, nx=4))
+ # x_reconstructed_wdm = wdm_transform_orig(np.fft.rfft(time_signal_from_single_coeff), nf, nt, nx=4)
+
+
+ if not np.allclose(x_original_wdm, x_reconstructed_wdm, atol=1e-7):
diff_wdm = np.abs(x_original_wdm - x_reconstructed_wdm)
diff_val = np.max(diff_wdm)
- reconstruction_failures.append(((i,j), diff_val))
+ reconstruction_failures.append(((i, j), diff_val))
all_passed = False
-
- fig, ax = plt.subplots(1,3, figsize=(6,4))
- ax[0].imshow(np.abs(x_original_wdm.T), aspect='auto', origin='lower', cmap='viridis')
- ax[1].imshow(np.abs(x_reconstructed_wdm.T), aspect='auto', origin='lower', cmap='viridis')
- ax[2].imshow(diff_wdm.T, aspect='auto', origin='lower', cmap='viridis')
+
+ plt.close('all') # Close any previous plots to avoid clutter
+ fig, ax = plt.subplots(1, 3, figsize=(6, 4))
+
+ # Use pcolormesh for exact grid plotting and add colorbars on top
+ xy = (wdm_times_xaxis, wdm_freqs_yaxis)
+ kwgs = dict(cmap='viridis', shading='auto', edgecolors='white', linewidth=0.5, )
+ pcm0 = ax[0].pcolormesh(*xy, np.abs(x_original_wdm.T), **kwgs)
+ pcm1 = ax[1].pcolormesh(*xy, np.abs(x_reconstructed_wdm.T), **kwgs)
+ pcm2 = ax[2].pcolormesh(*xy, diff_wdm.T, **kwgs)
+ fig.colorbar(pcm0, ax=ax[0], orientation='horizontal', pad=0.15)
+ fig.colorbar(pcm1, ax=ax[1], orientation='horizontal', pad=0.15)
+ fig.colorbar(pcm2, ax=ax[2], orientation='horizontal', pad=0.15)
+
ax[0].set_title("Original WDM")
ax[1].set_title("Orig->time->WDM")
ax[2].set_title("diff")
- break
-
+ fig.savefig(os.path.join(OUTDIR, f'single_element_reconstruction_fail_{i}_{j}.png'), dpi=300)
+
if reconstruction_failures:
print(f"Single element reconstruction FAILED for {len(reconstruction_failures)} cases:")
- for k_idx in range(min(5, len(reconstruction_failures))):
- print(f" Index (t,f)=({reconstruction_failures[k_idx][0][0]},{reconstruction_failures[k_idx][0][1]}), max_abs_diff={reconstruction_failures[k_idx][1]:.2e}")
-
+ for k_idx in range(min(5, len(reconstruction_failures))):
+ print(
+ f" Index (t,f)=({reconstruction_failures[k_idx][0][0]},{reconstruction_failures[k_idx][0][1]}), max_abs_diff={reconstruction_failures[k_idx][1]:.2e}")
+
if all_passed:
print("Single Element Inverse Reconstruction Test PASSED for all elements.")
else:
@@ -199,12 +308,7 @@ def run_single_element_inverse_reconstruction_test():
return all_passed
-
-
if __name__ == '__main__':
+ run_monochromatic_wnm_test()
run_parsevals_theorem_and_chirp_track_test()
run_single_element_inverse_reconstruction_test()
-
-
-
-
diff --git a/wdm_roll.py b/wdm_roll.py
index 9f35702..403dca9 100644
--- a/wdm_roll.py
+++ b/wdm_roll.py
@@ -1,15 +1,24 @@
import numpy as np
from scipy.special import betainc
-from numpy.fft import fft, ifft, fftfreq
+from numpy import fft
+
def Phi_unit(f, A, d):
"""
- Meyer window function for the WDM wavelet transform.
+ Meyer window function Φ(ω) from Cornish Eq. (11).
+
+ Φ(ω) = 1/√ΔΩ for |ω| < A
+ Φ(ω) = (1/√ΔΩ) cos[νd(π/2 * (|ω|-A)/B)] for A ≤ |ω| ≤ A+B
- See Eq. (10) of Cornish (2020).
- `f` and half-width `A` are in units of Δf; `d` controls the smoothness.
+ where νd(x) is the normalized incomplete Beta function (Eq. 12)
+ and B = ΔΩ - 2A with constraint 2A + B = ΔΩ.
+
+ Args:
+ f: frequency in units of ΔF
+ A: half-width parameter (0 < A < 0.5)
+ d: steepness parameter controlling edge smoothness
"""
- B = 1.0 - 2.0 * A
+ B = 1.0 - 2.0 * A # From constraint 2A + B = ΔΩ = 1 (normalized)
if B <= 0:
if A >= 0.5:
raise ValueError("A must be < 0.5 so that B = 1 − 2A > 0.")
@@ -17,259 +26,184 @@ def Phi_unit(f, A, d):
f_arr = np.asarray(f)
result = np.zeros_like(f_arr, dtype=float)
- # Region 1: |f| < A → φ = 1
+ # Region 1: |ω| < A → Φ = 1/√ΔΩ (normalized to 1)
mask1 = np.abs(f_arr) < A
result[mask1] = 1.0
- # Region 2: A ≤ |f| < A + B → φ = cos(π/2 · p), p = I((|f| − A)/B; d, d)
+ # Region 2: A ≤ |ω| ≤ A + B → Φ = (1/√ΔΩ) cos[νd(π/2 * (|ω|-A)/B)]
mask2 = (np.abs(f_arr) >= A) & (np.abs(f_arr) < (A + B))
- if np.any(mask2) and B > 1e-12:
+ if np.any(mask2):
z = (np.abs(f_arr[mask2]) - A) / B
z = np.clip(z, 0.0, 1.0)
- p = betainc(d, d, z)
- result[mask2] = np.cos(np.pi * p / 2.0)
+ # νd(x) from Eq. (12): normalized incomplete Beta function
+ nu_d = betainc(d, d, z)
+ result[mask2] = np.cos(np.pi * nu_d / 2.0)
return result.item() if np.isscalar(f) else result
-def wdm_dT_dF(nt, nf, dt):
+def wdm_dT_dF(Nt, Nf, delta_t):
"""
- Returns (ΔT, ΔF) for WDM with nt time bins, nf freq bins, and input sampling dt.
- ΔT = nf · dt
- ΔF = 1 / (2 · nf · dt)
+ WDM time-frequency grid parameters.
+
+ From Cornish: ΔT = Nf * Δt, ΔF = 1/(2*Nf*Δt)
+ where ΔT is time pixel width, ΔF is frequency pixel width.
"""
- ΔT = nf * dt
- ΔF = 1.0 / (2.0 * nf * dt)
- return (ΔT, ΔF)
+ Delta_T = Nf * delta_t
+ Delta_F = 1.0 / (2.0 * Nf * delta_t)
+ return Delta_T, Delta_F
-def wdm_times_frequencies(nt, nf, dt):
+def wdm_times_frequencies(Nt, Nf, delta_t):
"""
- Returns (ts, fs) for WDM:
- ts = ΔT · [0..nt−1], fs = ΔF · [0..nf−1],
- where (ΔT, ΔF) = wdm_dT_dF(nt,nf,dt).
+ Generate WDM time-frequency coordinate arrays.
"""
- ΔT, ΔF = wdm_dT_dF(nt, nf, dt)
- ts = np.arange(nt) * ΔT
- fs = np.arange(nf) * ΔF
- return ts, fs
+ Delta_T, Delta_F = wdm_dT_dF(Nt, Nf, delta_t)
+ t_n = np.arange(Nt) * Delta_T # Time coordinates t_n
+ f_m = np.arange(Nf) * Delta_F # Frequency coordinates f_m
+ return t_n, f_m
-def wdm_transform(x, nt, nf, A, d):
+def wdm_transform(x, Nt, Nf, A, d):
"""
- Forward WDM transform using np.roll + slice + reorder.
+ Forward WDM transform implementing Cornish Eqs. (16) and (17).
+
+ Eq. (16): w_nm = √2 (-1)^nm ℜ[C_nm * x_m[n]]
+ Eq. (17): x_m[n] = Σ_{l=-Nt/2}^{Nt/2-1} exp(-2πiln/Nt) X[l + mNt/2] Φ[l]
Args:
- x : 1D real array of length n_total = nt*nf
- nt : number of time bins (even)
- nf : number of frequency bins (even)
- A,d : Meyer window parameters (0 < A < 0.5, d > 0)
+ x: Input time series of length N = Nt*Nf
+ Nt: Number of time bins (must be even)
+ Nf: Number of frequency bins (must be even)
+ A, d: Meyer window parameters from Eqs. (11-12)
Returns:
- W : real array of shape (nt, nf) with WDM coefficients.
-
- Implementation notes:
- 1) Compute full FFT X_fft of length n_total.
- 2) Build a single φ-window of length nt by sampling fftfreq(n_total).
- 3) For each m = 1..nf−1:
- a) Compute shift = (n_total//2) − (m*(nt//2)).
- b) rolled = np.roll(X_fft, shift).
- c) slice_full = rolled[start : start+nt], where start = center − (nt//2).
- d) Reorder slice_full → [ positive_half, negative_half ].
- e) Multiply by φ-window, IFFT to get complex xnm_time of length nt.
- f) Multiply by C(n,m) = (1 if (n+m)%2==0 else 1j), take real, scale by √2/nf.
- 4) Column m=0 is zero.
+ w_nm: WDM coefficients, shape (Nt, Nf)
+ Regular bands in columns 1 to Nf-1
+ DC components in even rows of column 0
+ Nyquist components in odd rows of column 0
"""
-
- n_total = nt * nf
- if nt % 2 != 0 or nf % 2 != 0:
- raise ValueError("nt and nf must both be even.")
- if x.shape[-1] != n_total:
- raise ValueError(f"len(x)={x.shape[-1]} must equal nt*nf={n_total}.")
- if not (0 < A < 0.5):
- raise ValueError("A must be in (0, 0.5).")
-
- # 1) Compute FFT of full signal
- X_fft = fft(x) # length = n_total
-
- # 2) Build φ-window of length=nt
- _, dF_phi = wdm_dT_dF(nt, nf, 1.0)
- fs_full = fftfreq(n_total) # length = n_total
- half = nt // 2
- fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]]) # length = nt
- phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi) # length = nt
-
- # 3) Prepare output array
- W = np.zeros((nt, nf), dtype=float)
-
- center_idx = n_total // 2
- start = center_idx - half # starting index in rolled array
-
- # 4) For each sub-band m=1..nf-1:
- for m in range(1, nf):
- freq_bin = m * half
- shift = center_idx - freq_bin
- rolled = np.roll(X_fft, shift)
-
- # Slice exactly nt samples around center
- slice_full = rolled[start : start + nt] # length = nt
-
- # slice_full is ordered [ negative_half | positive_half ]
- neg_half = slice_full[:half]
- pos_half = slice_full[half:]
- # We need [pos_half | neg_half]
- block = np.concatenate([pos_half, neg_half])
-
- # Multiply by φ-window and IFFT
- xnm_time = ifft(block * phi_window) # length = nt, complex
-
- # Build parity factor C(n,m) = 1 if (n+m)%2==0 else 1j
- n_idx = np.arange(nt)
- parity = (n_idx + m) % 2
- C_col = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j)
-
- # Real part of conj(C)·xnm_time, scaled by √2/nf
- W[:, m] = (np.sqrt(2.0) / nf) * np.real(np.conj(C_col) * xnm_time)
-
- # Column m=0 remains zero
- return W
+ N = Nt * Nf # Total data length
+
+ # Validation
+ if x.shape[-1] != N:
+ raise ValueError(f"len(x) must be Nt*Nf = {N}")
+ if Nt % 2 or Nf % 2 or not (0 < A < 0.5) or d <= 0:
+ raise ValueError("Nt,Nf even; 00 required.")
+
+ # Step 1: Compute X[l] = FFT(x) (Cornish notation)
+ X = fft.fft(x) # [Dc, +ive freqs, -ive freqs, nyquist]
+
+ # Step 2: Build frequency domain window Φ[l] from Eq. (11)
+ # Note: "discrete Fourier samples are evaluated at f = l*Δf" (Cornish text)
+ l_freqs = fft.fftfreq(N) # l = 0,1,...,N-1 mapped to [-0.5, 0.5)
+ half = Nt // 2
+
+ # Reorder for l = -Nt/2, ..., Nt/2-1 as in Eq. (17)
+ l_indices = np.concatenate([l_freqs[:half], l_freqs[-half:]]) # length Nt
+ Delta_F_norm = 1.0 / (2.0 * Nf) # Normalized ΔF for unit sampling
+ Phi = Phi_unit(l_indices / Delta_F_norm, A, d) / np.sqrt(Delta_F_norm)
+
+ # Step 3: Initialize output w_nm
+ w_nm = np.zeros((Nt, Nf), dtype=float)
+
+ # Grid indices for parity calculations
+ center_idx = N // 2
+ start_idx = center_idx - half
+
+ # Step 4: Process regular frequency bands m = 1, 2, ..., Nf-1
+ # Following Eqs. (16) and (17)
+ for m in range(1, Nf):
+ # Eq. (17): Extract X[l + mNt/2] using frequency shift
+ freq_shift = m * half # mNt/2 term from Eq. (17)
+ roll_amount = center_idx - freq_shift
+ X_shifted = np.roll(X, roll_amount)
+
+ # Extract Nt samples: X[l + mNt/2] for l = -Nt/2, ..., Nt/2-1
+ X_slice = X_shifted[start_idx:start_idx + Nt]
+
+ # Reorder to match IFFT convention: [positive freqs, negative freqs]
+ X_reordered = np.concatenate([X_slice[half:], X_slice[:half]])
+
+ # Apply window Φ[l] and compute IFFT to get x_m[n] from Eq. (17)
+ x_m = fft.ifft(X_reordered * Phi)
+
+ # Eq. (16): Apply parity factors C_nm and phase (-1)^nm
+ n_indices = np.arange(Nt)
+ parity = (n_indices + m) % 2
+ # C_nm from Eq. (10): C_nm = 1 for (n+m) even, C_nm = i for (n+m) odd
+ C_nm = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j)
+
+ # Final formula: w_nm = √2 (-1)^nm ℜ[C_nm * x_m[n]]
+ # Note: (-1)^nm = 1 for even parity, -1 for odd parity (but absorbed in C_nm here)
+ w_nm[:, m] = (np.sqrt(2.0) / Nf) * np.real(C_nm * x_m)
+
+ # Step 5: Handle m = 0 (DC components) - Cornish Ref. [19] special case
+ # Store in even time indices of column 0
+ m = 0
+ roll_amount = center_idx # No frequency shift for DC
+ X_shifted = np.roll(X, roll_amount)
+ X_slice = X_shifted[start_idx:start_idx + Nt]
+ X_reordered = np.concatenate([X_slice[half:], X_slice[:half]])
+
+ # Zero negative frequencies for DC component
+ X_dc = X_reordered.copy()
+ X_dc[half:] = 0.0 # Zero negative frequencies
+ X_dc[0] /= 2.0 # Divide DC bin by 2
+
+ x_0 = fft.ifft(X_dc * Phi)
+ # Store DC components in even time indices with √2 normalization
+ even_indices = np.arange(0, Nt, 2)
+ w_nm[even_indices, 0] = np.real(x_0[even_indices]) * np.sqrt(2.0) / Nf
+
+ # Step 6: Handle m = Nf (Nyquist components) - store in odd indices of column 0
+ m = Nf
+ freq_shift = Nf * half
+ roll_amount = center_idx - freq_shift
+ X_shifted = np.roll(X, roll_amount)
+ X_slice = X_shifted[start_idx:start_idx + Nt]
+ X_reordered = np.concatenate([X_slice[half:], X_slice[:half]])
+
+ # Zero positive frequencies for Nyquist component
+ X_nyquist = X_reordered.copy()
+ X_nyquist[:half] = 0.0 # Zero positive frequencies
+ X_nyquist[half] /= 2.0 # Divide Nyquist bin by 2
+
+ x_Nf = fft.ifft(X_nyquist * Phi)
+ # Store Nyquist components in odd time indices
+ odd_indices = np.arange(1, Nt, 2)
+ even_source_indices = np.arange(0, Nt, 2) # Take even indices from x_Nf
+ w_nm[odd_indices, 0] = np.real(x_Nf[even_source_indices]) * np.sqrt(2.0) / Nf
+
+ return w_nm
def wdm_inverse_transform(W, A, d):
- """
- Inverse WDM using the “roll + reorder + sum” approach:
-
- Args:
- W : real array (nt, nf) of WDM coefficients
- A,d : Meyer window parameters
-
- Returns:
- 1D real signal of length n_total = nt * nf.
- """
-
nt, nf = W.shape
- if nt % 2 != 0 or nf % 2 != 0:
- raise ValueError("nt and nf must both be even.")
n_total = nt * nf
- if not (0 < A < 0.5):
- raise ValueError("A must be in (0, 0.5).")
-
- # 1) Build φ-window (same as forward)
- _, dF_phi = wdm_dT_dF(nt, nf, 1.0)
- fs_full = fftfreq(n_total)
+ # validation omitted for brevity
+ fs_full = fft.fftfreq(n_total)
half = nt // 2
fs_phi = np.concatenate([fs_full[:half], fs_full[-half:]])
- phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi)
-
- # 2) Build parity matrix C(n,m) and form ylm = C · W / sqrt(2) · nf
- n_idx = np.arange(nt)[:, None] # shape (nt,1)
- m_idx = np.arange(nf)[None, :] # shape (1,nf)
- parity = (n_idx + m_idx) % 2
- Cmat = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) # shape (nt,nf)
+ phi = Phi_unit(fs_phi / (1.0/(2.0*nf)), A, d)
- # ylm shape = (nt,nf)
- # (for m=0 we keep column zero)
- ylm = np.zeros((nt, nf), dtype=complex)
- ylm[:, 1:] = (Cmat[:, 1:] * W[:, 1:] / np.sqrt(2.0)) * nf
+ n = np.arange(nt)[:, None]
+ m = np.arange(nf)[None, :]
+ parity = (n + m) % 2
+ C = np.where(parity == 0, 1, 1j)
+ ylm = np.zeros((nt, nf), complex)
+ ylm[:,1:] = (C[:,1:] * W[:,1:] / np.sqrt(2.0)) * nf
- # 3) FFT each column along axis=0
- Y = np.fft.fft(ylm, axis=0) # shape (nt,nf)
-
- # 4) Reconstruct full-spectrum X_recon of length n_total
- X_recon = np.zeros(n_total, dtype=complex)
- center_idx = n_total // 2
- start = center_idx - half
-
- for m in range(1, nf):
- # Build the nt-length “block” in frequency space: Y[:,m] * φ-window
- block = Y[:, m] * phi_window # length = nt
-
- pos_half = block[:half]
- neg_half = block[half:]
- block_full = np.concatenate([neg_half, pos_half])
- # → this arranges [ negative_half | positive_half ] at indices [start:start+nt]
-
- # Place block_full into an otherwise-zero length-n_total array, then roll:
- temp_full = np.zeros(n_total, dtype=complex)
- temp_full[start : start + nt] = block_full
-
- # Roll so that the “center” (index=center_idx) goes to freq_bin=m*half
- freq_bin = m * half
- shift = freq_bin - center_idx
- X_recon += np.roll(temp_full, shift)
-
- # 5) IFFT back to time domain
- x_time = ifft(X_recon)
- return np.real(x_time)
-
-
-
-
-def wdm_inverse_direct(W, A, d):
- """
- Inverse WDM transform (direct indexing, no roll -- i think i might have a bug :( )). Given W shape=(nt,nf),
- returns the real 1D signal x of length = nt*nf.
-
- Algorithm:
- nt, nf must be even, W[:,0]=0. For m=1..nf-1:
- 1) Build ylm[n,m] = C(n,m)*W[n,m]/√2 * nf, where C(n,m)=1 or 1j by parity.
- 2) Y[:,m] = FFT(ylm[:,m]) (length=nt).
- 3) block = Y[:,m] * φ_window (length=nt, with φ_window from fftfreq(nt*nf)).
- 4) Split block = [pos_half|neg_half], where pos_half = block[:nt//2], neg_half=block[nt//2:].
- 5) l0 = m*(nt//2). Add pos_half into X[l0 : l0+half]; add neg_half into X[l0-half : l0].
- 6) After looping m=1..nf-1, IFFT(X) → time, return real part.
- """
- nt, nf = W.shape
- if nt % 2 != 0 or nf % 2 != 0:
- raise ValueError("nt and nf must both be even.")
- n_total = nt * nf
- if not (0 < A < 0.5):
- raise ValueError("A must be in (0,0.5).")
-
- # 1) Build φ_window of length=nt
- _, dF_phi = wdm_dT_dF(nt, nf, 1.0)
- fs_full = fftfreq(n_total)
- half = nt // 2
- fs_phi = np.concatenate((fs_full[:half], fs_full[-half:])) # length=nt
- phi_window = Phi_unit(fs_phi / dF_phi, A, d) / np.sqrt(dF_phi)
-
- # 2) Build parity matrix C(n,m) and compute ylm
- n_idx = np.arange(nt)[:, None] # shape (nt,1)
- m_idx = np.arange(nf)[None, :] # shape (1,nf)
- parity = (n_idx + m_idx) % 2
- Cmat = np.where(parity == 0, 1.0 + 0.0j, 0.0 + 1.0j) # shape (nt,nf)
-
- # ylm[n,m] = C(n,m) * W[n,m] / sqrt(2) * nf (for m>=1; ylm[:,0]=0)
- ylm = np.zeros((nt, nf), dtype=complex)
- ylm[:, 1:] = (Cmat[:, 1:] * W[:, 1:] / np.sqrt(2.0)) * nf
-
- # 3) FFT each column of ylm along axis=0 → Y (shape (nt,nf))
- Y = np.fft.fft(ylm, axis=0)
-
- # 4) Reconstruct full-spectrum X_recon of length n_total by adding each band’s contributions
- X_recon = np.zeros(n_total, dtype=complex)
- half = nt // 2
+ Y = fft.fft(ylm, axis=0)
+ X_rec = np.zeros(n_total, complex)
+ center = n_total // 2
+ start = center - half
for m in range(1, nf):
- # Build the nt-length “block” = [pos_half | neg_half]
- block = Y[:, m] * phi_window # length = nt
-
- # Split block into pos/neg
- pos_half = block[:half]
- neg_half = block[half:]
-
- # l0 = m*(nt/2)
- l0 = m * half
-
- # Add pos_half into X_recon[l0 : l0 + half]
- X_recon[l0 : l0 + half] += pos_half
-
- # Add neg_half into X_recon[l0 - half : l0]
- X_recon[l0 - half : l0] += neg_half
-
- # (No extra “conj” handling needed because ylm was built with the correct parity factor)
-
- # 5) IFFT back to time domain
- x_time = ifft(X_recon)
- return np.real(x_time)
+ block = Y[:, m] * phi
+ pos, neg = block[:half], block[half:]
+ temp = np.zeros(n_total, complex)
+ temp[start:start+nt] = np.concatenate([neg, pos])
+ X_rec += np.roll(temp, m * half - center)
+ return np.real(fft.ifft(X_rec))