In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pywt
import os
import pandas as pd
import statsmodels.api as sm

# Generate an example spectrum (signal + noise)
freqs = np.linspace(0, 500, 1000)  # Frequency range
signal = np.exp(-((freqs - 200) ** 2) / (2 * 30 ** 2))  # Gaussian peak (signal)
noise = np.random.normal(0, 0.02, size=freqs.shape)  # Additive white noise
spectrum = signal + noise + 0.05  # Combine signal, noise, and a baseline floor

# First navigate to our directory
directory_path = os.path.join("Data", "processed_df.parquet")
# Load the dataframe
df = pd.read_parquet(directory_path)
freqs = df.iloc[0]['freqs']
spectrum = df.iloc[0]['spectrum']
spectrum = np.array(spectrum, copy=True)

# Plot the original spectrum
# plt.figure(figsize=(12, 6))
# plt.plot(freqs, spectrum, label="Original Spectrum", alpha=0.7)
# plt.xlabel("Frequency (Hz)")
# plt.ylabel("Magnitude")
# plt.title("Original Spectrum")
# plt.legend()
# plt.show()

lowess = sm.nonparametric.lowess
sigma= 0.08  # local-ness factor {0.1-0.2}
fit= lowess(spectrum,freqs,frac=sigma, return_sorted=False)

# Perform wavelet decomposition
wavelet = 'sym5'  # Daubechies wavelet
level = pywt.dwt_max_level(len(spectrum), wavelet)  # Maximum decomposition level
coeffs = pywt.wavedec(spectrum, wavelet, level=level)

# Apply soft thresholding to detail coefficients (high frequencies)
sigma = np.std(coeffs[-1])  # Estimate noise standard deviation from the last detail level
threshold = sigma * np.sqrt(2 * np.log(len(spectrum)))  # Universal threshold
denoised_coeffs = [coeffs[0]] + [pywt.threshold(c, threshold, mode='soft') for c in coeffs[1:]]

# Reconstruct the noise floor from approximation coefficients
noise_floor = pywt.waverec([denoised_coeffs[0]] + [np.zeros_like(c) for c in denoised_coeffs[1:]], wavelet)

# Plot the spectrum and the estimated noise floor
plt.figure(figsize=(12, 6))
plt.plot(freqs, spectrum, label="Original Spectrum", alpha=0.7)
plt.plot(freqs, noise_floor[:len(freqs)], label="Wavelet Noise Floor", linestyle='--', linewidth=2)
plt.plot(freqs, fit[:len(freqs)], label="Lowess Noise Floor", linestyle='--', linewidth=2)
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.title("Wavelet-Based Noise Floor Estimation")
plt.legend()
plt.show()
