In [None]:
from scipy.signal import hilbert, find_peaks, savgol_filter
from scipy.ndimage import gaussian_filter1d, median_filter
from scipy.ndimage import grey_closing

"""
Imports and configuration
"""
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import glob
import math
from scipy.interpolate import CubicSpline
from typing import Tuple, Optional

# Parameters (edit these to match your data locations)
# Use Path for robust path handling across platforms
# `myloc` from your script appended to the DAQ base
sig_path = Path(r"C:\Users\sann7609\Documents\Oxford\ChannelAnalysis_CALA_Sept25_minisforum\vanilla_interferometry\for_vanilla_80mbar_Bdel582_t0\Interferometry2\sig")
bg_path  = Path(r"C:\Users\sann7609\Documents\Oxford\ChannelAnalysis_CALA_Sept25_minisforum\vanilla_interferometry\for_vanilla_80mbar_Bdel582_t0\Interferometry2\bg")
SIG_HEADER = '.tiff'   # substring common to signal image filenames
BG_HEADER  = '.tiff'     # substring common to background image filenames

# How many shots to process. By default we'll infer from files found.
NUM_SHOTS: Optional[int] = 1  # set to an integer to limit processing

# Display scaling for final image
PHASE_DISPLAY_CLIM = (-0.05, 0.05)

# Locate files

sig_paths = sorted([p for p in sig_path.glob(f"*{SIG_HEADER}*")])
bg_paths  = sorted([p for p in bg_path.glob(f"*{BG_HEADER}*")])

print(f"Found {len(sig_paths)} signal files and {len(bg_paths)} background files in {sig_path}")

# Optionally limit number of shots
if NUM_SHOTS is not None:
    sig_paths = sig_paths[:NUM_SHOTS]
    bg_paths  = bg_paths[:NUM_SHOTS]

# Sanity check: at least one file
if len(sig_paths) == 0:
    raise FileNotFoundError(f"No signal files found using pattern *{SIG_HEADER}* in {sig_path}")


# Read a single image to get frame shape (handle color/grayscale)
from imageio import imread

first_img = imread(sig_paths[0])
print("First image shape (as read):", first_img.shape)

# If color (H,W,3) convert to grayscale by luminosity method
if first_img.ndim == 3:
    frameshape = (first_img.shape[0], first_img.shape[1])
else:
    frameshape = first_img.shape

print("Frame shape (H, W):", frameshape)

# Read all files into arrays (shots, H, W)
num_files = min(len(sig_paths), len(bg_paths))
if NUM_SHOTS is not None:
    num_files = min(num_files, NUM_SHOTS)

RawInterferogramssig = np.zeros((num_files, frameshape[0], frameshape[1]), dtype=float)
RawInterferogramsbg  = np.zeros((num_files, frameshape[0], frameshape[1]), dtype=float)

for i in range(num_files):
    s = imread(sig_paths[i]).astype(float)
    b = imread(bg_paths[i]).astype(float)
    # Convert color to grayscale if needed
    if s.ndim == 3:
        s = 0.2126 * s[...,0] + 0.7152 * s[...,1] + 0.0722 * s[...,2]
    if b.ndim == 3:
        b = 0.2126 * b[...,0] + 0.7152 * b[...,1] + 0.0722 * b[...,2]

    RawInterferogramssig[i,:,:] = s
    RawInterferogramsbg[i,:,:]  = b

print(f"Loaded {num_files} shots into arrays: sig {RawInterferogramssig.shape}")

# Quick visual check
plt.figure(figsize=(6,6))
plt.imshow(RawInterferogramssig[0,:,:], cmap='gray')
plt.title('Example signal interferogram (shot 0)')
plt.axis('off')
plt.show()

def robust_envelope(signal,
                    use_hilbert=True,
                    smooth_sigma=2.0,
                    savgol_win=11,
                    savgol_poly=3,
                    med_k=5,
                    morph_size=9,
                    peak_prom_frac=0.2,
                    min_peak_distance=5):
    """
    Return (upper_env, lower_env) for 1D `signal` robustly.
    Strategy:
      1) If use_hilbert: take analytic amplitude via Hilbert transform.
      2) Smooth the amplitude (gaussian -> median -> savgol).
      3) For upper env use the smoothed amplitude added to a (smoothed) center.
      4) Also compute peak-based envelopes as a fallback and blend if necessary.
    Parameters tuned in-call; adjust for your fringes.
    """
    signal = np.asarray(signal, dtype=float)
    t = np.arange(signal.size)

    # 1) quick detrend/center to make amplitude estimate stable
    sig_centered = signal - np.nanmean(signal)

    # 2) amplitude from analytic signal (Hilbert)
    amp = np.abs(hilbert(sig_centered)) if use_hilbert else None

    if amp is None or np.all(np.isnan(amp)) or np.nanmax(amp)-np.nanmin(amp) < 1e-8:
        # fallback later to peak-based method
        amp = None

    # 3) smooth the amplitude if available
    if amp is not None:
        amp_s = gaussian_filter1d(amp, sigma=smooth_sigma, mode='reflect')
        amp_s = median_filter(amp_s, size=med_k)
        # savgol requires odd window and be smaller than length
        if savgol_win >= len(amp_s):
            savgol_win = max(3, (len(amp_s) // 2) // 2 * 2 + 1)
        try:
            amp_s = savgol_filter(amp_s, window_length=savgol_win, polyorder=savgol_poly)
        except Exception:
            pass
    else:
        amp_s = None

    # 4) estimate local centerline by low-pass filtering signal
    center_lp = gaussian_filter1d(signal, sigma=smooth_sigma*2, mode='reflect')
    center_lp = median_filter(center_lp, size=med_k)
    if savgol_win < len(center_lp):
        center_lp = savgol_filter(center_lp, window_length=savgol_win, polyorder=savgol_poly)

    # 5) construct upper/lower using amp if available
    if amp_s is not None:
        upper = center_lp + amp_s
        lower = center_lp - amp_s
    else:
        # Peak-based fallback (more like your original approach but smoothed)
        prom = peak_prom_frac * (signal.max() - signal.min())
        peaks, _ = find_peaks(sig_centered, prominence=prom, distance=min_peak_distance)
        valleys, _ = find_peaks(-sig_centered, prominence=prom, distance=min_peak_distance)

        # require at least two points to interpolate; otherwise fallback to global min/max
        t_inds = np.arange(len(signal))
        if len(peaks) >= 2 and len(valleys) >= 2:
            upper = np.interp(t_inds, peaks, signal[peaks])
            lower = np.interp(t_inds, valleys, signal[valleys])
            # smooth those interpolated envelopes
            upper = gaussian_filter1d(upper, sigma=smooth_sigma)
            lower = gaussian_filter1d(lower, sigma=smooth_sigma)
            upper = median_filter(upper, size=med_k)
            lower = median_filter(lower, size=med_k)
        else:
            # very coarse fallback
            upper = np.full_like(signal, np.nanmax(signal))
            lower = np.full_like(signal, np.nanmin(signal))

    # 6) morphological closing to remove local dips/spikes
    try:
        upper = grey_closing(upper, size=morph_size)
        lower = grey_closing(lower, size=morph_size)
    except Exception:
        pass

    # 7) final Savitzky smoothing to ensure differentiability for spline fits
    try:
        upper = savgol_filter(upper, window_length=min(len(upper)-1 if (len(upper)-1)%2==1 else len(upper)-2, savgol_win), polyorder=min(savgol_poly,2))
        lower = savgol_filter(lower, window_length=min(len(lower)-1 if (len(lower)-1)%2==1 else len(lower)-2, savgol_win), polyorder=min(savgol_poly,2))
    except Exception:
        pass

    return upper, lower

def fit_envelope_spline(t, upper, lower):
    """Safely fit cubic splines to envelopes (skips if contains NaN)."""
    if np.any(np.isnan(upper)) or np.any(np.isnan(lower)):
        return upper, lower
    try:
        u_fit = CubicSpline(t, upper)(t)
        l_fit = CubicSpline(t, lower)(t)
        return u_fit, l_fit
    except Exception:
        # fallback to smoothed arrays if spline fails
        return upper, lower


def remove_offset(signal: np.ndarray) -> Tuple[np.ndarray, float]:
    """Remove the mean offset from the signal.

    Returns (centered_signal, offset)
    """
    offset = float(np.nanmean(signal))
    return signal - offset, offset


def fit_envelopes(t: np.ndarray, upper_envelope: np.ndarray, lower_envelope: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Fit cubic splines to the provided envelopes and evaluate on t.

    If envelopes contain NaNs (fallback), the function returns the original arrays.
    """
    # If envelopes are constant or contain NaN, skip spline fit
    if np.any(np.isnan(upper_envelope)) or np.any(np.isnan(lower_envelope)):
        return upper_envelope, lower_envelope

    upper_spline = CubicSpline(t, upper_envelope)
    lower_spline = CubicSpline(t, lower_envelope)
    return upper_spline(t), lower_spline(t)


def extract_phase_from_column(signal: np.ndarray, plot_flag: bool = False) -> np.ndarray:
    """Extract an unwrapped phase from a 1D interferometric column using envelope demodulation.

    Steps (high-level):
    - estimate upper and lower envelopes
    - compute centerline and iteratively remove it to flatten amplitude modulation
    - normalise by the upper envelope to get a cosine-like signal in [-1, 1]
    - arc-cos to obtain a triangular phase signal and fix downward slopes by mapping them into [0, 2pi]
    - unwrap the phase

    Returns the unwrapped phase array.
    """
    t = np.arange(signal.size)

    # 1) First pass envelopes
    upper_env, lower_env = robust_envelope(signal)
    upper_fit, lower_fit = fit_envelopes(t, upper_env, lower_env)
    centerline = 0.5 * (upper_fit + lower_fit)

    # 2) Remove centerline and repeat a few times to converge (your original used many iterations)
    signal_iter = signal - centerline
    centerline_iter = centerline
    for _ in range(8):  # reduced iterations; increase if you still see residual modulation
        upper_e, lower_e = robust_envelope(signal_iter)
        upper_f, lower_f = fit_envelopes(t, upper_e, lower_e)
        centerline_iter = 0.5 * (upper_f + lower_f)
        signal_iter = signal_iter - centerline_iter

    # 3) Final envelopes and normalisation
    upper_e_final, lower_e_final = robust_envelope(signal_iter)
    # Avoid division by zero
    safe_upper = np.where(np.abs(upper_e_final) < 1e-6, np.nan, upper_e_final)
    I_norm = signal_iter / safe_upper
    # Clip into [-1, 1] to avoid numeric issues
    I_norm = np.clip(I_norm, -1.0, 1.0)

    # 4) Convert to phase via arccos; produces values in [0, pi]
    phi = np.arccos(I_norm)

    # 5) Fix segments where phi decreases (wrap correction used in your script)
    # Find negative slopes and map the following point: phi[i+1] -> 2*pi - phi[i+1]
    dphi = np.diff(phi)
    neg_idx = np.where(dphi < 0)[0]
    phi_mod = phi.copy()
    if neg_idx.size > 0:
        phi_mod[neg_idx + 1] = 2 * np.pi - phi_mod[neg_idx + 1]

    # 6) Unwrap
    # Replace NaNs (if any) before unwrap
    valid = ~np.isnan(phi_mod)
    phi_mod[valid] = np.unwrap(phi_mod[valid])

    # Optional plotting for debug
    if plot_flag:
        fig, axs = plt.subplots(3, 2, figsize=(15, 8))
        axs[0, 0].plot(t, signal, label='Original')
        axs[0, 0].plot(t, upper_env, 'r--', label='Upper env')
        axs[0, 0].plot(t, lower_env, 'b--', label='Lower env')
        axs[0, 0].plot(t, centerline, 'm-', label='Centerline')
        axs[0, 0].legend()
        axs[0, 0].set_ylabel('pixel value')

        axs[1, 0].plot(t, signal_iter, label='Iteratively flattened')
        axs[1, 0].plot(t, upper_e_final, 'r--')
        axs[1, 0].plot(t, lower_e_final, 'b--')
        axs[1, 0].set_ylabel('pixel value')

        axs[2, 0].plot(t, I_norm, label='Normalized signal')
        axs[2, 0].set_ylabel('normalized')

        axs[1, 1].plot(phi, label='raw arccos')
        axs[1, 1].plot(phi_mod, label='after slope correction')
        axs[1, 1].legend()

        axs[2, 1].plot(phi_mod)
        axs[2, 1].set_title('Unwrapped phase (after)')
        plt.tight_layout()
        plt.show()

    return phi_mod

# Main processing loop: compute avg phase shift image (signal - background)
shots_to_process = RawInterferogramssig.shape[0]
H = frameshape[0]
W = frameshape[1]

phase_shifts = np.zeros((shots_to_process, H, W), dtype=float)

for shot_idx in range(shots_to_process):
    phase_img = np.zeros((H, W), dtype=float)
    for row_idx in range(H):
        plot_flag = (shot_idx == 0 and row_idx == 0)  # only plot the very first column for debugging
        sig_col = RawInterferogramssig[shot_idx, row_idx, :]
        bg_col  = RawInterferogramsbg[shot_idx, row_idx, :]

        phase_sig = extract_phase_from_column(sig_col, plot_flag=plot_flag)
        phase_bg  = extract_phase_from_column(bg_col, plot_flag=False)

        # If extraction failed (NaNs) handle gracefully
        if phase_sig is None or phase_bg is None:
            phase_img[row_idx, :] = np.nan
        else:
            phase_img[row_idx, :] = phase_sig - phase_bg

    phase_shifts[shot_idx] = phase_img
    print(f"Finished processing shot {shot_idx + 1}/{shots_to_process}")

# Average across shots
avg_phase = np.nanmean(phase_shifts, axis=0)

plt.figure(figsize=(8,6))
plt.imshow(avg_phase, cmap='RdBu', vmin=PHASE_DISPLAY_CLIM[0], vmax=PHASE_DISPLAY_CLIM[1])
plt.colorbar(label='Phase shift (rad)')
plt.title('Average phase shift image')
plt.xlabel('pixel')
plt.ylabel('line')
plt.show()

In [None]:
# Improved column-phase extraction with adaptive bandpass + Hilbert
# Paste into your notebook cell after reading images (keeps RawInterferogramssig, RawInterferogramsbg, frameshape)

import numpy as np
from scipy.signal import hilbert, find_peaks, savgol_filter, butter, filtfilt
from scipy.ndimage import gaussian_filter1d, median_filter, grey_closing
from scipy.interpolate import CubicSpline
import matplotlib.pyplot as plt

# ------------------------
# Parameters you can tune
# ------------------------
MODE = "hilbert_direct"   # "hilbert_direct" (recommended) or "envelope_arccos"
POST_PHASE_SMOOTH_SIGMA = 1.0   # gaussian smoothing on unwrapped phase (pixels); 0 to disable
MIN_PERIOD = 4          # minimum fringe period (pixels) to consider valid
MAX_PERIOD = 200        # maximum fringe period (pixels) to consider valid
BANDWIDTH_FACTOR = 0.6  # bandpass width fraction relative to central freq (0.6 -> ±30%)
# envelope smoothing fractions (relative to period)
ENV_GAUSS_SIGMA_FRAC = 0.35
ENV_MEDIAN_K_FRAC    = 0.15
ENV_SAVGOL_WIN_FRAC  = 1.5  # window length ~ 1.5 * period (odd)
# safe defaults when carrier not found
DEFAULT_SAVGOL_WIN = 11
DEFAULT_GAUSS_SIGMA = 2.0

# ------------------------
# Utility functions
# ------------------------
def estimate_fringe_period(signal):
    """
    Estimate dominant fringe period (in pixels) from the 1D signal using FFT.
    Returns period (pixels) or None if unclear.
    """
    s = np.asarray(signal, dtype=float)
    n = s.size
    if n < 8:
        return None
    # detrend simple
    s0 = s - np.nanmean(s)
    # window to reduce leakage
    w = np.hanning(n)
    S = np.fft.rfft(s0 * w)
    ps = np.abs(S)**2
    freqs = np.fft.rfftfreq(n, d=1.0)  # cycles / pixel
    # ignore DC
    ps[0] = 0.0
    # pick peak frequency (exclude very high freqs)
    # convert to period = 1/freq
    peak_idx = np.argmax(ps)
    f0 = freqs[peak_idx]
    if f0 <= 0:
        return None
    period = 1.0 / f0
    if np.isnan(period) or period < MIN_PERIOD or period > MAX_PERIOD:
        return None
    return float(period)

def butter_bandpass(low, high, fs=1.0, order=3):
    # low, high in cycles/pixel (fs=1 pix^-1)
    nyq = fs / 2.0
    lowb = max(low / nyq, 1e-6)
    highb = min(high / nyq, 0.999999)
    if lowb >= highb:
        # invalid, return None
        return None, None
    b, a = butter(order, [lowb, highb], btype='band')
    return b, a

def bandpass_filter_1d(signal, low, high, order=3):
    b, a = butter_bandpass(low, high, fs=1.0, order=order)
    if b is None:
        return signal.copy()
    # use filtfilt for zero-phase
    try:
        return filtfilt(b, a, signal, method="pad")
    except Exception:
        # fallback to simple gaussian smoothing around removed low-frequency components
        return gaussian_filter1d(signal, sigma=1.0, mode='reflect')

def adaptive_env_params_from_period(period):
    if period is None:
        return {
            "savgol_win": DEFAULT_SAVGOL_WIN,
            "gauss_sigma": DEFAULT_GAUSS_SIGMA,
            "median_k": int(max(3, 5)),
        }
    sav = int(max(3, round(ENV_SAVGOL_WIN_FRAC * period)))
    if sav % 2 == 0:
        sav += 1
    medk = int(max(3, round(ENV_MEDIAN_K_FRAC * period)))
    if medk % 2 == 0:
        medk += 1
    gauss = max(0.5, ENV_GAUSS_SIGMA_FRAC * period)
    return {"savgol_win": max(3, min(sav, period*4)), "gauss_sigma": gauss, "median_k": medk}

def smooth_envelope(amp, savgol_win, savgol_poly=3, med_k=5, gauss_sigma=2.0):
    """
    Smooth amplitude (analytic) with median -> gaussian -> savgol.
    Returns smoothed amplitude.
    """
    a = amp.copy()
    try:
        a = median_filter(a, size=med_k)
    except Exception:
        pass
    try:
        a = gaussian_filter1d(a, sigma=gauss_sigma, mode='reflect')
    except Exception:
        pass
    # ensure savgol_win < len(a)
    if savgol_win >= len(a):
        savgol_win = max(3, (len(a) // 2) // 2 * 2 + 1)
    try:
        a = savgol_filter(a, window_length=savgol_win, polyorder=min(3, savgol_poly))
    except Exception:
        pass
    return a

# ------------------------
# Phase extraction functions
# ------------------------
def extract_phase_hilbert_column(signal, do_plot=False):
    """
    Recommended: bandpass around the fringe, compute analytic signal, take instantaneous phase.
    """
    n = len(signal)
    period = estimate_fringe_period(signal)
    params = adaptive_env_params_from_period(period)

    if period is None:
        # try to fall back to some smoothing + hilbert
        bp_low, bp_high = 0.005, 0.45  # very wide
    else:
        f0 = 1.0 / period
        bw = f0 * BANDWIDTH_FACTOR
        bp_low = max(0.0001, f0 - bw/2.0)
        bp_high = min(0.4999, f0 + bw/2.0)

    # bandpass the raw signal to isolate carrier
    sig_bp = bandpass_filter_1d(signal, bp_low, bp_high, order=3)

    # analytic signal
    analytic = hilbert(sig_bp - np.nanmean(sig_bp))
    inst_phase = np.angle(analytic)         # wrapped [-pi,pi]
    inst_phase_unwrapped = np.unwrap(inst_phase)

    # optional smoothing of the unwrapped phase (reduces high-frequency unwrap noise)
    if POST_PHASE_SMOOTH_SIGMA > 0:
        inst_phase_unwrapped = gaussian_filter1d(inst_phase_unwrapped, sigma=POST_PHASE_SMOOTH_SIGMA, mode='reflect')

    if do_plot:
        t = np.arange(n)
        amp = np.abs(analytic)
        amp_s = smooth_envelope(amp, params["savgol_win"], med_k=params["median_k"], gauss_sigma=params["gauss_sigma"])
        plt.figure(figsize=(10, 6))
        plt.subplot(3,1,1)
        plt.plot(t, signal, label='raw')
        plt.plot(t, sig_bp, label='bandpassed')
        plt.legend()
        plt.subplot(3,1,2)
        plt.plot(t, amp, label='hilbert amp (raw)')
        plt.plot(t, amp_s, label='amp smoothed')
        plt.legend()
        plt.subplot(3,1,3)
        plt.plot(t, inst_phase_unwrapped, label='inst phase unwrapped')
        plt.legend()
        plt.tight_layout()
        plt.show()

    return inst_phase_unwrapped

def extract_phase_envelope_arccos_column(signal, do_plot=False):
    """
    Envelope method retained but improved:
      - bandpass first (like above) to help envelope find peaks for amplitude
      - use analytic amplitude smoothed as envelope
      - normalise and arccos -> unwrap (like your earlier method)
    """
    n = len(signal)
    period = estimate_fringe_period(signal)
    params = adaptive_env_params_from_period(period)

    # bandpass to isolate carrier component used to compute amplitude
    if period is None:
        bp_low, bp_high = 0.005, 0.45
    else:
        f0 = 1.0 / period
        bw = f0 * BANDWIDTH_FACTOR
        bp_low = max(0.0001, f0 - bw/2.0)
        bp_high = min(0.4999, f0 + bw/2.0)

    sig_bp = bandpass_filter_1d(signal, bp_low, bp_high, order=3)
    analytic = hilbert(sig_bp - np.nanmean(sig_bp))
    amp = np.abs(analytic)
    amp_s = smooth_envelope(amp, params["savgol_win"], med_k=params["median_k"], gauss_sigma=params["gauss_sigma"])

    # centerline: slow-varying part of raw signal (not the bandpassed)
    center_lp = gaussian_filter1d(signal, sigma=max(1, params["gauss_sigma"]*2), mode='reflect')
    center_lp = median_filter(center_lp, size=params["median_k"])
    try:
        center_lp = savgol_filter(center_lp, window_length=min(len(center_lp)-1 if (len(center_lp)-1)%2==1 else len(center_lp)-2, params["savgol_win"]), polyorder=3)
    except Exception:
        pass

    # Build upper/lower from smoothed amplitude
    upper = center_lp + amp_s
    lower = center_lp - amp_s

    # Optional morphological closing (heals small dips)
    try:
        upper = grey_closing(upper, size=7)
        lower = grey_closing(lower, size=7)
    except Exception:
        pass

    # Normalize raw signal (after removing center)
    signal_centered = signal - center_lp
    safe_upper = np.where(np.abs(upper - center_lp) < 1e-9, np.nan, (upper - center_lp))
    I_norm = signal_centered / safe_upper
    I_norm = np.clip(I_norm, -1.0, 1.0)

    phi = np.arccos(I_norm)
    dphi = np.diff(phi)
    neg_idx = np.where(dphi < 0)[0]
    phi_mod = phi.copy()
    if neg_idx.size > 0:
        phi_mod[neg_idx + 1] = 2*np.pi - phi_mod[neg_idx + 1]

    valid = ~np.isnan(phi_mod)
    if np.any(valid):
        phi_mod[valid] = np.unwrap(phi_mod[valid])

    if POST_PHASE_SMOOTH_SIGMA > 0:
        phi_mod = gaussian_filter1d(phi_mod, sigma=POST_PHASE_SMOOTH_SIGMA, mode='reflect')

    if do_plot:
        t = np.arange(n)
        plt.figure(figsize=(10,6))
        plt.subplot(3,1,1)
        plt.plot(t, signal, label='raw')
        plt.plot(t, upper, '--', label='upper')
        plt.plot(t, lower, '--', label='lower')
        plt.legend()
        plt.subplot(3,1,2)
        plt.plot(t, amp, label='hilbert amp (raw)')
        plt.plot(t, amp_s, label='amp smoothed')
        plt.legend()
        plt.subplot(3,1,3)
        plt.plot(t, phi_mod, label='unwrapped phase')
        plt.legend()
        plt.tight_layout()
        plt.show()

    return phi_mod

# ------------------------
# Main processing loop using improved extraction
# ------------------------
shots_to_process = RawInterferogramssig.shape[0]
H = frameshape[0]
W = frameshape[1]
phase_shifts = np.zeros((shots_to_process, H, W), dtype=float)

for shot_idx in range(shots_to_process):
    phase_img = np.zeros((H, W), dtype=float)
    for row_idx in range(H):
        # debug-plot only first row of first shot if you need to inspect
        plot_flag = (shot_idx == 0 and row_idx == 0)
        sig_col = RawInterferogramssig[shot_idx, row_idx, :].astype(float)
        bg_col  = RawInterferogramsbg[shot_idx, row_idx, :].astype(float)

        if MODE == "hilbert_direct":
            phase_sig = extract_phase_hilbert_column(sig_col, do_plot=plot_flag)
            phase_bg  = extract_phase_hilbert_column(bg_col, do_plot=False)
        else:
            phase_sig = extract_phase_envelope_arccos_column(sig_col, do_plot=plot_flag)
            phase_bg  = extract_phase_envelope_arccos_column(bg_col, do_plot=False)

        # subtract (signal - background)
        if phase_sig is None or phase_bg is None or phase_sig.shape != phase_bg.shape:
            phase_img[row_idx, :] = np.nan
        else:
            phase_img[row_idx, :] = phase_sig - phase_bg

    phase_shifts[shot_idx] = phase_img
    print(f"Finished shot {shot_idx+1}/{shots_to_process}")

avg_phase = np.nanmean(phase_shifts, axis=0)

# Plot result
plt.figure(figsize=(8,6))
plt.imshow(avg_phase, cmap='RdBu', vmin=-0.05, vmax=0.05)
plt.colorbar(label='Phase shift (rad)')
plt.title('Average phase shift (improved)')
plt.show()


In [None]:
# Pipeline with column magnitude rejection + nearest-good-column replacement
# Paste into a Jupyter cell and run.

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from imageio import imread
from scipy.signal import hilbert, savgol_filter, butter, filtfilt
from scipy.ndimage import gaussian_filter1d, median_filter, grey_closing, gaussian_filter
from scipy.interpolate import SmoothBivariateSpline
from math import ceil

# ------------------------
# USER PARAMETERS (edit)
# ------------------------
sig_path = Path(r"C:\Users\sann7609\Documents\Oxford\ChannelAnalysis_CALA_Sept25_minisforum\vanilla_interferometry\for_vanilla_80mbar_Bdel582_t0\Interferometry2\sig")
bg_path  = Path(r"C:\Users\sann7609\Documents\Oxford\ChannelAnalysis_CALA_Sept25_minisforum\vanilla_interferometry\for_vanilla_80mbar_Bdel582_t0\Interferometry2\bg")
SIG_PATTERN = ".tiff"
BG_PATTERN  = ".tiff"
FILES_TO_LOAD = 5          # how many sig/bg pairs to load/process
MODE = "hilbert_direct"    # "hilbert_direct" or "envelope_arccos"
POST_PHASE_SMOOTH_SIGMA = 1.0

# The channel masks you provided (rows to mask out for fitting)
CHANNEL_MASKS = [(200, 370), (650, 820)]

# Spline fit degree/smoothing control
SPLINE_KX = 3
SPLINE_KY = 3
SPLINE_SMTH_MULTIPLIER = 0.8  # larger => smoother surface

# Ridge fallback degree & regularization
POLY_DEGREE_FALLBACK = 2
RIDGE_LAMBDA = 1e-6

# Display options
COLORMAP = 'RdBu'
PHASE_VMIN, PHASE_VMAX = -0.4, 0.4
SHOW_PER_SHOT_PLOTS = True
SHOW_AVERAGE_PLOTS  = True

# Tolerances for column validation (existing)
MIN_VALID_POINTS_PER_COLUMN = 10   # minimum non-NaN points in a vertical column to be considered valid
MIN_STD_FOR_VALID_COLUMN = 1e-6    # if std < this, consider column invalid (likely constant/failed)

# NEW: column magnitude based rejection
COLUMN_MAG_THRESHOLD = 0.5        # approximate maximum expected absolute phase (radians)
COLUMN_MAG_FACTOR_MAX = 1.2       # hard limit: if any |value| > threshold * factor_max -> reject
COLUMN_MEDIAN_FACTOR = 0.9        # if median(|col|) > threshold * COLUMN_MEDIAN_FACTOR -> reject (looser)

# ------------------------
# Utilities and extraction routines (unchanged)
# ------------------------
def to_gray_if_needed(arr):
    arr = np.asarray(arr)
    if arr.ndim == 3 and arr.shape[2] >= 3:
        return 0.2126*arr[...,0] + 0.7152*arr[...,1] + 0.0722*arr[...,2]
    return arr

def estimate_fringe_period(signal, MIN_PERIOD=4, MAX_PERIOD=200):
    s = np.asarray(signal, dtype=float)
    n = s.size
    if n < 8:
        return None
    s0 = s - np.nanmean(s)
    w = np.hanning(n)
    S = np.fft.rfft(s0 * w)
    ps = np.abs(S)**2
    freqs = np.fft.rfftfreq(n, d=1.0)
    ps[0] = 0.0
    peak_idx = np.argmax(ps)
    f0 = freqs[peak_idx]
    if f0 <= 0:
        return None
    period = 1.0 / f0
    if np.isnan(period) or period < MIN_PERIOD or period > MAX_PERIOD:
        return None
    return float(period)

def butter_bandpass(low, high, fs=1.0, order=3):
    nyq = fs / 2.0
    lowb = max(low / nyq, 1e-9)
    highb = min(high / nyq, 0.999999)
    if lowb >= highb:
        return None, None
    from scipy.signal import butter
    b, a = butter(order, [lowb, highb], btype='band')
    return b, a

def bandpass_filter_1d(signal, low, high, order=3):
    b, a = butter_bandpass(low, high, fs=1.0, order=order)
    if b is None:
        return signal.copy()
    try:
        return filtfilt(b, a, signal, method='pad')
    except Exception:
        return gaussian_filter1d(signal, sigma=1.0, mode='reflect')

def adaptive_env_params_from_period(period, ENV_GAUSS_SIGMA_FRAC=0.35, ENV_MEDIAN_K_FRAC=0.15, ENV_SAVGOL_WIN_FRAC=1.5,
                                   DEFAULT_SAVGOL_WIN=11, DEFAULT_GAUSS_SIGMA=2.0):
    if period is None:
        return {"savgol_win": DEFAULT_SAVGOL_WIN, "gauss_sigma": DEFAULT_GAUSS_SIGMA, "median_k": int(max(3,5))}
    sav = int(max(3, round(ENV_SAVGOL_WIN_FRAC * period)))
    if sav % 2 == 0:
        sav += 1
    medk = int(max(3, round(ENV_MEDIAN_K_FRAC * period)))
    if medk % 2 == 0:
        medk += 1
    gauss = max(0.5, ENV_GAUSS_SIGMA_FRAC * period)
    return {"savgol_win": max(3, min(sav, int(period*4))), "gauss_sigma": gauss, "median_k": medk}

def smooth_envelope(amp, savgol_win, savgol_poly=3, med_k=5, gauss_sigma=2.0):
    a = amp.copy()
    try:
        a = median_filter(a, size=med_k)
    except Exception:
        pass
    try:
        a = gaussian_filter1d(a, sigma=gauss_sigma, mode='reflect')
    except Exception:
        pass
    if savgol_win >= len(a):
        savgol_win = max(3, (len(a) // 2) // 2 * 2 + 1)
    try:
        a = savgol_filter(a, window_length=savgol_win, polyorder=min(3, savgol_poly))
    except Exception:
        pass
    return a

# primary extraction (bandpass + hilbert)
def extract_hilbert_col_main(signal, BANDWIDTH_FACTOR=0.6, POST_PHASE_SMOOTH_SIGMA=1.0):
    n = len(signal)
    period = estimate_fringe_period(signal)
    params = adaptive_env_params_from_period(period)
    if period is None:
        bp_low, bp_high = 0.005, 0.45
    else:
        f0 = 1.0 / period
        bw = f0 * BANDWIDTH_FACTOR
        bp_low = max(0.0001, f0 - bw/2.0)
        bp_high = min(0.4999, f0 + bw/2.0)
    sig_bp = bandpass_filter_1d(signal, bp_low, bp_high, order=3)
    analytic = hilbert(sig_bp - np.nanmean(sig_bp))
    inst_phase = np.angle(analytic)
    inst_phase_unwrapped = np.unwrap(inst_phase)
    if POST_PHASE_SMOOTH_SIGMA > 0:
        inst_phase_unwrapped = gaussian_filter1d(inst_phase_unwrapped, sigma=POST_PHASE_SMOOTH_SIGMA, mode='reflect')
    result = {"phase": inst_phase_unwrapped, "bandpassed": sig_bp, "analytic": analytic,
              "amp_s": smooth_envelope(np.abs(analytic), params["savgol_win"], med_k=params["median_k"], gauss_sigma=params["gauss_sigma"]),
              "period": period}
    return result

# fallback: direct hilbert (no bandpass)
def extract_hilbert_col_fallback(signal, POST_PHASE_SMOOTH_SIGMA=1.0):
    analytic = hilbert(signal - np.nanmean(signal))
    inst_phase_unwrapped = np.unwrap(np.angle(analytic))
    if POST_PHASE_SMOOTH_SIGMA > 0:
        inst_phase_unwrapped = gaussian_filter1d(inst_phase_unwrapped, sigma=POST_PHASE_SMOOTH_SIGMA, mode='reflect')
    return {"phase": inst_phase_unwrapped, "bandpassed": signal, "analytic": analytic, "amp_s": np.abs(analytic), "period": None}

def extract_envelope_arccos_col(signal, BANDWIDTH_FACTOR=0.6, POST_PHASE_SMOOTH_SIGMA=1.0):
    # improved envelope method
    n = len(signal)
    period = estimate_fringe_period(signal)
    params = adaptive_env_params_from_period(period)
    if period is None:
        bp_low, bp_high = 0.005, 0.45
    else:
        f0 = 1.0 / period
        bw = f0 * BANDWIDTH_FACTOR
        bp_low = max(0.0001, f0 - bw/2.0)
        bp_high = min(0.4999, f0 + bw/2.0)
    sig_bp = bandpass_filter_1d(signal, bp_low, bp_high, order=3)
    analytic = hilbert(sig_bp - np.nanmean(sig_bp))
    amp = np.abs(analytic)
    amp_s = smooth_envelope(amp, params["savgol_win"], med_k=params["median_k"], gauss_sigma=params["gauss_sigma"])
    center_lp = gaussian_filter1d(signal, sigma=max(1, params["gauss_sigma"]*2), mode='reflect')
    center_lp = median_filter(center_lp, size=params["median_k"])
    try:
        center_lp = savgol_filter(center_lp, window_length=min(len(center_lp)-1 if (len(center_lp)-1)%2==1 else len(center_lp)-2, params["savgol_win"]), polyorder=3)
    except Exception:
        pass
    upper = center_lp + amp_s
    lower = center_lp - amp_s
    try:
        upper = grey_closing(upper, size=7)
        lower = grey_closing(lower, size=7)
    except Exception:
        pass
    signal_centered = signal - center_lp
    safe_upper = np.where(np.abs(upper - center_lp) < 1e-9, np.nan, (upper - center_lp))
    I_norm = signal_centered / safe_upper
    I_norm = np.clip(I_norm, -1.0, 1.0)
    phi = np.arccos(I_norm)
    dphi = np.diff(phi)
    neg_idx = np.where(dphi < 0)[0]
    phi_mod = phi.copy()
    if neg_idx.size > 0:
        phi_mod[neg_idx + 1] = 2*np.pi - phi_mod[neg_idx + 1]
    valid = ~np.isnan(phi_mod)
    if np.any(valid):
        phi_mod[valid] = np.unwrap(phi_mod[valid])
    if POST_PHASE_SMOOTH_SIGMA > 0:
        phi_mod = gaussian_filter1d(phi_mod, sigma=POST_PHASE_SMOOTH_SIGMA, mode='reflect')
    return {"phase": phi_mod, "bandpassed": sig_bp, "analytic": analytic, "amp": amp, "amp_s": amp_s, "upper": upper, "lower": lower, "center_lp": center_lp, "I_norm": I_norm, "period": period}

# wrapper per-column that applies fallback automatically
def extract_column_phase(signal, mode=MODE):
    if mode == "hilbert_direct":
        out = extract_hilbert_col_main(signal)
        phase = out.get('phase', None)
        if phase is None or phase.shape[0] != signal.shape[0] or np.sum(~np.isnan(phase)) < MIN_VALID_POINTS_PER_COLUMN:
            out = extract_hilbert_col_fallback(signal)
        return out
    else:
        out = extract_envelope_arccos_col(signal)
        phase = out.get('phase', None)
        if phase is None or phase.shape[0] != signal.shape[0] or np.sum(~np.isnan(phase)) < MIN_VALID_POINTS_PER_COLUMN:
            out = extract_hilbert_col_fallback(signal)
        return out

# ------------------------
# Load files and compute phase maps (vertical columns)
# ------------------------
sig_files = sorted([p for p in sig_path.glob(f"*{SIG_PATTERN}*")])
bg_files  = sorted([p for p in bg_path.glob(f"*{BG_PATTERN}*")])
if len(sig_files) == 0 or len(bg_files) == 0:
    raise FileNotFoundError("No signal/background files found — check sig_path/bg_path and patterns.")

files_to_use = min(FILES_TO_LOAD, len(sig_files), len(bg_files))
print(f"Loading {files_to_use} sig/bg pairs")

first_img = to_gray_if_needed(imread(sig_files[0])).astype(float)
if first_img.ndim == 2:
    H, W = first_img.shape
else:
    H, W = first_img.shape[0], first_img.shape[1]

RawInterferogramssig = np.zeros((files_to_use, H, W), dtype=float)
RawInterferogramsbg  = np.zeros((files_to_use, H, W), dtype=float)

for i in range(files_to_use):
    s = to_gray_if_needed(imread(sig_files[i]).astype(float))
    b = to_gray_if_needed(imread(bg_files[i]).astype(float))
    RawInterferogramssig[i] = s
    RawInterferogramsbg[i] = b

print("Loaded shapes:", RawInterferogramssig.shape)

# Compute phase maps
phase_maps = np.full((files_to_use, H, W), np.nan, dtype=float)

for sidx in range(files_to_use):
    ph_map = np.full((H, W), np.nan, dtype=float)
    for col_idx in range(W):
        sig_col = RawInterferogramssig[sidx, :, col_idx].astype(float)
        bg_col  = RawInterferogramsbg[sidx, :, col_idx].astype(float)
        out_sig = extract_column_phase(sig_col, mode=MODE)
        out_bg  = extract_column_phase(bg_col, mode=MODE)
        phs = out_sig.get('phase', None)
        phb = out_bg.get('phase', None)
        if phs is None or phb is None or phs.shape != phb.shape:
            ph_map[:, col_idx] = np.nan
        else:
            ph_map[:, col_idx] = phs - phb
    phase_maps[sidx] = ph_map
    print(f"Computed phase map for shot {sidx+1}/{files_to_use}")

# ------------------------
# Build mask from CHANNEL_MASKS (True = masked/excluded from fit)
# ------------------------
mask = np.zeros((H, W), dtype=bool)
for (ystart, yend) in CHANNEL_MASKS:
    y0 = int(max(0, ystart)); y1 = int(min(H, yend))
    if y1 > y0:
        mask[y0:y1, :] = True

# ------------------------
# Robust 2D fit per-shot using SmoothBivariateSpline with outlier rejection
# ------------------------
def robust_2d_spline_fit(Z, mask_bool, kx=3, ky=3, smth_mult=1.0):
    H, W = Z.shape
    yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
    xx_flat = xx.ravel(); yy_flat = yy.ravel(); z_flat = Z.ravel()
    valid = (~mask_bool.ravel()) & (~np.isnan(z_flat))
    if valid.sum() < 20:
        return None, False, "too few valid points"
    x_valid = xx_flat[valid]; y_valid = yy_flat[valid]; z_valid = z_flat[valid]

    med = np.nanmedian(z_valid)
    std = np.nanstd(z_valid)
    if std == 0 or np.isnan(std):
        std = 1.0
    thresh = 3.0
    inlier_mask = np.abs(z_valid - med) <= (thresh * std)
    if inlier_mask.sum() < max(20, int(0.3 * valid.sum())):
        thresh = 4.0
        inlier_mask = np.abs(z_valid - med) <= (thresh * std)

    x_in = x_valid[inlier_mask]; y_in = y_valid[inlier_mask]; z_in = z_valid[inlier_mask]
    n_in = x_in.size
    if n_in < 20:
        return None, False, "too few inliers after trimming"

    s_val = smth_mult * (np.nanstd(z_in)**2) * n_in

    try:
        spline = SmoothBivariateSpline(x_in, y_in, z_in, kx=kx, ky=ky, s=s_val)
        xx_grid = np.arange(W)
        yy_grid = np.arange(H)
        Z_fit = spline(xx_grid, yy_grid)  # (len(xx_grid), len(yy_grid))
        Z_fit = Z_fit.T
        return Z_fit, True, {"n_in": n_in, "s_val": s_val}
    except Exception as e:
        return None, False, f"spline failed: {e}"

def fit_2d_ridge_poly(Z, mask_bool, deg=2, lam=1e-6):
    H, W = Z.shape
    yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
    xx_flat = xx.ravel(); yy_flat = yy.ravel(); z_flat = Z.ravel()
    valid = (~mask_bool.ravel()) & (~np.isnan(z_flat))
    x = xx_flat[valid]; y = yy_flat[valid]; z = z_flat[valid]
    if z.size < 10:
        return None, False, "too few points"
    terms = []
    for d in range(deg+1):
        for i in range(d+1):
            j = d - i
            terms.append((i, j))
    A = np.vstack([ (x**i) * (y**j) for (i,j) in terms ]).T
    ATA = A.T.dot(A)
    M = ATA.shape[0]
    ATA_reg = ATA + lam * np.eye(M)
    rhs = A.T.dot(z)
    coeffs = np.linalg.solve(ATA_reg, rhs)
    xx_all = xx_flat; yy_all = yy_flat
    A_full = np.vstack([ (xx_all**i) * (yy_all**j) for (i,j) in terms ]).T
    z_fit_flat = A_full.dot(coeffs)
    return z_fit_flat.reshape(H, W), True, {"deg": deg, "lam": lam}

# Apply fits and subtract
fitted_surfaces = np.full_like(phase_maps, np.nan)
corrected_maps = np.full_like(phase_maps, np.nan)
fit_details = []

for sidx in range(files_to_use):
    Z = phase_maps[sidx]
    Z_fit, ok, info = robust_2d_spline_fit(Z, mask, kx=SPLINE_KX, ky=SPLINE_KY, smth_mult=SPLINE_SMTH_MULTIPLIER)
    if not ok:
        Z_fit, ok2, info2 = fit_2d_ridge_poly(Z, mask, deg=POLY_DEGREE_FALLBACK, lam=RIDGE_LAMBDA)
        if not ok2:
            print(f"Shot {sidx+1}: both spline and poly fallback failed ({info}; {info2}). Leaving NaNs.")
            fitted_surfaces[sidx] = np.nan
            corrected_maps[sidx] = np.nan
            fit_details.append({"sidx": sidx, "method": "failed", "info": (info, info2)})
            continue
        else:
            fitted_surfaces[sidx] = Z_fit
            corrected_maps[sidx] = Z - Z_fit
            fit_details.append({"sidx": sidx, "method": "ridge_poly", "info": info2})
    else:
        fitted_surfaces[sidx] = Z_fit
        corrected_maps[sidx] = Z - Z_fit
        fit_details.append({"sidx": sidx, "method": "spline", "info": info})

# ------------------------
# Fill partial NaNs along rows (small gaps) - linear interpolation across columns per row
# ------------------------
def fill_nan_columns_by_row_interpolation(map2d):
    H, W = map2d.shape
    out = map2d.copy()
    x = np.arange(W)
    for yi in range(H):
        row = out[yi, :]
        nan_mask = np.isnan(row)
        if np.all(nan_mask):
            continue
        if np.any(nan_mask):
            valid_idx = np.where(~nan_mask)[0]
            valid_vals = row[valid_idx]
            filled = np.interp(x, valid_idx, valid_vals, left=valid_vals[0], right=valid_vals[-1])
            out[yi, nan_mask] = filled[nan_mask]
    return out

# ------------------------
# Detect bad columns (by count/std/magnitude) and replace by nearest good column
# ------------------------
def find_bad_columns(map2d, min_valid_points=MIN_VALID_POINTS_PER_COLUMN, min_std=MIN_STD_FOR_VALID_COLUMN,
                     mag_thresh=COLUMN_MAG_THRESHOLD, mag_factor_max=COLUMN_MAG_FACTOR_MAX, mag_median_factor=COLUMN_MEDIAN_FACTOR):
    """
    Return boolean array (W,) where True means column is considered 'bad' and should be replaced.
    """
    H, W = map2d.shape
    bad = np.zeros(W, dtype=bool)
    for j in range(W):
        col = map2d[:, j]
        valid_mask = ~np.isnan(col)
        valid_count = int(valid_mask.sum())
        if valid_count < min_valid_points:
            bad[j] = True
            continue
        col_std = float(np.nanstd(col))
        if col_std < min_std:
            bad[j] = True
            continue
        # magnitude checks
        max_abs = float(np.nanmax(np.abs(col)))
        med_abs = float(np.nanmedian(np.abs(col)))
        if max_abs > (mag_thresh * mag_factor_max):
            bad[j] = True
            continue
        if med_abs > (mag_thresh * mag_median_factor):
            bad[j] = True
            continue
    return bad

def replace_bad_columns_by_nearest(map2d, bad_mask):
    """
    Replace columns marked in bad_mask with nearest good column copy.
    Returns (out_map, replaced_indices)
    """
    H, W = map2d.shape
    out = map2d.copy()
    good_idx = np.where(~bad_mask)[0]
    if good_idx.size == 0:
        # no good columns -> nothing to replace
        return out, np.where(bad_mask)[0]
    for j in np.where(bad_mask)[0]:
        # find nearest good col
        distances = np.abs(good_idx - j)
        nearest_pos = np.argmin(distances)
        nearest_col = good_idx[nearest_pos]
        out[:, j] = out[:, nearest_col]
    return out, np.where(bad_mask)[0]

# Apply small row interpolation then column rejection+replacement per shot
replacement_summary = []

for sidx in range(files_to_use):
    if np.all(np.isnan(phase_maps[sidx])):
        replacement_summary.append({"sidx": sidx, "replaced_cols": []})
        continue
    # small-gap interpolation
    phase_maps[sidx] = fill_nan_columns_by_row_interpolation(phase_maps[sidx])
    if not np.all(np.isnan(fitted_surfaces[sidx])):
        fitted_surfaces[sidx] = fill_nan_columns_by_row_interpolation(fitted_surfaces[sidx])
    if not np.all(np.isnan(corrected_maps[sidx])):
        corrected_maps[sidx] = fill_nan_columns_by_row_interpolation(corrected_maps[sidx])

    # detect bad columns in corrected map (prefer corrected map since it's what you'll analyse)
    bad_cols = find_bad_columns(corrected_maps[sidx], min_valid_points=MIN_VALID_POINTS_PER_COLUMN,
                                min_std=MIN_STD_FOR_VALID_COLUMN, mag_thresh=COLUMN_MAG_THRESHOLD,
                                mag_factor_max=COLUMN_MAG_FACTOR_MAX, mag_median_factor=COLUMN_MEDIAN_FACTOR)
    # also ensure we don't accidentally mark as bad columns inside masked rows only (if whole column masked) - but previous checks handle NaNs
    # Replace bad columns in all maps consistently
    phase_maps[sidx], replaced_phase_cols = replace_bad_columns_by_nearest(phase_maps[sidx], bad_cols)
    if not np.all(np.isnan(fitted_surfaces[sidx])):
        fitted_surfaces[sidx], _ = replace_bad_columns_by_nearest(fitted_surfaces[sidx], bad_cols)
    if not np.all(np.isnan(corrected_maps[sidx])):
        corrected_maps[sidx], _ = replace_bad_columns_by_nearest(corrected_maps[sidx], bad_cols)

    replacement_summary.append({"sidx": sidx, "replaced_cols": replaced_phase_cols.tolist(), "n_replaced": len(replaced_phase_cols)})

# Recompute averages AFTER filling / replacement
avg_before = np.nanmean(phase_maps[:files_to_use], axis=0)
avg_fit    = np.nanmean(fitted_surfaces[:files_to_use], axis=0)
avg_after  = np.nanmean(corrected_maps[:files_to_use], axis=0)
avg_corrected_map = avg_after.copy()

# ------------------------
# Plotting: per-shot (original w/mask, fitted surface, corrected) and averages
# All plots flipped vertically via origin='lower'
# ------------------------
for sidx in range(files_to_use):
    if not SHOW_PER_SHOT_PLOTS:
        break
    fig, axs = plt.subplots(1, 3, figsize=(18,5))
    # Original with mask overlay
    im0 = axs[0].imshow(phase_maps[sidx], cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    axs[0].set_title(f"Shot {sidx+1} — Original phase map (after fixes)")
    mask_overlay = np.ma.masked_where(~mask, mask)
    axs[0].imshow(mask_overlay, cmap='gray', alpha=0.25, origin='lower')
    axs[0].set_axis_off()
    plt.colorbar(im0, ax=axs[0], fraction=0.046)

    # Fitted surface (use same cmap/limits for direct comparison)
    if not np.all(np.isnan(fitted_surfaces[sidx])):
        im1 = axs[1].imshow(fitted_surfaces[sidx], cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    else:
        im1 = axs[1].imshow(np.zeros((H,W)), cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
        axs[1].text(0.5, 0.5, 'fit failed', transform=axs[1].transAxes, ha='center')
    axs[1].set_title("Fitted 2D surface (same cmap)")
    axs[1].set_axis_off()
    plt.colorbar(im1, ax=axs[1], fraction=0.046)

    # Corrected map
    im2 = axs[2].imshow(corrected_maps[sidx], cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    axs[2].set_title("Corrected map (original − fitted)")
    axs[2].set_axis_off()
    plt.colorbar(im2, ax=axs[2], fraction=0.046)

    suppl = f"Shot {sidx+1}: method={fit_details[sidx]['method']} info={fit_details[sidx]['info']}. Replaced {replacement_summary[sidx]['n_replaced']} bad cols."
    plt.suptitle(suppl)
    plt.tight_layout()
    plt.show()

# Average plots
if SHOW_AVERAGE_PLOTS:
    fig, axs = plt.subplots(1, 3, figsize=(18,5))
    im0 = axs[0].imshow(avg_before, cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    axs[0].set_title("Average before subtraction")
    axs[0].imshow(np.ma.masked_where(~mask, mask), cmap='gray', alpha=0.25, origin='lower')
    axs[0].set_axis_off()
    plt.colorbar(im0, ax=axs[0], fraction=0.046)

    im1 = axs[1].imshow(avg_fit, cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    axs[1].set_title("Average fitted surface (same cmap)")
    axs[1].set_axis_off()
    plt.colorbar(im1, ax=axs[1], fraction=0.046)

    im2 = axs[2].imshow(avg_after, cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower')
    axs[2].set_title("Average after subtraction (final avg corrected map)")
    axs[2].set_axis_off()
    plt.colorbar(im2, ax=axs[2], fraction=0.046)

    plt.suptitle(f"Averages across first {files_to_use} shots — spline_mult={SPLINE_SMTH_MULTIPLIER}")
    plt.tight_layout()
    plt.show()

print("Done.\nFit details per shot:")
for d in fit_details:
    print(d)
print("\nColumn replacement summary per shot:")
for r in replacement_summary:
    print(f"Shot {r['sidx']+1}: replaced {r['n_replaced']} columns: {r['replaced_cols']}")


In [None]:


fft_phase = np.loadtxt(r'C:\Users\sann7609\Documents\Oxford\ChannelAnalysis_CALA_Sept25_minisforum\channel_analysis_250910_largeFFT\phase_maps\80mbar_Bdel582_t0\Interferometry2\AvgPhase.txt')

phasemax=0.3
# Show final averaged corrected map
fig, axs = plt.subplots(1,2, figsize=(10,6))
axs[0].imshow(avg_corrected_map, cmap=COLORMAP, vmin=-phasemax, vmax=phasemax, origin='lower')
#plt.colorbar(label='Phase (rad)')

axs[1].imshow(fft_phase, cmap='RdBu', vmin=-phasemax, vmax=phasemax,origin='lower')

In [None]:
fig, axs = plt.subplots(1,2, figsize=(10,6))
sig_img = RawInterferogramssig[0,:,:]
bg_img  = RawInterferogramsbg[0,:,:]
H = sig_img.shape[0]
W = sig_img.shape[1]
pixsize=1.06e-6 #microns
extent = [0,W*pixsize*1e6,0,H*pixsize*1e6] # in meters
axs[0].imshow(RawInterferogramssig[0,:,:], cmap='gray', origin='lower',extent=extent)
axs[1].imshow(RawInterferogramsbg[0,:,:], cmap='gray', origin='lower',extent=extent)
axs[0].set_xlabel('X (microns)')
axs[0].set_ylabel('Y (microns)')
axs[1].set_xlabel('X (microns)')
axs[0].text(s='a) Signal',x=50,y=1000,color='yellow',fontsize=14)
axs[1].text(s='b) Background',x=50,y=1000,color='yellow',fontsize=14)
axs[1].set_yticklabels([])
plt.tight_layout()

plt.savefig(r'c:\Users\sann7609\Documents\Oxford\Thesis\images\Interferograms_sig_bg.png',dpi=300,bbox_inches='tight')


In [None]:
# Full updated plotting cell with x-limits for c)-f)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy.ndimage import gaussian_filter1d, median_filter, gaussian_filter
from scipy.signal import savgol_filter, hilbert
import warnings

# --------------- USER TWEAKS ---------------
REF_SHOT = 0            # which shot to show (0-based)
REF_COL = None          # which column to inspect; None -> center column
ZOOM_FRACTION = 0.30    # portion of image height to zoom into stacked plots
FIG_W = 8.0             # figure width (inches)
FIG_H = 11.0            # figure height (inches)
POST_PHASE_SMOOTH_SIGMA = 1.0
PIXEL_MICRONS_X = 1359.0
PIXEL_MICRONS_Y = 1088.0
COLORMAP = 'RdBu'
PHASE_VMIN, PHASE_VMAX = -0.4, 0.4

# Optional: display-only smoothing sigma for presentation (set to 0 to disable)
DISPLAY_SMOOTH_SIGMA = 0.0

sig_color = 'blue'
bg_color = 'orange'
# -------------------------------------------

# small helpers (reused)
def bandpass_filter_1d(signal, low, high, order=3):
    from scipy.signal import butter, filtfilt
    nyq = 0.5
    lowb = max(low / nyq, 1e-9)
    highb = min(high / nyq, 0.999999)
    if lowb >= highb:
        return signal.copy()
    b, a = butter(order, [lowb, highb], btype='band')
    try:
        return filtfilt(b, a, signal, method='pad')
    except Exception:
        from scipy.ndimage import gaussian_filter1d
        return gaussian_filter1d(signal, sigma=1.0, mode='reflect')

def smooth_envelope(amp, savgol_win=11, med_k=5, gauss_sigma=2.0):
    a = np.asarray(amp).copy()
    try:
        a = median_filter(a, size=max(1, med_k))
    except Exception:
        pass
    try:
        a = gaussian_filter1d(a, sigma=max(0.5, gauss_sigma), mode='reflect')
    except Exception:
        pass
    if savgol_win >= len(a):
        savgol_win = max(3, (len(a) // 2) // 2 * 2 + 1)
    try:
        a = savgol_filter(a, window_length=int(savgol_win), polyorder=min(3,2))
    except Exception:
        pass
    return a

def extract_envelope_arccos_col_wrapped(signal, bp_low=0.005, bp_high=0.45,
                                        savgol_win=11, med_k=5, gauss_sigma=2.0,
                                        post_smooth=0.0, renorm_pct=99):
    sig_bp = bandpass_filter_1d(signal, bp_low, bp_high)
    analytic = hilbert(sig_bp - np.nanmean(sig_bp))
    amp = np.abs(analytic)
    amp_s = smooth_envelope(amp, savgol_win=savgol_win, med_k=med_k, gauss_sigma=gauss_sigma)

    center_lp = gaussian_filter1d(signal, sigma=max(1, gauss_sigma*2), mode='reflect')
    center_lp = median_filter(center_lp, size=max(1, med_k))
    try:
        center_lp = savgol_filter(center_lp, window_length=min(len(center_lp)-1 if (len(center_lp)-1)%2==1 else len(center_lp)-2, savgol_win), polyorder=3)
    except Exception:
        pass

    upper = center_lp + amp_s
    lower = center_lp - amp_s
    centerline = 0.5*(upper+lower)
    amplitude = 0.5*(upper-lower)
    safe_amp = np.where(np.abs(amplitude) < 1e-9, np.nan, amplitude)
    I_norm_raw = (signal - centerline) / safe_amp

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        pct_val = np.nanpercentile(np.abs(I_norm_raw), renorm_pct) if np.any(np.isfinite(I_norm_raw)) else 1.0
    scale = float(pct_val) if (np.isfinite(pct_val) and pct_val > 0) else 1.0
    if scale < 1e-6:
        scale = 1.0
    I_norm = I_norm_raw / scale
    I_norm_clipped = np.clip(I_norm, -1.0, 1.0)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        phi_tri_wrapped = np.arccos(I_norm_clipped)

    dphi = np.diff(phi_tri_wrapped)
    neg_idx = np.where(dphi < 0)[0]
    phi_tri_corr_wrapped = phi_tri_wrapped.copy()
    if neg_idx.size > 0:
        for idx in neg_idx:
            v = phi_tri_corr_wrapped[idx+1]
            phi_tri_corr_wrapped[idx+1] = (2*np.pi - v) % (2*np.pi)

    if post_smooth > 0:
        phi_tri_corr_wrapped = gaussian_filter1d(phi_tri_corr_wrapped, sigma=post_smooth, mode='reflect')

    return {
        'sig_bp': sig_bp, 'analytic': analytic, 'amp': amp, 'amp_s': amp_s,
        'upper': upper, 'lower': lower, 'centerline': centerline,
        'I_norm_raw': I_norm_raw, 'I_norm': I_norm, 'I_norm_clipped': I_norm_clipped,
        'scale_used': scale, 'phi_tri_wrapped': phi_tri_wrapped, 'phi_tri_corr_wrapped': phi_tri_corr_wrapped
    }

# ----------------- ensure pipeline outputs exist -----------------
try:
    RawSig = RawInterferogramssig
    RawBg = RawInterferogramsbg
    phase_maps   # processed sig-bg per-shot maps (from pipeline)
    corrected_maps
    fitted_surfaces
    mask
except NameError as e:
    raise RuntimeError("Pipeline outputs (phase_maps, corrected_maps, fitted_surfaces, mask, RawInterferogramssig) not found. Run the pipeline cell first.") from e

shots_avail = RawSig.shape[0]
H, W = RawSig.shape[1], RawSig.shape[2]
REF_SHOT = int(np.clip(REF_SHOT, 0, shots_avail-1))
if REF_COL is None:
    REF_COL = W // 2
else:
    REF_COL = int(np.clip(REF_COL, 0, W-1))

sig_img = RawSig[REF_SHOT]
bg_img  = RawBg[REF_SHOT]
# Use pipeline outputs for the single-shot displays
raw_single_shot_map = phase_maps[REF_SHOT]
corrected_single_shot = corrected_maps[REF_SHOT]
fitted_shot = fitted_surfaces[REF_SHOT] if (fitted_surfaces is not None and fitted_surfaces.shape[0] > REF_SHOT) else np.full((H,W), np.nan)
if 'avg_corrected_map' not in globals():
    avg_corrected_map = np.nanmean(corrected_maps, axis=0)

def maybe_smooth_display(arr, sigma):
    if sigma and sigma > 0:
        arr2 = arr.copy()
        nanmask = np.isnan(arr2)
        arr2[nanmask] = 0.0
        arr2 = gaussian_filter(arr2, sigma=sigma)
        arr2[nanmask] = np.nan
        return arr2
    return arr

display_raw_map = maybe_smooth_display(raw_single_shot_map, DISPLAY_SMOOTH_SIGMA)
display_corrected = maybe_smooth_display(corrected_single_shot, DISPLAY_SMOOTH_SIGMA)
display_fitted = maybe_smooth_display(fitted_shot, DISPLAY_SMOOTH_SIGMA)
display_avg = maybe_smooth_display(avg_corrected_map, DISPLAY_SMOOTH_SIGMA)

# pick reference column vectors
sig_col = sig_img[:, REF_COL].astype(float)
bg_col  = bg_img[:, REF_COL].astype(float)

# compute envelope extraction for reference column to populate stacked panels
period_est = None
n = len(sig_col)
if n >= 8:
    s0 = sig_col - np.nanmean(sig_col)
    w = np.hanning(n)
    S = np.fft.rfft(s0*w)
    ps = np.abs(S)**2
    ps[0] = 0.0
    freqs = np.fft.rfftfreq(n, d=1.0)
    peak_idx = np.argmax(ps)
    f0 = freqs[peak_idx] if peak_idx < len(freqs) else 0.02
    if f0 > 0:
        period_est = 1.0 / f0

if period_est is None:
    bp_low, bp_high = 0.005, 0.45
else:
    f0 = 1.0/period_est
    bp_low = max(0.0005, f0*0.6)
    bp_high = min(0.49, f0*1.4)

env = extract_envelope_arccos_col_wrapped(sig_col, bp_low=bp_low, bp_high=bp_high,
                                          savgol_win=max(11, int(round((period_est or 8)*1.5))),
                                          med_k=5, gauss_sigma=max(1, (period_est*0.35) if period_est else 2),
                                          post_smooth=POST_PHASE_SMOOTH_SIGMA, renorm_pct=99)

# zoom indices
zoom_h = int(max(8, round(H * ZOOM_FRACTION)))
center_idx = H//2
if np.any(env['amp_s'] > 0):
    center_idx = int(np.nanargmax(env['amp_s']))
zoom_lo = max(0, center_idx - zoom_h//2)
zoom_hi = min(H, zoom_lo + zoom_h)
zoom_slice = slice(zoom_lo, zoom_hi)
t_zoom = np.arange(zoom_lo, zoom_hi)

# ---------------- GridSpec layout & plotting ----------------
# Now 6 rows: top images, four stacked small rows (c-f), bottom row with two images (g,h)
nrows = 6
ncols = 2
h_top = 3.2
h_stack_row = 0.7
height_ratios = [h_top] + [h_stack_row]*4 + [h_top]
width_ratios = [1,1]

fig = plt.figure(figsize=(FIG_W, FIG_H))
gs = gridspec.GridSpec(nrows=nrows, ncols=ncols, figure=fig,
                       height_ratios=height_ratios, width_ratios=width_ratios)

# Top images (signal, background)
ax_a = fig.add_subplot(gs[0,0])
ax_b = fig.add_subplot(gs[0,1])

# stacked zoomed plots (c-f) spanning both columns
ax_c = fig.add_subplot(gs[1, :])
ax_d = fig.add_subplot(gs[2, :], sharex=ax_c)
ax_e = fig.add_subplot(gs[3, :], sharex=ax_c)
ax_f = fig.add_subplot(gs[4, :], sharex=ax_c)

# --- Set custom x-limits for stacked plots (c–f) ---
# Use the zoomed row index range so c-f show the same window as the vertical markers
ax_c.set_xlim(t_zoom[0], t_zoom[-1])

# bottom 1 row with 2 images (g,h)
ax_g = fig.add_subplot(gs[5,0])
ax_h = fig.add_subplot(gs[5,1])

extent = [0, PIXEL_MICRONS_X, 0, PIXEL_MICRONS_Y]

# Top interferograms (square pixels)
ax_a.imshow(sig_img, cmap='gray', origin='lower', extent=extent, aspect='equal')
ax_b.imshow(bg_img,  cmap='gray', origin='lower', extent=extent, aspect='equal')

# Colored vertical segment markers that correspond to line colors in c)
col_x_um = (REF_COL + 0.5) * (PIXEL_MICRONS_X / W)
y0_um = (zoom_lo) * (PIXEL_MICRONS_Y / H)
y1_um = (zoom_hi-1) * (PIXEL_MICRONS_Y / H)

offset_um = (PIXEL_MICRONS_X / W) * 0.6
ax_a.plot([col_x_um - offset_um, col_x_um - offset_um], [y0_um, y1_um], color=sig_color, linestyle='-', linewidth=2)
ax_a.plot([col_x_um + offset_um, col_x_um + offset_um], [y0_um, y1_um], color=sig_color, linestyle='-', linewidth=2)
ax_b.plot([col_x_um - offset_um, col_x_um - offset_um], [y0_um, y1_um], color=bg_color, linestyle='-', linewidth=2)
ax_b.plot([col_x_um + offset_um, col_x_um + offset_um], [y0_um, y1_um], color=bg_color, linestyle='-', linewidth=2)

# Put x ticks and labels on top of a) and b)
ax_a.xaxis.set_label_position('top'); ax_a.xaxis.tick_top()
ax_b.xaxis.set_label_position('top'); ax_b.xaxis.tick_top()

ax_a.set_xlabel("x (μm)"); ax_a.set_ylabel("y (μm)")
ax_b.set_xlabel("x (μm)")
# nicer tick spacing: ~200 µm steps on x, ~250 µm on y
x_ticks = np.arange(0, PIXEL_MICRONS_X + 1, 200)
y_ticks = np.arange(0, PIXEL_MICRONS_Y + 1, 250)

ax_a.set_xticks(x_ticks)
ax_a.set_yticks(y_ticks)
ax_b.set_xticks(x_ticks)
ax_b.set_yticks([])


# stacked plots c-f (zoomed) with matched colors for signal/bg
ax_c.plot(t_zoom, sig_col[zoom_slice], label='signal', lw=1.1, color=sig_color)
ax_c.plot(t_zoom, bg_col[zoom_slice], label='background', lw=1.1, alpha=0.85, color=bg_color)
ax_c.set_ylabel("pixel val")

ax_d.plot(t_zoom, sig_col[zoom_slice], lw=0.9, color=sig_color, label='signal')
ax_d.plot(t_zoom, env['upper'][zoom_slice], '--', label='upper env', color='red')
ax_d.plot(t_zoom, env['lower'][zoom_slice], '--', label='lower env', color='red')
ax_d.plot(t_zoom, env['centerline'][zoom_slice], label='centerline', color='magenta')
ax_d.set_ylabel("pixel val")

ax_e.plot(t_zoom, env['I_norm'][zoom_slice]*1.2, lw=0.9, label=f'I_norm (scale={env["scale_used"]:.2f})', color='k')
ax_e.set_ylim(-1.2, 1.2)


phi_tri = env['phi_tri_wrapped'][zoom_slice] / np.pi
phi_corr = env['phi_tri_corr_wrapped'][zoom_slice] / np.pi
ax_f.plot(t_zoom, phi_tri*1.2, lw=0.9, label='tri (wrapped)/π', color='k')
ax_f.plot(t_zoom, phi_corr*1.2, lw=0.9, linestyle='--', label='slope-corr (wrapped)/π', color='magenta')
ax_f.set_ylabel("Phase")
ymin = np.nanmin(np.concatenate([phi_tri, phi_corr]))
ymax = np.nanmax(np.concatenate([phi_tri, phi_corr]))
yticks = np.linspace(np.floor(ymin), np.ceil(ymax), min(5, max(2, int(np.ceil(ymax)-np.floor(ymin)+1))))
ax_f.set_yticks(yticks)
ylbls = []
for v in yticks:
    if abs(v - round(v)) < 1e-6:
        ylbls.append(f"{int(round(v))}π")
    else:
        ylbls.append(f"{v:.2f}π")
ax_f.set_yticklabels(ylbls)
ax_f.set_xlabel("row (pixel)")

# ---------------- Bottom 2 images only (g,h) ----------------
im_g = ax_g.imshow(display_raw_map, cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower', extent=extent, aspect='equal')
ax_g.set_xlabel("x (μm)"); ax_g.set_ylabel("y (μm)")
if mask is not None:
    mo = np.ma.masked_where(~mask, mask)
    ax_g.imshow(mo, cmap='gray', alpha=0.25, origin='lower', extent=extent, aspect='equal')

im_h = ax_h.imshow(display_corrected, cmap=COLORMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX, origin='lower', extent=extent, aspect='equal')
ax_h.set_xlabel("x (μm)"); ax_h.set_yticks([])

# hide redundant y-ticks on right column top plot
ax_b.set_yticks([]); ax_h.set_yticks([])

# colorbar aligned to bottom pair
# colorbar aligned precisely to subplot h)
pos_h = ax_h.get_position()
cbar_x = pos_h.x1 + 0.008  # a bit closer
cbar_y = 0.08
cbar_w = 0.018             # slightly thinner
cbar_w = 0.02              # width of colorbar
cbar_h = pos_h.height*1.15      # match height of h)
cax = fig.add_axes([cbar_x, cbar_y, cbar_w, cbar_h])
cbar = fig.colorbar(im_h, cax=cax)
cbar.set_label('Phase shift [rad]')


# final layout tuning: smaller gap above a/b and slightly larger vertical space between panels a-f
plt.subplots_adjust(
    left=0.06,
    right=0.88,
    top=0.95,      # bring top of figure down -> less space above a,b
    bottom=0.04,
    wspace=0.02,
    hspace=0.0     # slightly larger vertical gaps between stacked panels
)

ax_a.text(s='a) raw interferogram: signal',x=80,y=950, color='white', fontsize=14)
ax_a.text(s=' lineout\nlocation',x=700,y=200, color='white', fontsize=14)

ax_b.text(s='b) raw interferogram:\n    background',x=80,y=870, color='white', fontsize=14)
ax_b.text(s=' lineout\nlocation',x=700,y=200, color='white', fontsize=14)

ax_c.text(s='c) zoomed lineout:',x=zoom_lo+10,y=np.nanmax(sig_col[zoom_slice])*1.1, color='black', fontsize=12)
ax_c.text(s='signal',x=zoom_lo+10+75,y=np.nanmax(sig_col[zoom_slice])*1.1, color=sig_color, fontsize=12)
ax_c.text(s='&',x=zoom_lo+10+100,y=np.nanmax(sig_col[zoom_slice])*1.1, color='black', fontsize=12)
ax_c.text(s='background',x=zoom_lo+10+108,y=np.nanmax(sig_col[zoom_slice])*1.1, color=bg_color, fontsize=12)
ax_c.set_ylim(np.nanmin(sig_col[zoom_slice])*0.9, np.nanmax(sig_col[zoom_slice])*1.4)

ax_d.text(s='d)',x=zoom_lo+10,y=np.nanmax(sig_col[zoom_slice])*1.1, color='black', fontsize=12)
ax_d.text(s='envelopes',x=zoom_lo+23,y=np.nanmax(sig_col[zoom_slice])*1.1, color='red', fontsize=12)
ax_d.text(s='&',x=zoom_lo+65,y=np.nanmax(sig_col[zoom_slice])*1.1, color='black', fontsize=12)
ax_d.text(s='centerline',x=zoom_lo+74,y=np.nanmax(sig_col[zoom_slice])*1.1, color='magenta', fontsize=12)
ax_d.set_ylim(np.nanmin(sig_col[zoom_slice])*0.9, np.nanmax(sig_col[zoom_slice])*1.4)

ax_e.text(s='e) normalized lineout',x=zoom_lo+10,y=1.3, color='black', fontsize=12)
ax_e.set_ylim(-1, 1.99)

ax_f.text(s='f) phase before unwrapping: triangular &',x=zoom_lo+10,y=np.nanmax(phi_corr)*1.25, color='black', fontsize=12)
ax_f.text(s='slope-corrected',x=zoom_lo+174,y=np.nanmax(phi_corr)*1.25, color='magenta', fontsize=12)
ax_f.set_ylim(np.nanmin(phi_corr)*0.9, np.nanmax(phi_corr)*1.6)

ax_g.text(s='g) single-shot raw phase map',x=60,y=1000, color='black', fontsize=14)
ax_g.text(s='mask for background fit',x=600,y=240, color='black', fontsize=10)

ax_h.text(s='h) background-subtracted\n   phase map',x=60,y=900, color='black', fontsize=14)

plt.savefig(r'c:\Users\sann7609\Documents\Oxford\Thesis\images\fringe_normalization_figure.png', dpi=300,bbox_inches='tight')