In [1]:
import torch
import numpy as np
from scipy.signal import welch
from sbi import utils as sbi_utils
from sbi.inference import SNPE
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt

# Load your trained generator and posterior (already done)
from tensorflow.keras.models import load_model
generator = load_model("saved_models/qpo_cgan_phy_generator.keras")

# Load pre-trained SBI posterior (already trained)
# Assume: posterior = inference.build_posterior(density_estimator)

posterior = torch.load("trained_sbi_posterior.pt")

  posterior = torch.load("trained_sbi_posterior.pt")


In [2]:
from scipy.optimize import curve_fit

def lorentzian(f, A, f0, gamma):
    return A / (1 + ((f - f0) / gamma) ** 2)


In [3]:
from scipy.signal import welch
from scipy.optimize import curve_fit

# def compute_lorentzian_q(series, fs=1, f_window=(0.001, 0.5)):
#     f, Pxx = welch(series.squeeze(), fs=fs, nperseg=256)

#     # Filter peak region
#     mask = (f > f_window[0]) & (f < f_window[1])
#     f_peak = f[mask]
#     Pxx_peak = Pxx[mask]

#     try:
#         # Initial guess for A, f0, gamma
#         p0 = [np.max(Pxx_peak), f_peak[np.argmax(Pxx_peak)], 0.01]
#         popt, _ = curve_fit(lorentzian, f_peak, Pxx_peak, p0=p0, maxfev=5000)
#         A, f0, gamma = popt
#         Q = f0 / gamma if gamma != 0 else 0
#         return Q
#     except:
#         return 0.0  # Fallback if fit fails
    

def compute_lorentzian_q(series, fs=1, f_window=(0.001, 0.5)):
    f, Pxx = welch(series.squeeze(), fs=fs, nperseg=256)

    mask = (f > f_window[0]) & (f < f_window[1])
    f_peak = f[mask]
    Pxx_peak = Pxx[mask]

    try:
        p0 = [np.max(Pxx_peak), f_peak[np.argmax(Pxx_peak)], 0.01]
        popt, _ = curve_fit(lorentzian, f_peak, Pxx_peak, p0=p0, maxfev=5000)
        A, f0, gamma = popt

        if gamma <= 0 or f0 <= 0:
            return 0.0  # ❗invalid fit → flat
        return f0 / gamma
    except:
        return 0.0  # ❗fallback



In [4]:
def detect_qpo_from_real_band_q(curve, posterior, clf):
    f, Pxx = welch(curve.squeeze(), fs=1, nperseg=256)
    x_obs = torch.tensor(Pxx, dtype=torch.float32)

    samples = posterior.sample((500,), x=x_obs, show_progress_bars=False)
    fc_mean = samples[:, 0].mean().item()
    amp_mean = samples[:, 1].mean().item()
    fc_std = samples[:, 0].std().item()
    Q = compute_lorentzian_q(curve)

    qpo_label = clf.predict([[fc_mean, amp_mean, fc_std, Q]])[0]

    return {
        "fc_mean": fc_mean,
        "amp_mean": amp_mean,
        "fc_std": fc_std,
        "Q": Q,
        "qpo": bool(qpo_label),
        "samples": samples
    }


In [5]:
def generate_sbi_training_data_with_q(num_simulations=2000, latent_dim=100):
    X, y = [], []

    for _ in range(num_simulations // 2):
        for is_qpo in [1, 0]:  # QPO and non-QPO
            fc = np.random.uniform(0.01, 1.0)
            amp = np.random.uniform(0.1, 1.0)
            label = np.array([[fc, amp, is_qpo]], dtype=np.float32)
            z = np.random.randn(1, latent_dim)
            generated = generator([z, label], training=False).numpy().squeeze()
            f, Pxx = welch(generated, fs=1, nperseg=256)
            x = torch.tensor(Pxx, dtype=torch.float32)
            samples = posterior.sample((500,), x=x, show_progress_bars=False)
            fc_mean = samples[:, 0].mean().item()
            amp_mean = samples[:, 1].mean().item()
            fc_std = samples[:, 0].std().item()
            Q = compute_lorentzian_q(generated)
            X.append([fc_mean, amp_mean, fc_std, Q])
            y.append(is_qpo)
    return np.array(X), np.array(y)

# Train classifier
X_train, y_train = generate_sbi_training_data_with_q()
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)


In [15]:
def generate_sbi_training_data_q_filtered(num_simulations=2000, latent_dim=100, include_power=True):
    from scipy.signal import welch
    X, y = [], []

    for _ in range(num_simulations // 2):
        for is_qpo in [1, 0]:
            fc = np.random.uniform(0.01, 1.0)
            amp = np.random.uniform(0.1, 1.0)
            label = np.array([[fc, amp, is_qpo]], dtype=np.float32)
            z = np.random.randn(1, latent_dim)
            generated = generator([z, label], training=False).numpy().squeeze()

            # Compute PSD
            f, Pxx = welch(generated, fs=1, nperseg=256)
            x = torch.tensor(Pxx, dtype=torch.float32)

            # Get SBI posterior samples
            samples = posterior.sample((500,), x=x, show_progress_bars=False)
            fc_mean = samples[:, 0].mean().item()
            amp_mean = samples[:, 1].mean().item()
            fc_std = samples[:, 0].std().item()

            # Lorentzian Q
            Q = compute_lorentzian_q(generated)
            peak_power = np.max(Pxx)

            # ✅ STRONG QPO ONLY
            if is_qpo == 1 and Q < 4:
                continue  # skip weak QPO samples

            # Build feature vector
            if include_power:
                X.append([fc_mean, amp_mean, fc_std, Q, peak_power])
            else:
                X.append([fc_mean, amp_mean, fc_std, Q])
            y.append(is_qpo)

    return np.array(X), np.array(y)


In [None]:
from sklearn.ensemble import RandomForestClassifier

# Generate data
X_train, y_train = generate_sbi_training_data_q_filtered()

# Train classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight={0: 1, 1: 2})
clf.fit(X_train, y_train)


In [None]:
def detect_qpo_with_q_power(curve, posterior, clf):
    f, Pxx = welch(curve.squeeze(), fs=1, nperseg=256)
    x_obs = torch.tensor(Pxx, dtype=torch.float32)
    samples = posterior.sample((500,), x=x_obs, show_progress_bars=False)

    fc_mean = samples[:, 0].mean().item()
    amp_mean = samples[:, 1].mean().item()
    fc_std = samples[:, 0].std().item()
    Q = compute_lorentzian_q(curve)
    peak_power = np.max(Pxx)

    proba = clf.predict_proba([[fc_mean, amp_mean, fc_std, Q, peak_power]])[0][1]

    return {
        "fc_mean": fc_mean,
        "amp_mean": amp_mean,
        "fc_std": fc_std,
        "Q": Q,
        "peak_power": peak_power,
        "qpo_proba": proba,
        "qpo": proba > 0.45,
        "samples": samples
    }


In [None]:
data = np.loadtxt("ltcrv4bands_rej_dt100.dat")
bands = [data[:, i] for i in range(4)]

for i, band in enumerate(bands):
    result = detect_qpo_with_q_power(band, posterior, clf)
    print(f"\n🎧 Band {i+1}:")
    print(f"→ fc_mean: {result['fc_mean']:.3f}, fc_std: {result['fc_std']:.3f}")
    print(f"→ amp_mean: {result['amp_mean']:.3f}, Q: {result['Q']:.2f}, Peak Power: {result['peak_power']:.3f}")
    print(f"→ QPO Probability: {result['qpo_proba']:.2f}")
    print(f"→ QPO Detected? {'✅ YES' if result['qpo'] else '❌ NO'}")



🎧 Band 1:
→ fc_mean: 0.494, fc_std: 0.239
→ amp_mean: 0.556, Q: 2.71, Peak Power: 11.788
→ QPO Probability: 0.00
→ QPO Detected? ❌ NO

🎧 Band 2:
→ fc_mean: 0.512, fc_std: 0.240
→ amp_mean: 0.560, Q: 5.52, Peak Power: 0.260
→ QPO Probability: 0.62
→ QPO Detected? ✅ YES

🎧 Band 3:
→ fc_mean: 0.491, fc_std: 0.241
→ amp_mean: 0.555, Q: 5.39, Peak Power: 0.163
→ QPO Probability: 0.68
→ QPO Detected? ✅ YES

🎧 Band 4:
→ fc_mean: 0.504, fc_std: 0.241
→ amp_mean: 0.567, Q: -25.79, Peak Power: 0.002
→ QPO Probability: 0.03
→ QPO Detected? ❌ NO


In [45]:
data = np.loadtxt("ltcrv4bands_rej_dt100.dat")
bands = [data[:, i] for i in range(4)]

for i, band in enumerate(bands):
    result = detect_qpo_from_real_band_q(band, posterior, clf)
    print(f"\n🎧 Band {i+1}:")
    print(f"→ fc_mean: {result['fc_mean']:.3f}, fc_std: {result['fc_std']:.3f}")
    print(f"→ amp_mean: {result['amp_mean']:.3f}, Lorentzian Q: {result['Q']:.2f}")
    print(f"→ QPO Detected? {'✅ YES' if result['qpo'] else '❌ NO'}")



🎧 Band 1:
→ fc_mean: 0.487, fc_std: 0.243
→ amp_mean: 0.549, Lorentzian Q: 2.71
→ QPO Detected? ❌ NO

🎧 Band 2:
→ fc_mean: 0.507, fc_std: 0.249
→ amp_mean: 0.557, Lorentzian Q: 5.52
→ QPO Detected? ❌ NO

🎧 Band 3:
→ fc_mean: 0.512, fc_std: 0.242
→ amp_mean: 0.561, Lorentzian Q: 5.39
→ QPO Detected? ❌ NO

🎧 Band 4:
→ fc_mean: 0.516, fc_std: 0.254
→ amp_mean: 0.545, Lorentzian Q: 0.00
→ QPO Detected? ❌ NO
