In [None]:
import os
import numpy as np
import warnings
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
from regions import Regions
import pandas as pd
import plotly.graph_objects as go

warnings.filterwarnings("ignore", category=UserWarning, append=True)

# Redshift
z = 0.016268

# Load region file
reg_path = "data/region_file"
regions = Regions.read(reg_path, format='ds9')

# Define 4 chapters, each with 3 files
chapters = {
    "chapter1": [
        './data/jw01328-c1006_t014_miri_ch1-short_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch1-medium_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch1-long_s3d.fits'
    ],
    "chapter2": [
        './data/jw01328-c1006_t014_miri_ch2-short_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch2-medium_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch2-long_s3d.fits'
    ],
    "chapter3": [
        './data/jw01328-c1006_t014_miri_ch3-short_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch3-medium_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch3-long_s3d.fits'
    ],
    "chapter4": [
        './data/jw01328-c1006_t014_miri_ch4-short_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch4-medium_s3d.fits',
        './data/jw01328-c1006_t014_miri_ch4-long_s3d.fits'
    ]
}

# Emission features to mark
features = {
    'PAHs': {'PAH 7.7': 7.7, 'PAH 8.6': 8.6, 'PAH 11.3': 11.3},
    'Neon': {'[Ne VI]': 7.65},
    'Other': {'[Ar III]': 8.991, '[S IV]': 10.51},
    'H₂': {'S(3)': 9.66, 'S(4)': 8.03}
}
colors = {
    'PAHs': '#FF7F0E',
    'Neon': '#D62728',
    'Other': '#9467BD',
    'H₂': '#8C564B'
}

def extract_and_plot(chapter_name, file_paths, output_dir, global_region_data):
    os.makedirs(output_dir, exist_ok=True)
    combined_wavelength = []
    combined_spectrum = []
    combined_uncertainty = []

    for r_idx, region in enumerate(regions):
        spectrum_all = []
        spectrum_all_err = []
        wavelength_all = []

        for file_path in file_paths:
            with fits.open(file_path) as hdul:
                data = hdul[1].data
                data[data < 0] = np.nan
                data_err = hdul[2].data
                header = hdul[1].header
                wcs = WCS(header)
                mask = region.to_pixel(wcs.celestial).to_mask()
                num_channels, ny, nx = data.shape

                for i in range(num_channels):
                    masked_data = np.array(mask.multiply(data[i, :, :]), dtype=float)
                    masked_data_err = np.array(mask.multiply(data_err[i, :, :]), dtype=float)
                    avg = np.nanmean(masked_data)
                    err = np.sqrt(np.nanmean(masked_data_err ** 2))
                    spectrum_all.append(avg if not np.isnan(avg) else 0)
                    spectrum_all_err.append(err if not np.isnan(err) else 0)

                crval3 = header['CRVAL3']
                cdelt3 = header['CDELT3']
                crpix3 = header['CRPIX3']
                wavelength = (np.arange(num_channels) - (crpix3 - 1)) * cdelt3 + crval3
                wavelength /= (1 + z)
                wavelength_all.extend(wavelength)

        global_region_data[r_idx]['wavelength'].extend(wavelength_all)
        global_region_data[r_idx]['intensity'].extend(spectrum_all)
        global_region_data[r_idx]['uncertainty'].extend(spectrum_all_err)

        combined_wavelength.extend(wavelength_all)
        combined_spectrum.extend(spectrum_all)
        combined_uncertainty.extend(spectrum_all_err)

        df = pd.DataFrame({
            'Wavelength_microns': wavelength_all,
            'Intensity_MJy_sr': spectrum_all,
            'Uncertainty': spectrum_all_err
        })
        df.to_csv(f"{output_dir}/{chapter_name}_region{r_idx+1}_spectrum.csv", index=False)

        plt.figure(figsize=(15, 8))
        plt.errorbar(wavelength_all, spectrum_all, yerr=spectrum_all_err, color='blue', ecolor='black')
        plt.xlabel('Wavelength (microns)')
        plt.ylabel('Average Intensity (MJy/sr)')
        plt.title(f'{chapter_name.upper()} - Region {r_idx+1}')
        plt.grid(True)
        plt.savefig(f"{output_dir}/{chapter_name}_region{r_idx+1}_plot.png", dpi=300)
        plt.close()

        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=wavelength_all, y=spectrum_all, mode='lines',
            line=dict(color='blue', width=1.5), name='Spectrum'
        ))
        fig.add_trace(go.Scatter(
            x=np.concatenate([wavelength_all, wavelength_all[::-1]]),
            y=np.concatenate([np.array(spectrum_all) + np.array(spectrum_all_err),
                              (np.array(spectrum_all) - np.array(spectrum_all_err))[::-1]]),
            fill='toself', fillcolor='rgba(31, 119, 180, 0.2)',
            line=dict(color='rgba(255,255,255,0)'), hoverinfo='skip', name='Uncertainty'
        ))
        for category, lines in features.items():
            for name, wl in lines.items():
                fig.add_vline(
                    x=wl, line=dict(color=colors[category], width=1.5, dash='dot'),
                    annotation=dict(text=name, font=dict(color=colors[category], size=10), yshift=10)
                )
        fig.update_layout(
            title=f'{chapter_name.upper()} - Region {r_idx+1} Spectrum',
            xaxis_title='Wavelength (μm)', yaxis_title='Intensity (MJy/sr)',
            hovermode='x unified'
        )
        fig.write_html(f"{output_dir}/{chapter_name}_region{r_idx+1}_plot.html")

    plt.figure(figsize=(15, 8))
    plt.errorbar(combined_wavelength, combined_spectrum, yerr=combined_uncertainty, color='purple')
    plt.xlabel('Wavelength (microns)')
    plt.ylabel('Average Intensity (MJy/sr)')
    plt.grid(True)
    plt.title(f'{chapter_name.upper()} - Combined Regions Spectrum')
    plt.savefig(f"{output_dir}/{chapter_name}_combined_plot.png", dpi=300)
    plt.close()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=combined_wavelength, y=combined_spectrum, mode='lines', name='Combined Spectrum'))
    fig.add_trace(go.Scatter(
        x=np.concatenate([combined_wavelength, combined_wavelength[::-1]]),
        y=np.concatenate([np.array(combined_spectrum) + np.array(combined_uncertainty),
                          (np.array(combined_spectrum) - np.array(combined_uncertainty))[::-1]]),
        fill='toself', fillcolor='rgba(128, 0, 128, 0.2)',
        line=dict(color='rgba(255,255,255,0)'), hoverinfo='skip', name='Uncertainty'
    ))
    fig.update_layout(
        title=f'{chapter_name.upper()} - Combined Spectrum',
        xaxis_title='Wavelength (μm)', yaxis_title='Intensity (MJy/sr)', hovermode='x unified'
    )
    fig.write_html(f"{output_dir}/{chapter_name}_combined_plot.html")

# Run chapters
global_region_data = [{'wavelength': [], 'intensity': [], 'uncertainty': []} for _ in regions]

for chapter_name, file_paths in chapters.items():
    output_dir = os.path.join("output", chapter_name)
    extract_and_plot(chapter_name, file_paths, output_dir, global_region_data)

# Final all-chapters-combined
os.makedirs("output/all_chapters", exist_ok=True)

for r_idx, data in enumerate(global_region_data):
    wl = data['wavelength']
    inten = data['intensity']
    err = data['uncertainty']

    plt.figure(figsize=(15, 8))
    plt.errorbar(wl, inten, yerr=err, color='darkgreen')
    plt.xlabel('Wavelength (microns)')
    plt.ylabel('Average Intensity (MJy/sr)')
    plt.grid(True)
    plt.title(f'All Chapters - Region {r_idx+1}')
    plt.savefig(f"output/all_chapters/region{r_idx+1}_plot.png", dpi=300)
    plt.close()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=wl, y=inten, mode='lines', name='Spectrum'))
    fig.add_trace(go.Scatter(
        x=np.concatenate([wl, wl[::-1]]),
        y=np.concatenate([np.array(inten) + np.array(err), (np.array(inten) - np.array(err))[::-1]]),
        fill='toself', fillcolor='rgba(0,128,0,0.2)',
        line=dict(color='rgba(255,255,255,0)'), hoverinfo='skip', name='Uncertainty'
    ))
    fig.update_layout(
        title=f'All Chapters - Region {r_idx+1} Spectrum',
        xaxis_title='Wavelength (μm)', yaxis_title='Intensity (MJy/sr)', hovermode='x unified'
    )
    fig.write_html(f"output/all_chapters/region{r_idx+1}_plot.html")
