In [1]:
#!/usr/bin/env python3
"""
Spectral plotting utilities for binary star analysis
"""

import numpy as np
import matplotlib.pyplot as plt
from astropy.convolution import convolve, Gaussian1DKernel
from astropy.io import fits


def plot_composite_spectral_comparison(results, best_fit_params=None, save_path=None, figsize=(15, 10)):
    """
    Create comprehensive spectral comparison plot for composite spectra showing:
    - Both observed composite spectra (A and B)
    - Best-fitting composite model spectra
    - Individual stellar component contributions
    - Residuals for both spectra
    
    Parameters:
    results: output from complete_composite_binary_workflow()
    best_fit_params: dict with best-fit parameters (optional)
    save_path: path to save the plot (optional)
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    if 'composite_spec_data' not in results:
        print("No composite spectral data available in results")
        return None
    
    composite_spec_data = results['composite_spec_data']
    spectrum_A = composite_spec_data['spectrum_A']
    spectrum_B = composite_spec_data['spectrum_B']
    
    # Set up the plot with subplots for spectra and residuals
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
    
    # Convert wavelengths to Angstroms for plotting
    wavelengths_A = spectrum_A['wavelength']
    wavelengths_B = spectrum_B['wavelength']
    
    # Plot spectrum A
    ax1.plot(wavelengths_A, spectrum_A['flux'], 'k-', alpha=0.7, label='Observed A')
    ax1.fill_between(wavelengths_A, 
                     spectrum_A['flux'] - spectrum_A['flux_error'],
                     spectrum_A['flux'] + spectrum_A['flux_error'],
                     alpha=0.3, color='gray', label='1σ uncertainty')
    
    # Plot spectrum B
    ax2.plot(wavelengths_B, spectrum_B['flux'], 'k-', alpha=0.7, label='Observed B')
    ax2.fill_between(wavelengths_B,
                     spectrum_B['flux'] - spectrum_B['flux_error'],
                     spectrum_B['flux'] + spectrum_B['flux_error'],
                     alpha=0.3, color='gray', label='1σ uncertainty')
    
    # If we have likelihood results, show best-fit models
    if 'composite_likelihood_results' in results:
        try:
            from composite_spectra import create_synthetic_composite_spectra
            from sbi_workflow import generate_desi_spectra_from_posterior
            
            likelihood_results = results['composite_likelihood_results']
            log_weights = likelihood_results['log_likelihood_weights']
            
            # Find best-fit indices
            best_idx = np.unravel_index(np.argmax(log_weights), log_weights.shape)
            
            # Get best-fit samples
            photo_samples = results['photometry_only_samples']
            best_star1_params = photo_samples['star1'][best_idx[0]]
            best_star2_params = photo_samples['star2'][best_idx[1]]
            
            # Get flux ratios
            flux_ratios_A = likelihood_results['flux_ratios_used']['spectrum_A']
            flux_ratios_B = likelihood_results['flux_ratios_used']['spectrum_B']
            best_ratio_A = flux_ratios_A[best_idx[2]]
            best_ratio_B = flux_ratios_B[best_idx[3]]
            
            # Generate model spectra
            model_spec1, wavelengths = generate_desi_spectra_from_posterior(
                best_star1_params.reshape(1, -1), exclude_ebv=True, device='cpu'
            )
            model_spec2, _ = generate_desi_spectra_from_posterior(
                best_star2_params.reshape(1, -1), exclude_ebv=True, device='cpu'
            )
            
            # Create composite model spectra
            model_composite_A = best_ratio_A * model_spec1[0] + (1 - best_ratio_A) * model_spec2[0]
            model_composite_B = best_ratio_B * model_spec1[0] + (1 - best_ratio_B) * model_spec2[0]
            
            # Apply normalization if available
            if 'normalizations' in likelihood_results:
                norm_A = likelihood_results['normalizations']['spectrum_A'][best_idx]
                norm_B = likelihood_results['normalizations']['spectrum_B'][best_idx]
                model_composite_A *= norm_A
                model_composite_B *= norm_B
            
            # Convert model wavelengths to Angstroms
            wavelengths_model = wavelengths * 1e4
            
            # Plot best-fit models
            ax1.plot(wavelengths_model, model_composite_A, 'r-', alpha=0.8, 
                    label=f'Best-fit model (ratio={best_ratio_A:.2f})')
            ax1.plot(wavelengths_model, best_ratio_A * model_spec1[0], 'b--', alpha=0.6, 
                    label=f'Star 1 contribution ({best_ratio_A:.2f}×)')
            ax1.plot(wavelengths_model, (1 - best_ratio_A) * model_spec2[0], 'g--', alpha=0.6,
                    label=f'Star 2 contribution ({1-best_ratio_A:.2f}×)')
            
            ax2.plot(wavelengths_model, model_composite_B, 'r-', alpha=0.8,
                    label=f'Best-fit model (ratio={best_ratio_B:.2f})')
            ax2.plot(wavelengths_model, best_ratio_B * model_spec1[0], 'b--', alpha=0.6,
                    label=f'Star 1 contribution ({best_ratio_B:.2f}×)')
            ax2.plot(wavelengths_model, (1 - best_ratio_B) * model_spec2[0], 'g--', alpha=0.6,
                    label=f'Star 2 contribution ({1-best_ratio_B:.2f}×)')
            
            # Plot residuals
            # Interpolate model to observed wavelengths for residuals
            from scipy.interpolate import interp1d
            
            interp_A = interp1d(wavelengths_model, model_composite_A, 
                               bounds_error=False, fill_value=np.nan)
            interp_B = interp1d(wavelengths_model, model_composite_B,
                               bounds_error=False, fill_value=np.nan)
            
            model_A_interp = interp_A(wavelengths_A)
            model_B_interp = interp_B(wavelengths_B)
            
            residuals_A = spectrum_A['flux'] - model_A_interp
            residuals_B = spectrum_B['flux'] - model_B_interp
            
            ax3.plot(wavelengths_A, residuals_A / spectrum_A['flux_error'], 'k-', alpha=0.7)
            ax3.axhline(0, color='r', linestyle='--', alpha=0.5)
            ax3.axhline(1, color='gray', linestyle=':', alpha=0.5)
            ax3.axhline(-1, color='gray', linestyle=':', alpha=0.5)
            
            ax4.plot(wavelengths_B, residuals_B / spectrum_B['flux_error'], 'k-', alpha=0.7)
            ax4.axhline(0, color='r', linestyle='--', alpha=0.5)
            ax4.axhline(1, color='gray', linestyle=':', alpha=0.5)
            ax4.axhline(-1, color='gray', linestyle=':', alpha=0.5)
            
        except Exception as e:
            print(f"Could not generate model comparison: {e}")
    
    # Formatting
    ax1.set_title('Composite Spectrum A')
    ax1.set_ylabel('Flux')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.set_title('Composite Spectrum B')
    ax2.set_ylabel('Flux')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    ax3.set_title('Residuals A')
    ax3.set_xlabel('Wavelength (Å)')
    ax3.set_ylabel('(Obs - Model) / σ')
    ax3.grid(True, alpha=0.3)
    
    ax4.set_title('Residuals B')
    ax4.set_xlabel('Wavelength (Å)')
    ax4.set_ylabel('(Obs - Model) / σ')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved composite spectral comparison to {save_path}")
    
    return fig


def plot_flux_ratio_comparison(results, save_path=None, figsize=(12, 6)):
    """
    Plot flux ratio analysis showing the different flux contributions in both composite spectra
    
    Parameters:
    results: output from complete_composite_binary_workflow()
    save_path: path to save the plot (optional)
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    if 'composite_likelihood_results' not in results:
        print("No composite likelihood results available")
        return None
    
    likelihood_results = results['composite_likelihood_results']
    flux_ratios_A = likelihood_results['flux_ratios_used']['spectrum_A']
    flux_ratios_B = likelihood_results['flux_ratios_used']['spectrum_B']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Plot flux ratio distributions if we have multiple samples
    if len(flux_ratios_A) > 1:
        ax1.hist(flux_ratios_A, bins=20, alpha=0.6, density=True, label='Spectrum A')
        ax1.axvline(np.mean(flux_ratios_A), color='blue', linestyle='--', 
                   label=f'Mean: {np.mean(flux_ratios_A):.3f}')
    else:
        ax1.axvline(flux_ratios_A[0], color='blue', linestyle='-', linewidth=3,
                   label=f'Fixed ratio A: {flux_ratios_A[0]:.3f}')
    
    if len(flux_ratios_B) > 1:
        ax1.hist(flux_ratios_B, bins=20, alpha=0.6, density=True, label='Spectrum B')
        ax1.axvline(np.mean(flux_ratios_B), color='red', linestyle='--',
                   label=f'Mean: {np.mean(flux_ratios_B):.3f}')
    else:
        ax1.axvline(flux_ratios_B[0], color='red', linestyle='-', linewidth=3,
                   label=f'Fixed ratio B: {flux_ratios_B[0]:.3f}')
    
    ax1.set_xlabel('Flux Ratio (Star 1 / (Star 1 + Star 2))')
    ax1.set_ylabel('Density')
    ax1.set_title('Flux Ratio Distributions')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Create a visual representation of the flux contributions
    ratios_A_mean = np.mean(flux_ratios_A)
    ratios_B_mean = np.mean(flux_ratios_B)
    
    categories = ['Spectrum A', 'Spectrum B']
    star1_contributions = [ratios_A_mean, ratios_B_mean]
    star2_contributions = [1 - ratios_A_mean, 1 - ratios_B_mean]
    
    width = 0.6
    ax2.bar(categories, star1_contributions, width, label='Star 1', alpha=0.8)
    ax2.bar(categories, star2_contributions, width, bottom=star1_contributions, 
           label='Star 2', alpha=0.8)
    
    # Add text labels
    for i, (cat, s1, s2) in enumerate(zip(categories, star1_contributions, star2_contributions)):
        ax2.text(i, s1/2, f'{s1:.2f}', ha='center', va='center', fontweight='bold')
        ax2.text(i, s1 + s2/2, f'{s2:.2f}', ha='center', va='center', fontweight='bold')
    
    ax2.set_ylabel('Flux Contribution')
    ax2.set_title('Stellar Contributions by Spectrum')
    ax2.legend()
    ax2.set_ylim(0, 1)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved flux ratio comparison to {save_path}")
    
    return fig


def create_composite_spectral_summary_plot(results, save_path=None, figsize=(16, 12)):
    """
    Create a comprehensive summary plot for composite spectral analysis
    
    Parameters:
    results: output from complete_composite_binary_workflow()
    save_path: path to save the plot (optional) 
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    fig, axes = plt.subplots(3, 2, figsize=figsize)
    
    # Plot 1: Composite spectral comparison (top row)
    if 'composite_spec_data' in results:
        composite_spec_data = results['composite_spec_data']
        spectrum_A = composite_spec_data['spectrum_A']
        spectrum_B = composite_spec_data['spectrum_B']
        
        axes[0, 0].plot(spectrum_A['wavelength'], spectrum_A['flux'], 'k-', alpha=0.7, label='Observed A')
        axes[0, 0].fill_between(spectrum_A['wavelength'],
                              spectrum_A['flux'] - spectrum_A['flux_error'],
                              spectrum_A['flux'] + spectrum_A['flux_error'],
                              alpha=0.3, color='gray')
        axes[0, 0].set_title('Composite Spectrum A')
        axes[0, 0].set_ylabel('Flux')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        axes[0, 1].plot(spectrum_B['wavelength'], spectrum_B['flux'], 'k-', alpha=0.7, label='Observed B')
        axes[0, 1].fill_between(spectrum_B['wavelength'],
                              spectrum_B['flux'] - spectrum_B['flux_error'],
                              spectrum_B['flux'] + spectrum_B['flux_error'],
                              alpha=0.3, color='gray')
        axes[0, 1].set_title('Composite Spectrum B')
        axes[0, 1].set_ylabel('Flux')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 2: Parameter comparison (middle row)
    if 'photometry_only_samples' in results:
        photo_samples = results['photometry_only_samples']
        sbi = results['sbi_instance']
        param_names = sbi.parameter_names
        
        # Compare Teff for both stars
        teff1_photo = photo_samples['star1'][:, 0]
        teff2_photo = photo_samples['star2'][:, 0]
        
        axes[1, 0].hist(teff1_photo, bins=30, alpha=0.6, density=True, label='Photometry only')
        
        if 'composite_spectroscopy_reweighted_samples' in results:
            spec_samples = results['composite_spectroscopy_reweighted_samples']
            teff1_spec = spec_samples['star1'][:, 0]
            axes[1, 0].hist(teff1_spec, bins=30, alpha=0.6, density=True, label='+ Composite spectra')
        
        axes[1, 0].set_xlabel('Teff (K)')
        axes[1, 0].set_ylabel('Density')
        axes[1, 0].set_title('Star 1 Temperature')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        axes[1, 1].hist(teff2_photo, bins=30, alpha=0.6, density=True, label='Photometry only')
        
        if 'composite_spectroscopy_reweighted_samples' in results:
            teff2_spec = spec_samples['star2'][:, 0]
            axes[1, 1].hist(teff2_spec, bins=30, alpha=0.6, density=True, label='+ Composite spectra')
        
        axes[1, 1].set_xlabel('Teff (K)')
        axes[1, 1].set_ylabel('Density')
        axes[1, 1].set_title('Star 2 Temperature')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    # Plot 3: Flux ratio analysis and effective sample sizes (bottom row)
    if 'composite_likelihood_results' in results:
        likelihood_results = results['composite_likelihood_results']
        flux_ratios_A = likelihood_results['flux_ratios_used']['spectrum_A']
        flux_ratios_B = likelihood_results['flux_ratios_used']['spectrum_B']
        
        if len(flux_ratios_A) > 1 or len(flux_ratios_B) > 1:
            if len(flux_ratios_A) > 1:
                axes[2, 0].hist(flux_ratios_A, bins=20, alpha=0.6, density=True, label='Spectrum A')
            else:
                axes[2, 0].axvline(flux_ratios_A[0], color='blue', label=f'Spectrum A: {flux_ratios_A[0]:.3f}')
            
            if len(flux_ratios_B) > 1:
                axes[2, 0].hist(flux_ratios_B, bins=20, alpha=0.6, density=True, label='Spectrum B')
            else:
                axes[2, 0].axvline(flux_ratios_B[0], color='red', label=f'Spectrum B: {flux_ratios_B[0]:.3f}')
        
        axes[2, 0].set_xlabel('Flux Ratio (Star 1 / Total)')
        axes[2, 0].set_ylabel('Density')
        axes[2, 0].set_title('Flux Ratio Distributions')
        axes[2, 0].legend()
        axes[2, 0].grid(True, alpha=0.3)
    
    # Effective sample size comparison
    if 'effective_sample_sizes' in results:
        eff_sizes = results['effective_sample_sizes']
        original_sizes = [len(results['photometry_only_samples']['star1']),
                         len(results['photometry_only_samples']['star2'])]
        
        stars = ['Star 1', 'Star 2']
        eff_values = [eff_sizes['star1'], eff_sizes['star2']]
        
        x_pos = np.arange(len(stars))
        axes[2, 1].bar(x_pos - 0.2, original_sizes, 0.4, label='Original samples', alpha=0.7)
        axes[2, 1].bar(x_pos + 0.2, eff_values, 0.4, label='Effective samples', alpha=0.7)
        
        axes[2, 1].set_xlabel('Star')
        axes[2, 1].set_ylabel('Sample Size')
        axes[2, 1].set_title('Sample Size Comparison')
        axes[2, 1].set_xticks(x_pos)
        axes[2, 1].set_xticklabels(stars)
        axes[2, 1].legend()
        axes[2, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved composite spectral summary to {save_path}")
    
    return fig


def plot_spectral_comparison(results, best_fit_params=None, boss_spectrum_path=None,
                           model_spec_dir='/data/jls/triple_dc_model_grid/',
                           save_path=None, figsize=(15, 7)):
    """
    Create comprehensive spectral comparison plot showing:
    - DESI observed spectra for both stars
    - Best-fitting single star models  
    - Combined model spectrum
    - BOSS spectrum (if available)
    
    Parameters:
    results: output from complete_binary_star_workflow()
    best_fit_params: dict with best-fit parameters for both stars (optional)
    boss_spectrum_path: path to BOSS FITS spectrum file (optional)
    model_spec_dir: directory containing model spectra
    save_path: path to save the plot (optional)
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    if 'spec_data' not in results:
        print("No DESI spectral data available in results")
        return None
    
    try:
        from fetch_spectra import load_spectrum_with_photometry, find_nearest_spectrum
        from extinction import interpolate_extinction_law, get_extinction_coefficients
    except ImportError:
        print("Missing required modules: fetch_spectra, compute_photometry")
        return None
    
    spec_data = results['spec_data']
    sbi = results['sbi_instance']
    
    # Set up the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Parameters for plotting
    cont_region = [4450., 4550.]  # Continuum normalization region
    offsets = [-15., 17.5, 30., 45.]  # Vertical offsets for different spectra
    desi_smoothing = 3.0
    model_smoothing = 200.0
    
    # Plot DESI spectra for both stars
    spec1_data = spec_data['star1']
    spec2_data = spec_data['star2']
    
    # Star 1 DESI spectrum
    desi_flux1 = convolve(spec1_data['flux'], Gaussian1DKernel(desi_smoothing), boundary='extend')
    # Normalize to continuum region
    cont_mask1 = (spec1_data['wavelength'] > cont_region[0]) & (spec1_data['wavelength'] < cont_region[1])
    desi_norm1 = np.nanmean(desi_flux1[cont_mask1])
    
    ax.plot(spec1_data['wavelength'], desi_flux1, lw=0.5, 
            label='DESI DR1 fibre 1 (C-rich)', color='C0')
    
    # Star 2 DESI spectrum (with offset)
    desi_flux2 = convolve(spec2_data['flux'], Gaussian1DKernel(desi_smoothing), boundary='extend')
    cont_mask2 = (spec2_data['wavelength'] > cont_region[0]) & (spec2_data['wavelength'] < cont_region[1])
    desi_norm2 = np.nanmean(desi_flux2[cont_mask2])
    
    ax.plot(spec2_data['wavelength'], desi_flux2 + offsets[0], lw=0.5,
            label='DESI DR1 fibre 2 (non-C)', color='C1')
    
    # Load and plot best-fitting single star models if parameters provided
    if best_fit_params is not None or 'best_fits' in results:
        
        # Use provided parameters or extract from results
        if best_fit_params is not None:
            star1_params = best_fit_params['star1']
            star2_params = best_fit_params['star2']
        else:
            # Use best fit from results
            best_fit = results['best_fits'][0]
            star1_params = best_fit['star1_dict']
            star2_params = best_fit['star2_dict']
        
        print(f"Loading model spectra with parameters:")
        print(f"  Star 1: Teff={star1_params.get('teff', 5700):.0f}K, [Fe/H]={star1_params.get('metallicity', -1.6):.2f}, [C/Fe]={star1_params.get('carbon', 2.4):.1f}")
        print(f"  Star 2: Teff={star2_params.get('teff', 5780):.0f}K, [Fe/H]={star2_params.get('metallicity', -1.6):.2f}, [C/Fe]={star2_params.get('carbon', 0.0):.1f}")
        
        try:
            # Load model spectrum for star 1 (C-rich)
            F1 = load_spectrum_with_photometry(find_nearest_spectrum(
                star1_params.get('teff', 5700),
                star1_params.get('logg', 4.2), 
                star1_params.get('metallicity', -1.6),
                target_alpha=star1_params.get('alpha', 0.17),
                target_carbon=star1_params.get('carbon', 2.4),
                target_nitrogen=star1_params.get('nitrogen', 0.0),
                spec_dir=model_spec_dir,
                weights={'teff': 10.0, 'logg': 0.2, 'metallicity': 0.1, 
                        'carbon': 0.1, 'nitrogen': 0.1}
            ))
            F1smooth = convolve(F1['spectrum']['flux'], Gaussian1DKernel(model_smoothing))
            
            # Load model spectrum for star 2 (non-C)
            F2 = load_spectrum_with_photometry(find_nearest_spectrum(
                star2_params.get('teff', 5780),
                star2_params.get('logg', 4.2),
                star2_params.get('metallicity', -1.6), 
                target_alpha=star2_params.get('alpha', 0.17),
                target_carbon=star2_params.get('carbon', 0.0),
                target_nitrogen=star2_params.get('nitrogen', 0.0),
                spec_dir=model_spec_dir,
                weights={'teff': 10.0, 'logg': 0.2, 'metallicity': 0.1,
                        'carbon': 0.1, 'nitrogen': 0.1}
            ))
            F2smooth = convolve(F2['spectrum']['flux'], Gaussian1DKernel(model_smoothing))
            
            # Get extinction coefficients and apply extinction
            EE = get_extinction_coefficients()
            ebv = 0.07  # Default extinction value
            
            # Calculate flux scaling ratios based on photometry
            # Use observed photometry from the SBI analysis
            observed_mags1, _ = sbi.get_observed_photometry_for_inference(object_index=0)
            observed_mags2, _ = sbi.get_observed_photometry_for_inference(object_index=1)
            
            # Get g-band magnitudes for scaling
            g_idx = sbi.bands.index('g') if 'g' in sbi.bands else sbi.bands.index('r')
            calib_band = sbi.bands[g_idx]
            wavelengths = {'g': 475, 'r': 625}  # nm
            
            # Calculate scaling ratios
            rmag1 = 10**(-0.4*(observed_mags1[g_idx] - EE[calib_band]*ebv)) * 1.21e14/(wavelengths.get(calib_band, 625))**2
            rmag_model1 = 10**(-0.4*F1['photometry'][calib_band]) * 1.21e14/(wavelengths.get(calib_band, 625))**2
            ratio1 = rmag_model1/rmag1
            
            rmag2 = 10**(-0.4*(observed_mags2[g_idx] - EE[calib_band]*ebv)) * 1.21e14/(wavelengths.get(calib_band, 625))**2
            rmag_model2 = 10**(-0.4*F2['photometry'][calib_band]) * 1.21e14/(wavelengths.get(calib_band, 625))**2
            ratio2 = rmag_model2/rmag2
            
            # Apply extinction to model spectra
            extinction_factor = 10**(-0.4*ebv*interpolate_extinction_law(F1['spectrum']['wavelength_um']))
            
            # Plot individual star models at different offsets
            wavelength_aa = F1['spectrum']['wavelength_um'] * 1e4
            
            # Normalize to continuum for plotting
            cont_mask_model = (wavelength_aa > cont_region[0]) & (wavelength_aa < cont_region[1])
            
            # Individual star 1 model
            model1_flux = F1smooth/ratio1 * extinction_factor
            model1_norm = np.nanmean(model1_flux[cont_mask_model])
            ax.plot(wavelength_aa, model1_flux * desi_norm1/model1_norm + offsets[3], 
                   lw=1.0, alpha=0.7, color='k', label='C model')
            
            # Individual star 2 model  
            model2_flux = F2smooth/ratio2 * extinction_factor
            model2_norm = np.nanmean(model2_flux[cont_mask_model])
            ax.plot(wavelength_aa, model2_flux * desi_norm2/model2_norm + offsets[2],
                   lw=1.0, alpha=0.7, color='k', label='non-C model')
            
            # Combined model spectra for comparison with observations
            flux_ratios = [
                [0.54, 0.18],  # For star 1 comparison
                [0.142, 0.315],  # For star 2 comparison 
                [0.748, 0.292]   # For BOSS comparison (if available)
            ]
            
            spectrum_offsets = [0, offsets[0], offsets[1]]  # DESI 1, DESI 2, BOSS
            spectrum_norms = [desi_norm1, desi_norm2, desi_norm1]  # Normalization references
            
            for i, (weights, offset, norm_ref) in enumerate(zip(flux_ratios, spectrum_offsets, spectrum_norms)):
                weights = np.array(weights)
                weights /= np.sum(weights)
                
                combined_flux = (weights[0]*F1smooth/ratio1 + weights[1]*F2smooth/ratio2) * extinction_factor
                combined_norm = np.nanmean(combined_flux[cont_mask_model])
                
                ax.plot(wavelength_aa, combined_flux * norm_ref/combined_norm + offset,
                       lw=1.0, alpha=0.5, color='k', zorder=-1)
            
        except Exception as e:
            print(f"Warning: Could not load model spectra: {e}")
    
    # Load and plot BOSS spectrum if provided
    if boss_spectrum_path is not None:
        try:
            sdss_data = fits.open(boss_spectrum_path)
            wav_sdss = np.power(10., sdss_data[1].data['loglam'])
            flux_sdss_raw = sdss_data[1].data['flux']
            
            # Apply smoothing and wavelength-dependent scaling
            flux_sdss = convolve(flux_sdss_raw, Gaussian1DKernel(2.5))
            wavelength_scaling = ((wav_sdss - 4500.)/1000.*0.15 + 1.)
            flux_sdss_scaled = wavelength_scaling * flux_sdss
            
            ax.plot(wav_sdss, flux_sdss_scaled + offsets[1], lw=0.5, 
                   color='red', label='BOSS SDSS DR19')
            
        except Exception as e:
            print(f"Warning: Could not load BOSS spectrum: {e}")
    
    # Formatting and labels
    ax.set_ylabel(r'$F_\lambda$ [$10^{-17}$ erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$]')
    ax.set_xlabel(r'Wavelength ($\AA$)')
    ax.set_xlim(3600., 6000.)
    ax.set_ylim(-10., 66.)
    
    # Add annotations
    annotations = [
        ('DESI DR1 fibre 1', (5980, 14), 'C0'),
        ('DESI DR1 fibre 2', (5980, 0), 'C1'),
    ]
    
    if boss_spectrum_path is not None:
        annotations.append(('BOSS SDSS DR19', (5980, 27), 'red'))
    
    if best_fit_params is not None or 'best_fits' in results:
        annotations.extend([
            ('C model', (5980, 57), 'k'),
            ('non-C model', (5980, 43), 'k')
        ])
    
    for text, (x, y), color in annotations:
        ax.annotate(text, xy=(x, y), fontsize=16, color=color, ha='right')
    
    # Add grid and formatting
    ax.grid(True, alpha=0.3)
    
    # Try to add inner ticks if available
    try:
        from plotting_general import add_inner_ticks
        add_inner_ticks()
    except ImportError:
        pass
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved spectral comparison plot to {save_path}")
    
    return fig


def plot_model_data_residuals(results, best_fit_params=None, model_spec_dir='/data/jls/triple_dc_model_grid/',
                             flux_ratio=0.5, save_path=None, figsize=(15, 10)):
    """
    Create a detailed model vs. data comparison plot with residuals
    
    Parameters:
    results: output from complete_binary_star_workflow()
    best_fit_params: dict with best-fit parameters for both stars (optional)
    model_spec_dir: directory containing model spectra
    flux_ratio: flux ratio for combining model spectra
    save_path: path to save the plot (optional)
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    if 'spec_data' not in results:
        print("No DESI spectral data available in results")
        return None
    
    try:
        from neural_emulator import SpectralEmulator
        from extinction import apply_extinction_to_spectrum
    except ImportError:
        print("Neural emulator not available for detailed comparison")
        return None
    
    # Create subplot layout
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, 
                                   gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.05})
    
    spec_data = results['spec_data']
    
    # Combine observed spectra
    from scipy.interpolate import interp1d
    wavelengths = spec_data['star1']['wavelength']
    
    # Interpolate star2 onto star1 wavelength grid
    interp_flux2 = interp1d(
        spec_data['star2']['wavelength'], 
        spec_data['star2']['flux'],
        bounds_error=False, 
        fill_value=np.nan
    )(wavelengths)
    
    # Combine observed fluxes
    observed_combined = flux_ratio * spec_data['star1']['flux'] + (1 - flux_ratio) * interp_flux2
    
    # Generate model spectrum using best-fit parameters
    if 'best_fits' in results and len(results['best_fits']) > 0:
        best_fit = results['best_fits'][0]
        
        # Extract parameters for both stars
        star1_params = np.array([best_fit['star1_params']])  # Shape: (1, n_params)
        star2_params = np.array([best_fit['star2_params']])
        
        try:
            # Use neural emulator to generate spectra
            emulator_path = 'desi_emulator_model.pth'  # Update as needed
            emulator = SpectralEmulator(model_path=emulator_path, device='cpu')
            
            # Generate intrinsic spectra (excluding EBV)
            spectrum1 = emulator.predict(star1_params[:, :-1])[0]  # Exclude EBV
            spectrum2 = emulator.predict(star2_params[:, :-1])[0]
            model_wavelengths = emulator.get_wavelengths()
            
            # Apply extinction
            ebv1, ebv2 = star1_params[0, -1], star2_params[0, -1]
            if ebv1 > 0:
                spectrum1 = apply_extinction_to_spectrum(model_wavelengths, spectrum1, ebv1)
            if ebv2 > 0:
                spectrum2 = apply_extinction_to_spectrum(model_wavelengths, spectrum2, ebv2)
            
            # Interpolate to observed wavelength grid
            model_flux1 = interp1d(model_wavelengths*1e4, spectrum1, 
                                 bounds_error=False, fill_value=np.nan)(wavelengths)
            model_flux2 = interp1d(model_wavelengths*1e4, spectrum2,
                                 bounds_error=False, fill_value=np.nan)(wavelengths)
            
            # Combine model spectra
            model_combined = flux_ratio * model_flux1 + (1 - flux_ratio) * model_flux2
            
            # Normalize to match observed spectrum
            valid_mask = ~(np.isnan(observed_combined) | np.isnan(model_combined))
            if np.sum(valid_mask) > 0:
                normalization = np.median(observed_combined[valid_mask] / model_combined[valid_mask])
                model_combined *= normalization
            
            # Plot spectra
            ax1.plot(wavelengths, observed_combined, 'k-', lw=1, label='Observed (combined)', alpha=0.7)
            ax1.plot(wavelengths, model_combined, 'r-', lw=1, label='Best-fit model', alpha=0.8)
            
            # Plot residuals
            residuals = observed_combined - model_combined
            ax2.plot(wavelengths, residuals, 'b-', lw=0.5, alpha=0.7)
            ax2.axhline(0, color='k', linestyle='--', alpha=0.5)
            
            # Calculate chi-squared
            errors = spec_data['star1']['flux_error']  # Use star1 errors as approximation
            valid_error_mask = valid_mask & ~np.isnan(errors) & (errors > 0)
            if np.sum(valid_error_mask) > 0:
                chi_squared = np.sum((residuals[valid_error_mask] / errors[valid_error_mask])**2)
                dof = np.sum(valid_error_mask)
                reduced_chi_squared = chi_squared / dof
                
                ax1.text(0.02, 0.98, f'χ²/dof = {reduced_chi_squared:.2f}', 
                        transform=ax1.transAxes, verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
        except Exception as e:
            print(f"Could not generate model spectrum: {e}")
            # Just plot observed data
            ax1.plot(wavelengths, observed_combined, 'k-', lw=1, label='Observed (combined)')
    
    # Formatting
    ax1.set_ylabel(r'$F_\lambda$ [$10^{-17}$ erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$]')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(3600, 6000)
    
    ax2.set_ylabel('Residuals')
    ax2.set_xlabel(r'Wavelength ($\AA$)')
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(3600, 6000)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved model-data comparison plot to {save_path}")
    
    return fig


def plot_photometry_comparison(results, show_spectra=True, save_path=None, figsize=(15, 8)):
    """
    Create photometry comparison plot showing observed vs. model photometry
    with filter transmission curves underneath and optional best-fitting spectra
    
    Parameters:
    results: output from complete_binary_star_workflow()
    show_spectra: whether to show best-fitting spectra as transparent overlay
    save_path: path to save the plot (optional)
    figsize: figure size tuple
    
    Returns:
    fig: matplotlib figure object
    """
    try:
        from extinction import get_extinction_coefficients
        from compute_photometry import PhotMetry
    except ImportError:
        print("Missing required module: compute_photometry")
        return None
    
    sbi = results['sbi_instance']
    
    # Create subplot layout: photometry on top, filter curves on bottom
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 1, height_ratios=[2, 1], hspace=0.3)
    ax_phot = fig.add_subplot(gs[0])
    ax_filt = fig.add_subplot(gs[1])
    
    # Get observed photometry for both stars
    observed_mags1, observed_errors1 = sbi.get_observed_photometry_for_inference(object_index=0)
    observed_mags2, observed_errors2 = sbi.get_observed_photometry_for_inference(object_index=1)
    
    # Get wavelength centers for bands
    wavelengths = {
        'NUV': 2300, 'u': 3650, 'g': 4750, 'r': 6250, 'i': 7750, 
        'z': 8900, 'Z': 9000, 'Y': 10000, 'J': 12500, 'H': 16500, 'Ks': 22000
    }
    
    # Convert to flux units for plotting (similar to notebook)
    EE = get_extinction_coefficients()
    ebv = 0.07  # Assumed extinction
    
    observed_wavelengths = []
    observed_fluxes1 = []
    observed_fluxes2 = []
    observed_flux_errors1 = []
    observed_flux_errors2 = []
    
    for i, band in enumerate(sbi.bands):
        if band in wavelengths:
            wl = wavelengths[band]
            observed_wavelengths.append(wl)
            
            # Convert magnitudes to flux units (10^-17 erg/cm^2/s/A equivalent)
            # Using the same conversion as in the notebook
            flux1 = 10**(-0.4*(observed_mags1[i] - EE[band]*ebv)) * 1.21e14 / (wl/10)**2
            flux2 = 10**(-0.4*(observed_mags2[i] - EE[band]*ebv)) * 1.21e14 / (wl/10)**2
            
            observed_fluxes1.append(flux1)
            observed_fluxes2.append(flux2)
            
            # Convert magnitude errors to flux errors (approximate)
            flux_error1 = flux1 * observed_errors1[i] * np.log(10) / 2.5
            flux_error2 = flux2 * observed_errors2[i] * np.log(10) / 2.5
            observed_flux_errors1.append(flux_error1)
            observed_flux_errors2.append(flux_error2)
    
    observed_wavelengths = np.array(observed_wavelengths)
    observed_fluxes1 = np.array(observed_fluxes1)
    observed_fluxes2 = np.array(observed_fluxes2)
    observed_flux_errors1 = np.array(observed_flux_errors1)
    observed_flux_errors2 = np.array(observed_flux_errors2)
    
    # Plot observed photometry
    ax_phot.errorbar(observed_wavelengths, observed_fluxes1, yerr=observed_flux_errors1,
                     fmt='s', markersize=8, capsize=3, label='Star 1 (C-rich) - Observed', 
                     color='C0', alpha=0.8)
    ax_phot.errorbar(observed_wavelengths, observed_fluxes2, yerr=observed_flux_errors2,
                     fmt='s', markersize=8, capsize=3, label='Star 2 (non-C) - Observed',
                     color='C1', alpha=0.8)
    
    # Get model photometry if available
    if 'best_fits' in results and len(results['best_fits']) > 0:
        try:
            from fetch_spectra import load_spectrum_with_photometry, find_nearest_spectrum
            
            best_fit = results['best_fits'][0]
            star1_params = best_fit['star1_dict']
            star2_params = best_fit['star2_dict']
            
            # Load model photometry
            model_spec_dir = '/data/jls/triple_dc_model_grid/'
            
            F1 = load_spectrum_with_photometry(find_nearest_spectrum(
                star1_params.get('teff', 5700),
                star1_params.get('logg', 4.2),
                star1_params.get('metallicity', -1.6),
                target_alpha=star1_params.get('alpha', 0.17),
                target_carbon=star1_params.get('carbon', 2.4),
                target_nitrogen=star1_params.get('nitrogen', 0.0),
                spec_dir=model_spec_dir,
                weights={'teff': 10.0, 'logg': 0.2, 'metallicity': 0.1,
                        'carbon': 0.1, 'nitrogen': 0.1}
            ))
            
            F2 = load_spectrum_with_photometry(find_nearest_spectrum(
                star2_params.get('teff', 5780),
                star2_params.get('logg', 4.2),
                star2_params.get('metallicity', -1.6),
                target_alpha=star2_params.get('alpha', 0.17),
                target_carbon=star2_params.get('carbon', 0.0),
                target_nitrogen=star2_params.get('nitrogen', 0.0),
                spec_dir=model_spec_dir,
                weights={'teff': 10.0, 'logg': 0.2, 'metallicity': 0.1,
                        'carbon': 0.1, 'nitrogen': 0.1}
            ))
            
            # Calculate model fluxes with same conversion
            model_wavelengths = []
            model_fluxes1 = []
            model_fluxes2 = []
            
            for i, band in enumerate(sbi.bands):
                if band in wavelengths and band in F1['photometry']:
                    wl = wavelengths[band]
                    model_wavelengths.append(wl)
                    
                    # Convert model magnitudes to flux units
                    model_flux1 = 10**(-0.4*F1['photometry'][band]) * 1.21e14 / (wl/10)**2
                    model_flux2 = 10**(-0.4*F2['photometry'][band]) * 1.21e14 / (wl/10)**2
                    
                    model_fluxes1.append(model_flux1)
                    model_fluxes2.append(model_flux2)
            
            model_wavelengths = np.array(model_wavelengths)
            model_fluxes1 = np.array(model_fluxes1)
            model_fluxes2 = np.array(model_fluxes2)
            
            # Calculate scaling factors to match observed fluxes
            # Use g or r band for normalization
            norm_band = 'g' if 'g' in sbi.bands else 'r'
            norm_idx = sbi.bands.index(norm_band)
            norm_wl_idx = np.where(model_wavelengths == wavelengths[norm_band])[0]
            
            if len(norm_wl_idx) > 0:
                idx = norm_wl_idx[0]
                scale1 = observed_fluxes1[idx] / model_fluxes1[idx]
                scale2 = observed_fluxes2[idx] / model_fluxes2[idx]
                
                model_fluxes1 *= scale1
                model_fluxes2 *= scale2
            
            # Plot model photometry
            ax_phot.plot(model_wavelengths, model_fluxes1, 'o', markersize=10, 
                        markerfacecolor='none', markeredgewidth=2, markeredgecolor='C0',
                        label='Star 1 (C-rich) - Model')
            ax_phot.plot(model_wavelengths, model_fluxes2, 'o', markersize=10,
                        markerfacecolor='none', markeredgewidth=2, markeredgecolor='C1', 
                        label='Star 2 (non-C) - Model')
            
            # Show best-fitting spectra as transparent overlay if requested
            if show_spectra:
                try:
                    from astropy.convolution import convolve, Gaussian1DKernel
                    
                    # Load and smooth model spectra
                    F1smooth = convolve(F1['spectrum']['flux'], Gaussian1DKernel(200.0))
                    F2smooth = convolve(F2['spectrum']['flux'], Gaussian1DKernel(200.0))
                    
                    wavelength_aa = F1['spectrum']['wavelength_um'] * 1e4
                    
                    # Scale spectra to match photometry scaling
                    spectrum_flux1 = F1smooth * scale1
                    spectrum_flux2 = F2smooth * scale2
                    
                    ax_phot.plot(wavelength_aa, spectrum_flux1, '-', alpha=0.3, color='C0', 
                               linewidth=1, zorder=-1)
                    ax_phot.plot(wavelength_aa, spectrum_flux2, '-', alpha=0.3, color='C1',
                               linewidth=1, zorder=-1)
                    
                except Exception as e:
                    print(f"Could not plot spectra: {e}")
                    
        except Exception as e:
            print(f"Could not load model photometry: {e}")
    
    # Plot filter transmission curves
    try:
        phot_computer = PhotMetry()
        
        # Get available bands that have transmission curves
        available_bands = [band for band in sbi.bands if band in phot_computer.transmission_curves]
        
        # Plot transmission curves with different colors
        for i, band in enumerate(available_bands):
            trans_data = phot_computer.transmission_curves[band]
            wavelength = trans_data['Wavelength'].value * 1e4  # Convert to Angstroms
            transmission = trans_data['Transmission'].value
            
            # Normalize transmission for better visibility
            transmission_norm = transmission / np.max(transmission) * 0.8
            
            ax_filt.fill_between(wavelength, 0, transmission_norm, 
                               alpha=0.6, color=f'C{i}', label=f'{band} band')
            ax_filt.plot(wavelength, transmission_norm, color=f'C{i}', linewidth=1)
        
        ax_filt.set_xlabel('Wavelength (Å)')
        ax_filt.set_ylabel('Normalized\nTransmission')
        ax_filt.set_xlim(ax_phot.get_xlim())
        ax_filt.legend(ncol=len(available_bands)//2 + 1, fontsize=10)
        ax_filt.grid(True, alpha=0.3)
        
    except Exception as e:
        print(f"Could not plot filter curves: {e}")
        # Hide the filter subplot if we can't plot it
        ax_filt.set_visible(False)
    
    # Format photometry plot
    ax_phot.set_ylabel(r'$F_\lambda$ [relative units]')
    ax_phot.set_xscale('log')
    ax_phot.set_yscale('log')
    ax_phot.legend(fontsize=10)
    ax_phot.grid(True, alpha=0.3)
    ax_phot.set_title('Observed vs. Model Photometry')
    
    # Set wavelength range
    if len(observed_wavelengths) > 0:
        wl_min = np.min(observed_wavelengths) * 0.8
        wl_max = np.max(observed_wavelengths) * 1.2
        ax_phot.set_xlim(wl_min, wl_max)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved photometry comparison plot to {save_path}")
    
    return fig




def visualize_photometry_results(results, save_plots=True, output_dir='./'):
    """
    Create comprehensive visualizations of single star photometry results
    
    Parameters:
    results: output from complete_binary_star_workflow()
    save_plots: whether to save plots to disk
    output_dir: directory to save plots
    
    Returns:
    figures: dict of matplotlib figures
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    figures = {}
    sbi = results['sbi_instance']
    photo_samples = results['photometry_only_samples']
    
    # 1. Corner plot comparing both stars (photometry only)
    print("Creating corner plot for photometry-only results...")
    
    samples_dict = {
        'Star 1 (Photometry)': photo_samples['star1'],
        'Star 2 (Photometry)': photo_samples['star2']
    }
    
    # Add spectroscopy results if available
    if 'spectroscopy_reweighted_samples' in results:
        spec_samples = results['spectroscopy_reweighted_samples']
        samples_dict['Star 1 (Photo+Spec)'] = spec_samples['star1']
        samples_dict['Star 2 (Photo+Spec)'] = spec_samples['star2']
    
    corner_fig = sbi.plot_corner(
        samples_dict, 
        save_path=f'{output_dir}/corner_plot_comparison.png' if save_plots else None
    )
    figures['corner_plot'] = corner_fig
    
    # 2. Parameter summary plots
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    param_labels = ['Teff (K)', 'log g', '[Fe/H]', '[α/Fe]', '[C/Fe]', '[N/Fe]', 'E(B-V)']
    
    for i, (param, label) in enumerate(zip(sbi.parameter_names, param_labels)):
        ax = axes[i]
        
        # Photometry-only histograms
        ax.hist(photo_samples['star1'][:, i], bins=30, alpha=0.6, 
                label='Star 1 (Photo)', color='C0', density=True)
        ax.hist(photo_samples['star2'][:, i], bins=30, alpha=0.6, 
                label='Star 2 (Photo)', color='C1', density=True)
        
        # Add spectroscopy results if available
        if 'spectroscopy_reweighted_samples' in results:
            spec_samples = results['spectroscopy_reweighted_samples']
            ax.hist(spec_samples['star1'][:, i], bins=30, alpha=0.6, 
                    label='Star 1 (Photo+Spec)', color='C0', density=True, 
                    histtype='step', linewidth=2)
            ax.hist(spec_samples['star2'][:, i], bins=30, alpha=0.6, 
                    label='Star 2 (Photo+Spec)', color='C1', density=True,
                    histtype='step', linewidth=2)
        
        ax.set_xlabel(label)
        ax.set_ylabel('Density')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    if save_plots:
        plt.savefig(f'{output_dir}/parameter_histograms.png', dpi=150, bbox_inches='tight')
    figures['parameter_histograms'] = fig
    
    # 3. Summary statistics table
    print("\n=== Photometry-Only Results Summary ===")
    
    for star_idx, star_name in enumerate(['Star 1', 'Star 2']):
        samples = photo_samples[f'star{star_idx+1}']
        print(f"\n{star_name}:")
        
        for i, (param, label) in enumerate(zip(sbi.parameter_names, param_labels)):
            median_val = np.median(samples[:, i])
            p16, p84 = np.percentile(samples[:, i], [16, 84])
            
            print(f"  {label:12}: {median_val:8.3f} +{p84-median_val:6.3f}/-{median_val-p16:6.3f}")
    
    # 4. If spectroscopy available, show comparison
    if 'spectroscopy_reweighted_samples' in results:
        print("\n=== Spectroscopy-Reweighted Results Summary ===")
        spec_samples = results['spectroscopy_reweighted_samples']
        
        for star_idx, star_name in enumerate(['Star 1', 'Star 2']):
            samples = spec_samples[f'star{star_idx+1}']
            print(f"\n{star_name}:")
            
            for i, (param, label) in enumerate(zip(sbi.parameter_names, param_labels)):
                median_val = np.median(samples[:, i])
                p16, p84 = np.percentile(samples[:, i], [16, 84])
                
                print(f"  {label:12}: {median_val:8.3f} +{p84-median_val:6.3f}/-{median_val-p16:6.3f}")
        
        # Show effective sample sizes
        if 'effective_sample_sizes' in results:
            eff_sizes = results['effective_sample_sizes']
            print(f"\nEffective sample sizes after reweighting:")
            print(f"  Star 1: {eff_sizes['star1']:.0f}/{len(photo_samples['star1'])}")
            print(f"  Star 2: {eff_sizes['star2']:.0f}/{len(photo_samples['star2'])}")
    
    # 5. HR diagram if available
    if 'teff' in sbi.parameter_names and 'logg' in sbi.parameter_names:
        teff_idx = sbi.parameter_names.index('teff')
        logg_idx = sbi.parameter_names.index('logg')
        
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Plot photometry samples
        ax.scatter(photo_samples['star1'][:, teff_idx], photo_samples['star1'][:, logg_idx], 
                  alpha=0.3, s=1, label='Star 1 (Photo)', color='C0')
        ax.scatter(photo_samples['star2'][:, teff_idx], photo_samples['star2'][:, logg_idx], 
                  alpha=0.3, s=1, label='Star 2 (Photo)', color='C1')
        
        # Add spectroscopy samples if available
        if 'spectroscopy_reweighted_samples' in results:
            spec_samples = results['spectroscopy_reweighted_samples']
            ax.scatter(spec_samples['star1'][:, teff_idx], spec_samples['star1'][:, logg_idx], 
                      alpha=0.5, s=2, label='Star 1 (Photo+Spec)', color='C0', marker='x')
            ax.scatter(spec_samples['star2'][:, teff_idx], spec_samples['star2'][:, logg_idx], 
                      alpha=0.5, s=2, label='Star 2 (Photo+Spec)', color='C1', marker='x')
        
        ax.set_xlabel('Teff (K)')
        ax.set_ylabel('log g')
        ax.invert_xaxis()
        ax.invert_yaxis()
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_title('HR Diagram')
        
        if save_plots:
            plt.savefig(f'{output_dir}/hr_diagram.png', dpi=150, bbox_inches='tight')
        figures['hr_diagram'] = fig
    
    return figures


def create_spectral_comparison_plots(results, boss_spectrum_path=None, 
                                    model_spec_dir='/data/jls/triple_dc_model_grid/',
                                    save_plots=True, output_dir='./'):
    """
    Create comprehensive spectral comparison plots
    
    Parameters:
    results: output from complete_binary_star_workflow()
    boss_spectrum_path: path to BOSS FITS spectrum file (optional)
    model_spec_dir: directory containing model spectra
    save_plots: whether to save plots to disk
    output_dir: directory to save plots
    
    Returns:
    figures: dict of matplotlib figures
    """
    try:
        from spectral_plotting import plot_spectral_comparison, plot_model_data_residuals
    except ImportError:
        print("Warning: spectral_plotting module not available")
        return {}
    
    figures = {}
    
    if 'spec_data' not in results:
        print("No spectral data available for plotting")
        return figures
    
    print("Creating spectral comparison plots...")
    
    # 1. Main spectral comparison plot
    try:
        fig1 = plot_spectral_comparison(
            results,
            boss_spectrum_path=boss_spectrum_path,
            model_spec_dir=model_spec_dir,
            save_path=f'{output_dir}/spectral_comparison.png' if save_plots else None
        )
        figures['spectral_comparison'] = fig1
    except Exception as e:
        print(f"Warning: Could not create spectral comparison plot: {e}")
    
    # 2. Detailed model vs. data comparison
    try:
        # Use flux ratio from best fits if available
        flux_ratio = 0.5  # Default
        if 'best_fits' in results and len(results['best_fits']) > 0:
            flux_ratio = results['best_fits'][0]['flux_ratio']
        
        fig2 = plot_model_data_residuals(
            results,
            model_spec_dir=model_spec_dir,
            flux_ratio=flux_ratio,
            save_path=f'{output_dir}/model_data_comparison.png' if save_plots else None
        )
        figures['model_data_comparison'] = fig2
    except Exception as e:
        print(f"Warning: Could not create model-data comparison plot: {e}")
    
    # 3. Photometry comparison plot with filter curves
    try:
        from spectral_plotting import plot_photometry_comparison
        
        fig3 = plot_photometry_comparison(
            results,
            show_spectra=True,
            save_path=f'{output_dir}/photometry_comparison.png' if save_plots else None
        )
        figures['photometry_comparison'] = fig3
    except Exception as e:
        print(f"Warning: Could not create photometry comparison plot: {e}")
    
    return figures


def add_spectral_plotting_to_workflow():
    """
    Function to demonstrate how to integrate spectral plotting into the workflow
    """
    print("""
    To add spectral plotting to your workflow:
    
    1. Import the spectral plotting module:
       from spectral_plotting import plot_spectral_comparison, plot_model_data_residuals
    
    2. After running the complete workflow:
       results = complete_binary_star_workflow(...)
       
    3. Create spectral comparison plot:
       fig1 = plot_spectral_comparison(
           results, 
           boss_spectrum_path='spec-102347-59731-27021598157843923.fits',
           save_path='spectral_comparison.png'
       )
       
    4. Create detailed model vs. data plot:
       fig2 = plot_model_data_residuals(
           results,
           flux_ratio=0.54,
           save_path='model_data_comparison.png'
       )
    """)
