# RFI Flagger Demo

Demonstration of two-phase iterative sigma-clipping for Radio Frequency Interference flagging in waterfall data.

In [None]:
import sys
sys.path.insert(0, "..")

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from MomentRFI import load_waterfall, validate_waterfall, IterativeSurfaceFitter
from MomentRFI.plotting import (
    plot_waterfall, plot_mask, plot_residuals,
    plot_convergence, plot_summary
)

## 1. Load & Validate Data

In [None]:
DATA_PATH = "../2025-12-02_16-54-49_obs.hd5f"

waterfall, freqs, times = load_waterfall(DATA_PATH)
info = validate_waterfall(waterfall)

print(f"Shape: {info['shape']}")
print(f"Value range: {info['min']:.4e} to {info['max']:.4e}")
print(f"NaN: {info['has_nan']}, Zeros: {info['has_zero']}, Negative: {info['has_negative']}")
print(f"Frequency range: {freqs[0]:.2f} - {freqs[-1]:.2f} MHz")
print(f"Time range: {times[0]:.1f} - {times[-1]:.1f} s")

## 2. Visualize Raw Waterfall

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 5))
plot_waterfall(waterfall, freqs, times, ax=ax, title="Raw Waterfall")
plt.show()

## 3. Run Default Fitter (Phase 1 + Phase 2)

In [None]:
fitter = IterativeSurfaceFitter(
    sigma_threshold=4.0,
    phase1_degree=5,
    phase2_degree_freq=10,
    phase2_degree_time=5,
)

mask = fitter.fit(waterfall)

## 4. Summary: Original, Surface, Cleaned, Residuals, Mask

In [None]:
fig = plot_summary(waterfall, fitter, freqs, times)
plt.show()

## 5. Convergence Diagnostics

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
plot_convergence(fitter.history, ax=axes)
plt.show()

## 6. Zoom into Flagged Regions

In [None]:
# Find the frequency channel with the most flags
flags_per_freq = fitter.mask.sum(axis=0)
top_freq_idx = np.argsort(flags_per_freq)[-5:][::-1]

print("Top 5 most-flagged frequency channels:")
for idx in top_freq_idx:
    print(f"  Channel {idx} ({freqs[idx]:.3f} MHz): {flags_per_freq[idx]} flags")

# Zoom around the worst channel
center = top_freq_idx[0]
half_w = 200
sl = slice(max(0, center - half_w), min(waterfall.shape[1], center + half_w))

fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

from matplotlib.colors import LogNorm
extent_zoom = [freqs[sl.start], freqs[sl.stop - 1], times[-1], times[0]]

axes[0].imshow(waterfall[:, sl], aspect="auto", extent=extent_zoom,
               norm=LogNorm(), cmap="viridis", interpolation="none")
axes[0].set_title(f"Zoomed waterfall around {freqs[center]:.2f} MHz")
axes[0].set_ylabel("Time [s]")

axes[1].imshow(fitter.mask[:, sl].astype(float), aspect="auto", extent=extent_zoom,
               cmap="Reds", vmin=0, vmax=1, interpolation="none")
axes[1].set_title("RFI Mask (zoomed)")
axes[1].set_xlabel("Frequency [MHz]")
axes[1].set_ylabel("Time [s]")

plt.tight_layout()
plt.show()

## 7. Noise Estimator Comparison: MAD vs Lower-Tail

The default **MAD** (Median Absolute Deviation) estimator is robust up to ~50% contamination.
But if the RFI fraction exceeds 50%, MAD breaks down because the median itself
gets pulled by outliers.

The **lower-tail** estimator exploits the fact that RFI *adds* power, so it only
inflates the *upper* tail of the residual distribution. The lower tail should be
clean Gaussian noise. We histogram the bottom 20% of residuals and fit a zero-mean
Gaussian $A\exp(-x^2/2\sigma^2)$ via linear regression of $\log(\text{counts})$ vs
$x^2$ — an analytical closed-form solution with no iterative optimisation.

**Trade-off:** On this relatively clean dataset, the lower tail includes both
noise *and* genuine spectral structure that the polynomial cannot capture, so the
fitted Gaussian is wider than the true noise. The lower-tail estimator therefore
gives a *larger* sigma and flags *fewer* pixels than MAD. On a heavily contaminated
dataset (>50% RFI), the situation reverses: MAD is corrupted while the lower tail
remains clean.

In [None]:
# Run both estimators with the same settings
fitter_mad = IterativeSurfaceFitter(noise_estimator="mad", verbose=False)
mask_mad = fitter_mad.fit(waterfall)

fitter_lt = IterativeSurfaceFitter(noise_estimator="lower_tail", lower_tail_fraction=0.2, verbose=False)
mask_lt = fitter_lt.fit(waterfall)

print(f"MAD estimator:        {mask_mad.sum():>8d} flagged ({mask_mad.sum()/mask_mad.size:.4%}),  sigma_floor = {fitter_mad.sigma_floor:.6f}")
print(f"Lower-tail estimator: {mask_lt.sum():>8d} flagged ({mask_lt.sum()/mask_lt.size:.4%}),  sigma_floor = {fitter_lt.sigma_floor:.6f}")

In [None]:
# Side-by-side convergence comparison
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Sigma evolution
for label, h, marker in [("MAD", fitter_mad.history, "o"), ("Lower-tail", fitter_lt.history, "s")]:
    if h["phase1"]:
        axes[0].plot([x["iteration"] for x in h["phase1"]],
                     [x["sigma"] for x in h["phase1"]],
                     f"{marker}-", label=f"{label} P1")
    if h["phase2"]:
        axes[0].plot([x["iteration"] for x in h["phase2"]],
                     [x["sigma_used"] for x in h["phase2"]],
                     f"{marker}--", label=f"{label} P2")
axes[0].set_xlabel("Iteration"); axes[0].set_ylabel("Sigma")
axes[0].set_title("Sigma Evolution"); axes[0].legend(fontsize=8)

# Flag fraction
for label, h, marker in [("MAD", fitter_mad.history, "o"), ("Lower-tail", fitter_lt.history, "s")]:
    if h["phase2"]:
        axes[1].plot([x["iteration"] for x in h["phase2"]],
                     [x["flag_fraction"]*100 for x in h["phase2"]],
                     f"{marker}-", label=label)
axes[1].set_xlabel("Iteration"); axes[1].set_ylabel("Flagged (%)")
axes[1].set_title("Phase 2 Flag Fraction"); axes[1].legend()

# Mask comparison
fig2, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(16, 5))
plot_mask(mask_mad, freqs, times, ax=ax_a, title="MAD mask")
plot_mask(mask_lt, freqs, times, ax=ax_b, title="Lower-tail mask")
plt.tight_layout()
plt.show()

### Discussion

For this dataset (~1% true RFI), the two estimators behave differently:

- **MAD** converges in ~12 iterations to sigma ~0.27 and flags ~0.6%. The sigma floor
  prevents runaway in Phase 2.
- **Lower-tail** fits a wider Gaussian (sigma ~0.56) because the lower tail of the
  residuals includes unfitted spectral structure, not just noise. It converges in just
  2 iterations, flagging only ~0.01%.

**When to use `lower_tail`:** Datasets where RFI contaminates >50% of pixels. In that
regime MAD is corrupted, but the lower tail (clean by construction — RFI adds power)
still gives a valid sigma. You may want to lower `sigma_threshold` to compensate for the
wider sigma estimate.

**When to stick with `mad`:** Datasets with moderate RFI (<50%), like this one. MAD is
more sensitive to faint RFI and the two-phase sigma floor already prevents runaway.

In [None]:
# Visualise: compare the two masks and highlight prior pixels
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

plot_mask(mask_noprior, freqs, times, ax=axes[0], title="Mask — no prior")
plot_mask(mask_prior,   freqs, times, ax=axes[1], title="Mask — with prior")

# Third panel: pixels added *only* by the prior (present in mask_prior but not mask_noprior)
prior_only = mask_prior & ~mask_noprior
axes[2].imshow(prior_only.astype(float), aspect="auto",
               extent=[freqs[0], freqs[-1], times[-1], times[0]],
               cmap="Blues", vmin=0, vmax=1, interpolation="none")
axes[2].set_title("Pixels added by prior mask only")
axes[2].set_xlabel("Frequency [MHz]")
axes[2].set_ylabel("Time [s]")

plt.tight_layout()
plt.show()

# Shape-mismatch validation
try:
    fitter_prior.fit(waterfall, prior_mask=np.zeros((10, 10), dtype=bool))
except ValueError as e:
    print(f"\nShape mismatch correctly raises ValueError:\n  {e}")

In [None]:
# Build a prior mask from the top-5 most-flagged frequency channels found in Section 6.
# These are persistently contaminated channels we want to pre-exclude from the fit.
prior = np.zeros(waterfall.shape, dtype=bool)
prior[:, top_freq_idx] = True  # flag all time integrations for those channels

print(f"Prior mask: {prior.sum()} pixels pre-flagged "
      f"({prior.sum()/prior.size:.4%} of waterfall)")

# Run fitter with and without prior mask (same settings, verbose=False for brevity)
fitter_noprior = IterativeSurfaceFitter(verbose=False)
fitter_prior   = IterativeSurfaceFitter(verbose=False)

mask_noprior = fitter_noprior.fit(waterfall)
mask_prior   = fitter_prior.fit(waterfall, prior_mask=prior)

# Verify: every prior pixel must be True in the output mask
assert np.all(mask_prior[prior]), "Prior pixels must always be flagged"

print(f"\nWithout prior mask: {mask_noprior.sum():>8d} flagged ({mask_noprior.sum()/mask_noprior.size:.4%})")
print(f"With    prior mask: {mask_prior.sum():>8d} flagged ({mask_prior.sum()/mask_prior.size:.4%})")
print(f"  of which {prior.sum()} are from prior, "
      f"{(mask_prior & ~prior).sum()} from iterative sigma-clipping")

## 8. A Priori Mask (`prior_mask`)

If you already know certain channels or time integrations are bad (e.g., known persistent interference, a dead receiver element, or an observatory-flagged band edge), you can supply that information upfront via `prior_mask`. Prior-flagged pixels are:

- **excluded from the polynomial surface fit** in both phases, so they cannot bias the smooth background model.
- **always `True` in the returned mask**, regardless of their residual value — they are never re-admitted by sigma-clipping.

The sigma floor and convergence logic are unaffected: they operate only on the unflagged population.