# **Color Cutter: Color-Color Cuts Notebook for SuperBIT Cluster Data**
## *First Iteration: Parabolic Cut*

This notebook implements parabolic color cuts to select background galaxies for weak lensing analysis of galaxy clusters observed by SuperBIT. We use a rotated parabola in color-color space (B-G vs U-B) to identify background galaxies. The parabola parameters are optimized using galaxies with known redshifts from NED, DESI, and LoVoCCS surveys to maximize purity (fraction of true background galaxies selected).

## Imports and Setup

### Pip install `ipyml` to make your plots interactive 
Do this once, then hash out the cell

In [None]:
#pip install ipympl

In [None]:
from astropy.table import Table
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# Import ipywidgets for interactive controls
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

## File Paths and Data Loading

In [None]:
# Data file paths
# Replace with your own /path/to/data
REDSHIFT_CATALOG = '/Users/mayaamit/Desktop/mega_color_mag_with_redshift.fits'  # Training data with known redshifts
FULL_CATALOG = '/Users/mayaamit/Desktop/mega_color_mag_catalog.fits'            # Full catalog

# Load the redshift catalog (for visualization and validation)
print("Loading redshift catalog for training/validation...")
redshift_cat = Table.read(REDSHIFT_CATALOG)
print(f"Loaded {len(redshift_cat)} objects with known redshifts")

# Load the full catalog (for applying cuts)
print("\nLoading full photometric catalog...")
full_cat = Table.read(FULL_CATALOG)
print(f"Loaded {len(full_cat)} total objects")

# Display available columns
print(f"\nAvailable columns: {redshift_cat.colnames}")


In [None]:
redshift = redshift_cat['Z_best'].astype(float)
color_bg = redshift_cat['color_bg'].astype(float)  # B-G color
color_ub = redshift_cat['color_ub'].astype(float)  # U-B color
color_bg_err = redshift_cat['color_bg_err'].astype(float)
color_ub_err = redshift_cat['color_ub_err'].astype(float)
redshift_err = redshift_cat["ZERR_best"].astype(float)
redshift_source = np.array([s.strip() for s in redshift_cat['Z_source']])

print("Data summary:")
print(f"Redshift range: {np.nanmin(redshift):.3f} to {np.nanmax(redshift):.3f}")
print(f"B-G color range: {np.nanmin(color_bg):.3f} to {np.nanmax(color_bg):.3f}")
print(f"U-B color range: {np.nanmin(color_ub):.3f} to {np.nanmax(color_ub):.3f}")
print(f"Redshift sources: {np.unique(redshift_source)}")

## Functions

In [None]:
def plot_sources_separately(redshift, color_bg, color_ub, color_bg_err, color_ub_err, 
                           redshift_source, z_thresh, err_thresh, xlim, ylim):
    """
    Plot color-color diagrams for each redshift source separately.
    Split each source by redshift (above/below z_thresh) and show statistics.
    """
    
    # Get unique sources
    sources = ['NED', 'DESI', 'LoVoCCS']
    colors = {'NED': 'orange', 'DESI': 'orange', 'LoVoCCS': 'orange'}
    
    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(17, 6))
    
    print("Source Statistics:")
    print("=" * 60)
    
    for i, source in enumerate(sources):
        ax = axes[i]
        
        # Mask for this source
        source_mask = redshift_source == source
        
        if not np.any(source_mask):
            print(f"{source}: No objects found")
            ax.text(0.5, 0.5, f'No {source} objects', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=14)
            ax.set_title(f'{source}')
            continue
        
        # Extract data for this source
        z_src = redshift[source_mask]
        bg_src = color_bg[source_mask]
        ub_src = color_ub[source_mask]
        bg_err_src = color_bg_err[source_mask]
        ub_err_src = color_ub_err[source_mask]
        
        # Valid data mask (no NaN values)
        valid_mask = ~(np.isnan(z_src) | np.isnan(bg_src) | np.isnan(ub_src))
        
        # Good color measurements mask
        good_colors_mask = (bg_err_src < err_thresh) & (ub_err_src < err_thresh)
        
        # Combined mask for plotting
        plot_mask = valid_mask & good_colors_mask
        
        # Split by redshift
        low_z_mask = (z_src < z_thresh) & plot_mask
        high_z_mask = (z_src >= z_thresh) & plot_mask
        
        
        # Plot high-z objects (background)
        if np.any(high_z_mask):
            ax.scatter(bg_src[high_z_mask], ub_src[high_z_mask], 
                      c='red', alpha=0.3, s=10, label=f'z ≥ {z_thresh}')
            
        # Plot low-z objects (foreground/cluster)
        if np.any(low_z_mask):
            ax.scatter(bg_src[low_z_mask], ub_src[low_z_mask], 
                      c='blue', alpha=0.3, s=10, label=f'z < {z_thresh}')
        
        # Formatting
        ax.set_xlabel('B - G Color')
        ax.set_ylabel('U - B Color')
        ax.set_title(f'{source}', color=colors[source], fontweight='bold')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # Calculate statistics
        total_source = np.sum(source_mask)
        total_valid = np.sum(valid_mask)
        total_good_colors = np.sum(plot_mask)
        low_z_count = np.sum(low_z_mask)
        high_z_count = np.sum(high_z_mask)
        cut_by_errors = total_valid - total_good_colors
        
        # Print statistics
        print(f"{source}: {total_source:,} total | {low_z_count:,} low-z | {high_z_count:,} high-z")
        print(f"       {cut_by_errors:,} cut due to color errors (>{err_thresh})")
        print()
    
    plt.tight_layout()
    plt.show()
    
    return fig, axes

In [None]:

def plot_color_color_interactive(redshift, color_bg, color_ub, redshift_source, 
                                training_mask, z_thresh, xlim, ylim, cluster_name):
    """
    Interactive color-color plot with toggleable elements:
    • Scatter points (background/foreground with transparency)
    • NED contours (separated by redshift)
    • DESI contours (separated by redshift)  
    • Combined contours (all sources, separated by redshift)
    
    Uses histogram-based contours with gaussian filtering for better performance
    and visual quality compared to KDE.
    
    Parameters
    ----------
    redshift, color_bg, color_ub : array-like
        Galaxy properties
    redshift_source : array-like
        Source of redshift measurement ('NED', 'DESI', etc.)
    training_mask : boolean array
        Mask for high-quality training data
    z_thresh : float
        Redshift threshold for background/foreground split
    xlim, ylim : tuple
        Plot limits for color axes
    cluster_name : str
        Name of target cluster for plot title
        
    Returns
    -------
    None (displays interactive plot with controls)
    """
    
    def contour_xy(x, y, color, label, ax, xlim, ylim):
        """
        Create smooth contours using histogram and gaussian filter
        """
        if len(x) < 20:  # Need enough points for meaningful contours
            return
            
        # Create 2D histogram
        H, xe, ye = np.histogram2d(x, y, bins=80, range=[xlim, ylim])
        H = gaussian_filter(H.T, sigma=2.0)  # Smooth the histogram
        H /= H.max()  # Normalize to [0,1]
        
        # Create meshgrid for contour plotting
        X, Y = np.meshgrid(xe[:-1], ye[:-1])
        
        # Draw contours at multiple levels
        ax.contour(X, Y, H,
                  levels=[0.1, 0.25, 0.5, 0.75, 0.9],
                  colors=color, linewidths=1.5, alpha=1.0)  # Fully opaque contours
        
        # Add invisible line for legend
        ax.plot([], [], color=color, linewidth=2, label=label)
    
    # Extract training data
    train_z = redshift[training_mask]
    train_bg = color_bg[training_mask]
    train_ub = color_ub[training_mask]
    train_source = redshift_source[training_mask]
    
    # Create redshift masks
    background_mask = train_z > z_thresh  # High-z (background)
    foreground_mask = train_z <= z_thresh  # Low-z (foreground)
    
    # Create checkboxes
    scatter_cb = widgets.Checkbox(value=True, description="Scatter points")
    ned_cb = widgets.Checkbox(value=False, description="NED contours")
    desi_cb = widgets.Checkbox(value=False, description="DESI contours")
    combined_cb = widgets.Checkbox(value=False, description="Combined contours")
    
    # Create output widget for the plot
    output = widgets.Output()
    
    def update_plot(*args):
        """Update the plot based on checkbox states"""
        with output:
            clear_output(wait=True)
            
            # Create figure
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # Plot scatter points if enabled (with transparency)
            if scatter_cb.value:
                # Background objects (blue, transparent)
                if np.any(background_mask):
                    ax.scatter(train_bg[background_mask], train_ub[background_mask], 
                              c='red', alpha=0.3, s=8, label=f'Background (z > {z_thresh})')
                
                # Foreground objects (red, transparent)
                if np.any(foreground_mask):
                    ax.scatter(train_bg[foreground_mask], train_ub[foreground_mask], 
                              c='blue', alpha=0.3, s=8, label=f'Foreground (z ≤ {z_thresh})')
            
            # NED contours (separated by redshift)
            if ned_cb.value:
                ned_mask = train_source == 'NED'
                if np.any(ned_mask):
                    # NED foreground (red contours)
                    ned_fg_mask = ned_mask & foreground_mask
                    if np.any(ned_fg_mask):
                        contour_xy(train_bg[ned_fg_mask], train_ub[ned_fg_mask], 
                                  'blue', f'NED z < {z_thresh}', ax, xlim, ylim)
                    
                    # NED background (blue contours) 
                    ned_bg_mask = ned_mask & background_mask
                    if np.any(ned_bg_mask):
                        contour_xy(train_bg[ned_bg_mask], train_ub[ned_bg_mask], 
                                  'red', f'NED z ≥ {z_thresh}', ax, xlim, ylim)
            
            # DESI contours (separated by redshift)
            if desi_cb.value:
                desi_mask = train_source == 'DESI'
                if np.any(desi_mask):
                    # DESI foreground (red contours)
                    desi_fg_mask = desi_mask & foreground_mask
                    if np.any(desi_fg_mask):
                        contour_xy(train_bg[desi_fg_mask], train_ub[desi_fg_mask], 
                                  'blue', f'DESI z < {z_thresh}', ax, xlim, ylim)
                    
                    # DESI background (blue contours)
                    desi_bg_mask = desi_mask & background_mask  
                    if np.any(desi_bg_mask):
                        contour_xy(train_bg[desi_bg_mask], train_ub[desi_bg_mask], 
                                  'red', f'DESI z ≥ {z_thresh}', ax, xlim, ylim)
            
            # Combined contours (all sources, separated by redshift)
            if combined_cb.value:
                # Combined foreground (red contours)
                if np.any(foreground_mask):
                    contour_xy(train_bg[foreground_mask], train_ub[foreground_mask], 
                              'blue', f'All sources z < {z_thresh}', ax, xlim, ylim)
                
                # Combined background (blue contours)
                if np.any(background_mask):
                    contour_xy(train_bg[background_mask], train_ub[background_mask], 
                              'red', f'All sources z ≥ {z_thresh}', ax, xlim, ylim)
            
            # Formatting
            ax.set_xlabel('B - G Color', fontsize=12)
            ax.set_ylabel('U - B Color', fontsize=12)
            ax.set_title(f'{cluster_name}: Color-Color Distribution by Redshift\n'
                        f'Training sample (z_thresh = {z_thresh})', fontsize=14)
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            ax.grid(True, alpha=0.3)
            
            # Add legend if there are any labeled items
            handles, labels = ax.get_legend_handles_labels()
            if handles:
                ax.legend(fontsize=10)
            
            # Add sample size info
            n_ned = np.sum(train_source == 'NED')
            n_desi = np.sum(train_source == 'DESI')
            n_total = np.sum(training_mask)
            n_bg = np.sum(background_mask)
            n_fg = np.sum(foreground_mask)
            
            info_text = f'Sample: {n_total:,} total | {n_ned:,} NED | {n_desi:,} DESI\n'
            info_text += f'Background: {n_bg:,} | Foreground: {n_fg:,}'
            
            ax.text(0.02, 0.98, info_text, transform=ax.transAxes, 
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.9),
                   verticalalignment='top', fontsize=10)
            
            plt.tight_layout()
            plt.show()
    
    # Connect checkboxes to update function
    scatter_cb.observe(update_plot, names='value')
    ned_cb.observe(update_plot, names='value')
    desi_cb.observe(update_plot, names='value')
    combined_cb.observe(update_plot, names='value')
    
    # Display controls and output
    controls = widgets.HBox([scatter_cb, ned_cb, desi_cb, combined_cb])
    display(controls)
    display(output)
    
    # Initial plot
    update_plot()

In [None]:
def add_parabola_overlay(ax, A, h, k, theta, xlim, ylim, color='black', linewidth=2):
    """
    Add a rotated parabola overlay to an existing matplotlib axis.
    
    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axis to add the parabola to
    A : float
        Parabola curvature parameter
    h, k : float
        Vertex coordinates of the parabola
    theta : float
        Rotation angle in radians
    xlim, ylim : tuple
        Plot limits to constrain the parabola
    color : str
        Color of the parabola line
    linewidth : float
        Width of the parabola line
    """
    
    # Create a grid and find the parabola boundary
    x_grid = np.linspace(xlim[0], xlim[1], 300)
    y_grid = np.linspace(ylim[0], ylim[1], 300)
    X, Y = np.meshgrid(x_grid, y_grid)
    
    # Transform to rotated coordinates
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    
    dx = X - h
    dy = Y - k
    X_rot = cos_theta * dx + sin_theta * dy
    Y_rot = -sin_theta * dx + cos_theta * dy
    
    # Parabola equation in rotated frame: Y_rot = A * X_rot^2
    parabola_eq = Y_rot - A * X_rot**2
    
    # Draw the parabola as a contour line at level 0
    contour = ax.contour(X, Y, parabola_eq, levels=[0], colors=color, 
                        linewidths=linewidth, linestyles='-')
    
    return contour

def plot_color_color_with_parabola(redshift, color_bg, color_ub, redshift_source, 
                                  training_mask, z_thresh, xlim, ylim, cluster_name,
                                  A, h, k, theta):
    """
    Interactive color-color plot with parabola cut overlay.
    
    Same as plot_color_color_interactive but adds a parabola cut visualization
    and purity statistics.
    
    Parameters
    ----------
    redshift, color_bg, color_ub, redshift_source, training_mask : array-like
        Galaxy data and masks
    z_thresh : float
        Redshift threshold
    xlim, ylim : tuple
        Plot limits
    cluster_name : str
        Cluster name for title
    A, h, k, theta : float
        Parabola parameters
    """
    
    def apply_parabola_cut(bg, ub, A, h, k, theta):
        """
        Apply parabola cut and return mask for objects above/below the parabola
        """
        # Rotation matrix components
        cos_theta = np.cos(theta)
        sin_theta = np.sin(theta)
        
        # Translate to vertex coordinates
        dx = bg - h
        dy = ub - k
        
        # Rotate coordinates by -theta (inverse rotation)
        x_rot = cos_theta * dx + sin_theta * dy
        y_rot = -sin_theta * dx + cos_theta * dy
        
        # Apply parabola test: objects above parabola have y_rot > A * x_rot^2
        above_parabola = y_rot > (A * x_rot**2)
        below_parabola = y_rot <= (A * x_rot**2)
        
        return above_parabola, below_parabola
    
    def contour_xy(x, y, color, label, ax, xlim, ylim):
        """Create smooth contours using histogram and gaussian filter"""
        if len(x) < 20:
            return
            
        H, xe, ye = np.histogram2d(x, y, bins=80, range=[xlim, ylim])
        H = gaussian_filter(H.T, sigma=2.0)
        H /= H.max()
        
        X, Y = np.meshgrid(xe[:-1], ye[:-1])
        ax.contour(X, Y, H,
                  levels=[0.1, 0.25, 0.5, 0.75, 0.9],
                  colors=color, linewidths=1.5, alpha=1.0)
        ax.plot([], [], color=color, linewidth=2, label=label)
    
    # Extract training data
    train_z = redshift[training_mask]
    train_bg = color_bg[training_mask]
    train_ub = color_ub[training_mask]
    train_source = redshift_source[training_mask]
    
    # Create redshift masks (CORRECTED COLOR ASSIGNMENT)
    background_mask = train_z > z_thresh   # High-z (background, RED points)
    foreground_mask = train_z <= z_thresh  # Low-z (foreground, BLUE points)
    
    # Apply parabola cut to calculate statistics
    above_parabola, below_parabola = apply_parabola_cut(train_bg, train_ub, A, h, k, theta)
    
    # Calculate purity statistics
    n_total = len(train_z)
    n_above = np.sum(above_parabola)
    n_below = np.sum(below_parabola)
    
    # Points above parabola by redshift
    above_and_background = np.sum(above_parabola & background_mask)  # Red points above
    above_and_foreground = np.sum(above_parabola & foreground_mask)  # Blue points above
    
    # Points below parabola by redshift  
    below_and_background = np.sum(below_parabola & background_mask)  # Red points below
    below_and_foreground = np.sum(below_parabola & foreground_mask)  # Blue points below
    
    # Calculate purities (CORRECTED LOGIC)
    purity_above = above_and_foreground / n_above if n_above > 0 else 0  # Blue fraction above
    purity_below = below_and_background / n_below if n_below > 0 else 0  # Red fraction below
    
    # Print statistics (CORRECTED COLORS AND LOGIC)
    print("Parabola Cut Statistics:")
    print("=" * 50)
    print(f"Selected {n_above:,}/{n_total:,} pts (above the rotated parabola)")
    print(f"Of those: {above_and_foreground:,} blue (z ≤ {z_thresh}), {above_and_background:,} red (z > {z_thresh})")
    print(f"Purity above = {purity_above:.3f}, Purity below = {purity_below:.3f}")
    print()
    print(f"Below parabola: {n_below:,} pts ({below_and_foreground:,} blue, {below_and_background:,} red)")
    print(f"Note: 'Above purity' = fraction of foreground galaxies above parabola")
    print(f"      'Below purity' = fraction of background galaxies below parabola")
    
    # Create checkboxes
    scatter_cb = widgets.Checkbox(value=True, description="Scatter points")
    ned_cb = widgets.Checkbox(value=False, description="NED contours")
    desi_cb = widgets.Checkbox(value=False, description="DESI contours")
    combined_cb = widgets.Checkbox(value=False, description="Combined contours")
    parabola_cb = widgets.Checkbox(value=True, description="Parabola cut")
    
    # Create output widget for the plot
    output = widgets.Output()
    
    def update_plot(*args):
        """Update the plot based on checkbox states"""
        with output:
            clear_output(wait=True)
            
            # Create figure
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # Plot scatter points if enabled (CORRECTED COLORS)
            if scatter_cb.value:
                if np.any(background_mask):
                    ax.scatter(train_bg[background_mask], train_ub[background_mask], 
                              c='red', alpha=0.4, s=8, label=f'Background (z > {z_thresh})')
                
                if np.any(foreground_mask):
                    ax.scatter(train_bg[foreground_mask], train_ub[foreground_mask], 
                              c='blue', alpha=0.4, s=8, label=f'Foreground (z ≤ {z_thresh})')
            
            # NED contours (separated by redshift) - CORRECTED COLORS
            if ned_cb.value:
                ned_mask = train_source == 'NED'
                if np.any(ned_mask):
                    ned_fg_mask = ned_mask & foreground_mask
                    if np.any(ned_fg_mask):
                        contour_xy(train_bg[ned_fg_mask], train_ub[ned_fg_mask], 
                                  'blue', f'NED z ≤ {z_thresh}', ax, xlim, ylim)
                    
                    ned_bg_mask = ned_mask & background_mask
                    if np.any(ned_bg_mask):
                        contour_xy(train_bg[ned_bg_mask], train_ub[ned_bg_mask], 
                                  'red', f'NED z > {z_thresh}', ax, xlim, ylim)
            
            # DESI contours (separated by redshift) - CORRECTED COLORS
            if desi_cb.value:
                desi_mask = train_source == 'DESI'
                if np.any(desi_mask):
                    desi_fg_mask = desi_mask & foreground_mask
                    if np.any(desi_fg_mask):
                        contour_xy(train_bg[desi_fg_mask], train_ub[desi_fg_mask], 
                                  'blue', f'DESI z ≤ {z_thresh}', ax, xlim, ylim)
                    
                    desi_bg_mask = desi_mask & background_mask  
                    if np.any(desi_bg_mask):
                        contour_xy(train_bg[desi_bg_mask], train_ub[desi_bg_mask], 
                                  'red', f'DESI z > {z_thresh}', ax, xlim, ylim)
            
            # Combined contours (all sources, separated by redshift) - CORRECTED COLORS
            if combined_cb.value:
                if np.any(foreground_mask):
                    contour_xy(train_bg[foreground_mask], train_ub[foreground_mask], 
                              'blue', f'All sources z ≤ {z_thresh}', ax, xlim, ylim)
                
                if np.any(background_mask):
                    contour_xy(train_bg[background_mask], train_ub[background_mask], 
                              'red', f'All sources z > {z_thresh}', ax, xlim, ylim)
            
            # Add parabola overlay if enabled
            if parabola_cb.value:
                add_parabola_overlay(ax, A, h, k, theta, xlim, ylim, 
                                   color='black', linewidth=3)
                ax.plot([], [], color='black', linewidth=3, label='Parabola cut')
            
            # Formatting
            ax.set_xlabel('B - G Color', fontsize=12)
            ax.set_ylabel('U - B Color', fontsize=12)
            ax.set_title(f'{cluster_name}: Color-Color Distribution with Parabola Cut\n'
                        f'A={A}, vertex=({h:.3f}, {k:.3f}), θ={np.degrees(theta):.1f}°', 
                        fontsize=14)
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            ax.grid(True, alpha=0.3)
            
            # Add legend
            handles, labels = ax.get_legend_handles_labels()
            if handles:
                ax.legend(fontsize=10)
            
            # Add purity statistics to the plot (CORRECTED COLORS)
            purity_text = f'Above parabola: {n_above:,} pts | Purity: {purity_above:.3f}\n'
            purity_text += f'Below parabola: {n_below:,} pts | Purity: {purity_below:.3f}'
            
            ax.text(0.02, 0.02, purity_text, transform=ax.transAxes, 
                   bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.9),
                   fontsize=10)
            
            plt.tight_layout()
            plt.show()
    
    # Connect checkboxes to update function
    scatter_cb.observe(update_plot, names='value')
    ned_cb.observe(update_plot, names='value')
    desi_cb.observe(update_plot, names='value')
    combined_cb.observe(update_plot, names='value')
    parabola_cb.observe(update_plot, names='value')
    
    # Display controls and output
    controls = widgets.HBox([scatter_cb, ned_cb, desi_cb, combined_cb, parabola_cb])
    display(controls)
    display(output)
    
    # Initial plot
    update_plot()

In [None]:
def apply_parabola_cut_to_cluster(catalog_file, cluster_name, A, h, k, theta, 
                                 output_dir=None, verbose=True):
    """
    Apply parabola cut to full cluster catalog and save background/foreground samples.
    
    Loads the full photometric catalog, filters for the specified cluster,
    applies the parabola cut, and saves two separate FITS files for 
    background and foreground galaxies.
    
    Parameters
    ----------
    catalog_file : str
        Path to the full catalog FITS file
    cluster_name : str
        Name of cluster to filter (must match CLUSTER column exactly)
    A, h, k, theta : float
        Parabola parameters (curvature, vertex coords, rotation angle)
    output_dir : str, optional
        Output directory. If None, saves to current directory.
    verbose : bool
        Print progress and statistics
        
    Returns
    -------
    background_catalog : astropy.table.Table
        Catalog of background galaxies (below parabola)
    foreground_catalog : astropy.table.Table  
        Catalog of foreground galaxies (above parabola)
    """
    
    if verbose:
        print(f"Applying Parabola Cut to Cluster: {cluster_name}")
        print("=" * 50)
        print(f"Loading catalog: {catalog_file}")
    
    # Load the full catalog
    try:
        full_cat = Table.read(catalog_file)
    except Exception as e:
        print(f"Error loading catalog: {e}")
        return None, None
    
    if verbose:
        print(f"Total objects in catalog: {len(full_cat):,}")
    
    # Filter for the specified cluster
    cluster_mask = full_cat['CLUSTER'] == cluster_name
    cluster_objects = full_cat[cluster_mask]
    
    if verbose:
        print(f"Objects in {cluster_name}: {len(cluster_objects):,}")
    
    if len(cluster_objects) == 0:
        print(f"WARNING: No objects found for cluster '{cluster_name}'")
        available_clusters = np.unique(full_cat['CLUSTER'])
        print(f"Available clusters: {available_clusters[:10]}...")  # Show first 10
        return None, None
    
    # Extract colors for the cluster objects
    color_bg = cluster_objects['color_bg'].astype(float)
    color_ub = cluster_objects['color_ub'].astype(float)
    
    # Remove any objects with NaN colors
    valid_colors = ~(np.isnan(color_bg) | np.isnan(color_ub))
    if verbose:
        print(f"Objects with valid colors: {np.sum(valid_colors):,}")
        print(f"Objects with NaN colors: {np.sum(~valid_colors):,}")
    
    # Apply valid color mask
    cluster_objects = cluster_objects[valid_colors]
    color_bg = color_bg[valid_colors]
    color_ub = color_ub[valid_colors]
    
    if len(cluster_objects) == 0:
        print("ERROR: No objects with valid colors found")
        return None, None
    
    # Apply the parabola cut
    if verbose:
        print(f"\nApplying parabola cut:")
        print(f"Parameters: A={A}, vertex=({h:.3f}, {k:.3f}), θ={np.degrees(theta):.1f}°")
    
    # Rotation matrix components
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    
    # Translate to vertex and rotate coordinates
    dx = color_bg - h
    dy = color_ub - k
    
    # Rotate by -theta (inverse rotation)
    x_rot = cos_theta * dx + sin_theta * dy
    y_rot = -sin_theta * dx + cos_theta * dy
    
    # Apply parabola cut
    below_parabola = y_rot <= (A * x_rot**2)  # Background galaxies
    above_parabola = y_rot > (A * x_rot**2)   # Foreground/cluster galaxies
    
    # Split the catalog
    background_catalog = cluster_objects[below_parabola]
    foreground_catalog = cluster_objects[above_parabola]
    
    if verbose:
        print(f"\nCut Results:")
        print(f"Background galaxies (below parabola): {len(background_catalog):,}")
        print(f"Foreground galaxies (above parabola): {len(foreground_catalog):,}")
        print(f"Background fraction: {len(background_catalog)/len(cluster_objects):.3f}")
    
    # Generate output filenames
    if output_dir is None:
        output_dir = "."
    
    # Clean cluster name for filename
    clean_name = cluster_name.lower().replace(' ', '_').replace('-', '_')
    background_file = f"{output_dir}/{clean_name}_background_galaxies.fits"
    foreground_file = f"{output_dir}/{clean_name}_foreground_galaxies.fits"
    
    # Save the catalogs
    if verbose:
        print(f"\nSaving catalogs:")
        print(f"Background: {background_file}")
        print(f"Foreground: {foreground_file}")
    
    try:
        background_catalog.write(background_file, overwrite=True)
        foreground_catalog.write(foreground_file, overwrite=True)
        
        if verbose:
            print("✓ Catalogs saved successfully!")
            
    except Exception as e:
        print(f"Error saving catalogs: {e}")
        return background_catalog, foreground_catalog
    
    # Summary statistics
    if verbose:
        print(f"\nSummary for {cluster_name}:")
        print(f"Total cluster objects: {len(cluster_objects):,}")
        print(f"Background sample: {len(background_catalog):,} ({len(background_catalog)/len(cluster_objects)*100:.1f}%)")
        print(f"Foreground sample: {len(foreground_catalog):,} ({len(foreground_catalog)/len(cluster_objects)*100:.1f}%)")
        print(f"Files saved: {background_file}, {foreground_file}")
    
    return background_catalog, foreground_catalog


## **Step 1**: Pick A Cluster, Plot Catalog Sources Seperately 

- Choose a cluster and find its redshift here: `https://github.com/superbit-collaboration/superbit-lensing/blob/main/data/SuperBIT_target_list.csv`

- Define `Z_THRESH` to be that of your cluster plus 0.025 (because our cluster has thickness). This is the redshift boundary used to split “low‐z” vs “high‐z” objects.
Objects with Z_best < z_thresh will be plotted in blue, and those with Z_best ≥ z_thresh in red.

- Define `ERR_THRESH`. This is the maximum allowed color measurement error in either band. Any object with color_bg_err or color_ub_err ≥ err_thresh will be excluded from the plots.

- Plot all your objects seperated by the catalog source (Desi, Ned, and Lovoccs) by calling `plot_sources_separately` below. This will help you get a general feel for what your data looks like! You can see that a demarkation between low and high z sources is not so clear for Lovoccs data. It also has the largest population of objects with high color error. Because of this, moving forward we will work only with DESI and NED. 

In [None]:
Z_THRESH = 0.193   # Cluster redshift + 0.025 (for cluster thickness) 
ERR_THRESH = 0.5    # Maximum allowed color error
CLUSTER_NAME = 'Abell3411'

# Color-color plot limits -- these are good starting points, you can adjust as needed! 
XLIM = (-2, 4)    # B-G color range for plotting
YLIM = (-2, 3)    # U-B color range for plotting

In [None]:
%matplotlib inline
fig, axes = plot_sources_separately(
    redshift, color_bg, color_ub, color_bg_err, color_ub_err, 
    redshift_source, Z_THRESH, ERR_THRESH, XLIM, YLIM
)

## **Step 2**: Create reusable masks for the redshift catalog, filtering our Lovoccs and high error data

Just run this, no other steps required :)

In [None]:
# Now create reusable masks for the redshift catalog
# This avoids recreating the same masks in multiple functions

# Valid redshift sources (NED and DESI are most reliable)
reliable_sources_mask = (redshift_source == "DESI") | (redshift_source == "NED")

# Good color measurements (low errors)
good_colors_mask = (color_bg_err < ERR_THRESH) & (color_ub_err < ERR_THRESH)

# No NaN values in colors or redshift
valid_data_mask = ~(np.isnan(color_bg) | np.isnan(color_ub) | np.isnan(redshift))

# Combined mask for "training"
training_mask = reliable_sources_mask & good_colors_mask & valid_data_mask

# Background vs foreground classification
background_mask = redshift > Z_THRESH  # Objects behind the cluster
foreground_mask = redshift < Z_THRESH  # Objects in front of the cluster

# Calculate data quality statistics
total_objects = len(redshift)
nan_objects = np.sum(~valid_data_mask)
high_error_objects = np.sum(~good_colors_mask & valid_data_mask)  # Valid data but high errors
unreliable_sources = np.sum(~reliable_sources_mask)

print("Summary:")
print("=" * 45)
print(f"Total objects in catalog: {total_objects:,}")
print(f"Objects with NaN values: {nan_objects:,}")
print(f"LoVoCCS Objects Excluded: {unreliable_sources:,}")
print()
print(f"Reliable sources (NED/DESI): {np.sum(reliable_sources_mask):,}")
print(f"Objects with high color error (>{ERR_THRESH}): {high_error_objects:,}")
print(f"Combined Mask for Training: {np.sum(training_mask):,}")
print()
print("Redshifts of our Training Sample):")
print(f"Background objects (z > {Z_THRESH}): {np.sum(background_mask & training_mask):,}")
print(f"Foreground objects (z < {Z_THRESH}): {np.sum(foreground_mask & training_mask):,}")

## **Step 3:** Plot Desi and Ned together in an interactive plot with contours

In [None]:
%matplotlib inline
# Create the interactive plot
plot_color_color_interactive(
    redshift, color_bg, color_ub, redshift_source, 
    training_mask, Z_THRESH, XLIM, YLIM, CLUSTER_NAME
)

## **Step 4:** Define parabola paramters for the cut


In [None]:
A_PARABOLA = 5          # Curvature parameter (higher = narrower parabola)
H_VERTEX = 1.251        # Vertex x-coordinate in B-G color
K_VERTEX = 0.256        # Vertex y-coordinate in U-B color  
THETA_ROTATION = np.deg2rad(9)  # Rotation angle in radians (9 degrees)


In [None]:
plot_color_color_with_parabola(
    redshift, color_bg, color_ub, redshift_source, 
    training_mask, Z_THRESH, XLIM, YLIM, CLUSTER_NAME,
    A_PARABOLA, H_VERTEX, K_VERTEX, THETA_ROTATION
)

## **Step 5:** Apply the cut to your cluster and save foreground and background fits files

In [None]:
background_catalog, foreground_catalog = apply_parabola_cut_to_cluster(
    catalog_file=FULL_CATALOG,
    cluster_name=CLUSTER_NAME,
    A=A_PARABOLA, 
    h=H_VERTEX, 
    k=K_VERTEX, 
    theta=THETA_ROTATION,
    verbose=True
)
