# Image Correction

In [None]:
%matplotlib inline
import sys,os
import shutil
from os.path import basename
from glob import glob
import logging 
import numpy as np
from astropy.io import fits


# Image correction algorithm 
#sys.path.append('add path to class directory if different')
from skipper_roi_img_reduction import *

logging.basicConfig()
logging.root.setLevel(logging.NOTSET)
log = logging.getLogger("data_reduction")


## Helper Functions

In [None]:
def apply_bias_correction(raw_file_path, out_file_path, ROI=True, 
                          correction_type='SP', single_sample=False, 
                          gain_correction=False):
    """
    Apply overscan and optional gain corrections to a raw Skipper image, then save the result.

    Parameters
    ----------
    raw_file_path : str
        Path to the raw FITS file to be processed.
    out_file_path : str
        Path where the corrected FITS file will be saved.
    ROI : bool, optional
        Whether to interpret and apply ROI definitions from the FITS header. 
        If True, the function processes ROI-specific overscan corrections. 
        Defaults to True.
    correction_type : str, optional
        The overscan correction method. Common options include:
          - 'S': Serial overscan correction only.
          - 'P': Parallel overscan correction only.
          - 'SP': Both serial and parallel overscan correction.
        Defaults to 'SP'.
    single_sample : bool, optional
        If True, process the image as a single-sample calibration product. 
        This may restrict the available overscan corrections (e.g., only 
        serial). Defaults to False.
    gain_correction : bool, optional
        If True, apply gain correction using a gain file loaded by the 
        `SkipperImageROI` class (either a default file or user-provided). 
        Defaults to False.
    """
    img = SkipperImageROI(filename=raw_file_path,
                          overscan_correction_type=correction_type,
                          manual_input_slices=False,
                          correction_method="CUBIC_SPLINE",
                          ROI=ROI,
                          calibration_single_sample=single_sample,
                          gain_correction=gain_correction)
    
    img.save(out_file_path)


def create_combined_product(fits_dir, out_file_path, norm=True):
    """
    Create a median-combined (and optionally row-normalized) FITS product from 'proc' files.

    This function searches `fits_dir` for all FITS files containing 'proc' in their 
    filenames, then median-combines the data from each amplifier HDU across these files. 
    It also offers an option to normalize each row by its median value before writing 
    the final combined data to `out_file_path`.

    Parameters
    ----------
    fits_dir : str
        Path to the directory containing the 'proc' FITS files to be combined.
    out_file_path : str
        Path where the resulting combined FITS product will be saved.
    norm : bool, optional
        If True, each row in the combined product is divided by its median value,
        effectively performing a row-by-row normalization. Defaults to True.

    Returns
    -------
    int
        Returns 0 if no files with 'proc' in their names were found in `fits_dir`.
        Otherwise, no value is returned (None).
    """
    new_hdul = fits.HDUList()
    fits_files = [f for f in os.listdir(fits_dir) if 'proc' in f.lower()]

    if len(fits_files) == 0:
        log.warning("No 'proc' files found in the directory")
        return 0
    
    first_file = fits.open(os.path.join(fits_dir, fits_files[0]))
    num_hdus = len(first_file) 
    num_rows, num_cols = first_file[1].data.shape  
    
    # Create a new HDU list starting with the primary HDU from the first file
    new_hdul.append(fits.PrimaryHDU(header=first_file[0].header))
   
    # Initialize an array to hold all images for median combination
    all_images = np.zeros((len(fits_files), num_hdus - 1, num_rows, num_cols))
    
    # Load data from each 'proc' file
    for i, fits_file in enumerate(fits_files):
        hdul = fits.open(os.path.join(fits_dir, fits_file))
        for j in range(1, num_hdus):
            all_images[i, j - 1] = hdul[j].data
        hdul.close()
    
    # Compute the median across all files
    combined_product = np.median(all_images, axis=0)
    
    # Optional row-wise normalization
    if norm:
        for i in range(combined_product.shape[0]):
            row_medians = np.median(combined_product[i], axis=1)
            combined_product[i] = combined_product[i] / row_medians[:, np.newaxis]

    # Create new HDUs for each amplifier and update headers
    for i in range(num_hdus - 1):
        header = first_file[i + 1].header
        new_hdu = fits.PrimaryHDU(combined_product[i])
        new_hdu.header.update(header)
        new_hdul.append(new_hdu)
    
    new_hdul.writeto(out_file_path, overwrite=True)


## File Classification

In [None]:
def file_classification(dir_path):
    """
    Classify and organize FITS files based on their header information.

    This function:
      1. Identifies all `.fits` files in the specified directory.
      2. Reads each file's header to determine the "OBSTYPE" keyword,
         creating a corresponding subdirectory if it does not already exist.
      3. Copies the original `.fits` file into that subdirectory.
      4. Further checks each file's "OBJECT" keyword:
         - If it contains "focus", the file is moved to a "FOCUS" subdirectory.
         - Otherwise, a subdirectory named after the `OBJECT` keyword is created,
           and the file is moved there.
      5. Returns a list of newly created subdirectories (those intended to 
         be corrected later in the data-reduction pipeline).

    Parameters
    ----------
    dir_path : str
        The path to the directory containing the `.fits` files 
        that need to be classified.

    Returns
    -------
    list of str
        A list of directories that were created for specific objects
        and are expected to hold files for further processing.
    """
    all_files = [file for file in os.listdir(dir_path) if file.endswith(".fits")]
    path_to_files = [os.path.join(dir_path, file) for file in all_files]
    header_index_info = 0
    images_to_correct_dirs = list()

    for file in path_to_files:
        fits_hdu_list = fits.open(file)
        type_name = fits_hdu_list[header_index_info].header.get("OBSTYPE", "UNKNOWN")
        type_dir = os.path.join(dir_path, type_name)

        if not os.path.exists(type_dir):
            os.makedirs(type_dir)

        new_file_path = os.path.join(type_dir, os.path.basename(file))
        shutil.copy2(file, new_file_path)

    # Check for focusing files (we do not need to correct these).
    for type_name in os.listdir(dir_path):
        type_dir = os.path.join(dir_path, type_name)

        if os.path.isdir(type_dir):
            focus_dir = os.path.join(type_dir, "FOCUS")
            for file in os.listdir(type_dir):
                file_path = os.path.join(type_dir, file)
                if file.endswith(".fits"):
                    fits_hdu_list = fits.open(file_path)
                    object_name = fits_hdu_list[header_index_info].header.get("OBJECT", "UNKNOWN").lower()

                    if "focus" in object_name:
                        if not os.path.exists(focus_dir):
                            os.makedirs(focus_dir)

                        new_focus_path = os.path.join(focus_dir, os.path.basename(file))
                        shutil.move(file_path, new_focus_path)

    for type_name in os.listdir(dir_path):
        type_dir = os.path.join(dir_path, type_name)

        if os.path.isdir(type_dir):
            for file in os.listdir(type_dir):
                file_path = os.path.join(type_dir, file)

                if file.endswith(".fits"):
                    fits_hdu_list = fits.open(file_path)
                    object_name = fits_hdu_list[header_index_info].header.get("OBJECT", "UNKNOWN")

                    # Create a directory for each object name
                    object_dir = os.path.join(type_dir, object_name)
                    if not os.path.exists(object_dir):
                        os.makedirs(object_dir)
                        images_to_correct_dirs.append(object_dir)

                    new_object_path = os.path.join(object_dir, os.path.basename(file))
                    shutil.move(file_path, new_object_path)

    return images_to_correct_dirs


## Image Correction

In [None]:
def image_reduction(dir_path, flat_filed=True):
    """
    Perform a multi-step image reduction process on SIFS-Skipper CCD data.

    This function executes an automated workflow to:

    1. Classify input FITS files based on their headers (using `file_classification`).
    2. Apply bias subtraction (using `apply_bias_correction`) to produce 'proc_' files.
    3. Combine and optionally normalize milky-flatfield and bias data 
       (using `create_combined_product`).
    4. Apply flatfield corrections (optional) for object images if a 
       combined milky-flatfield product is available.

    Parameters
    ----------
    dir_path : str
        Path to the directory containing raw or newly sorted FITS files. 
        Subdirectories will be created automatically for each OBSTYPE 
        and OBJECT found in the headers.
    flat_filed : bool, optional
        If True, the function attempts to create and use combined 
        milky-flatfield products to flat-field correct any 'proc_' 
        object images found. Defaults to True.

    Notes
    -----
    **Workflow Details**:
    
    - **File Classification**:
      Invokes `file_classification(dir_path)` to sort `.fits` files by 
      'OBSTYPE' and 'OBJECT' keywords. Returns a list of directories 
      (images_to_correct) in which files are placed.

    - **Bias Subtraction** (Step 1):
      Iterates over each directory in `images_to_correct`. For each FITS file:
        - Loads its header to check `NROIS`.
        - If `NROIS == 0`, applies a single-sample bias correction (serial overscan, 
          `apply_bias_correction(..., ROI=False, correction_type='S', single_sample=True, ...)`).
        - If `NROIS > 0`, applies a full ROI-based bias correction (serial+parallel, 
          `apply_bias_correction(..., ROI=True, correction_type='SP', single_sample=False, ...)`).
        - Produces an output file prepended with `'proc_'`.

    - **Combined Data Products** (Step 2):
      For each directory in `images_to_correct`:
        - If 'milky' is in the directory name, creates a row-normalized combined product 
          (`combined_norm_milky.fits`) using `create_combined_product(..., norm=True)`.
        
    - **Flatfield Correction** (Step 3, optional):
      If `flat_filed=True`, the function looks for a milky-flatfield combined product 
      (containing `'milky'` in the directory name and `'combined'` in the file name). 
      If found:
        - Loads it into a `SkipperImageROI` object for processing, then retrieves the 
          flatfield data as a NumPy array.
        - Searches for any `'proc_'` object files and divides them by the flatfield array, 
          taking care to handle any dimension mismatches by padding with ones as needed.
        - Produces an output FITS file labeled `'flat_filed_proc_<filename>'`.

    Raises
    ------
    SystemExit
        If a milky-flatfield directory is not found or a combined milky-flatfield product 
        cannot be located, the function logs a warning and exits, as no flatfield correction
        can be performed.
    """
    header_index_info = 0
    
    # Implement file classification and get image paths
    images_to_correct = file_classification(dir_path)
    
    # Step 1: Apply bias subtraction 
    for path in images_to_correct:
        if os.path.isdir(path):
            for file in os.listdir(path):
                img_path = os.path.join(path, file)
                out_path = os.path.join(path, 'proc_' + file)
                img = fits.open(img_path)
                 
                ROI_num = int(img[header_index_info].header.get("NROIS", "UNKNOWN"))
                if ROI_num == 0:
                    apply_bias_correction(
                        img_path,
                        out_path,
                        ROI=False,
                        correction_type='S',
                        single_sample=True,
                        gain_correction=True
                    )
                
                elif ROI_num > 0:
                    apply_bias_correction(
                        img_path,
                        out_path,
                        ROI=True,
                        correction_type='SP',
                        single_sample=False,
                        gain_correction=True
                    )
                else:
                    log.warning(str(img_path) + " Unknown number of regions. Proceeding to next image.")
                    continue   
        else:
            log.warning(str(path) + " is not a directory. Proceeding...")
            continue 

    # Step 2: Create 'combined' data products for milky flatfields and bias 
    for path in images_to_correct:
        if 'milky' in path.lower():
            fname = 'combined_norm_' + path.split('/')[-1].lower() + '.fits'
            out_file_path = os.path.join(path, fname)
            create_combined_product(path, out_file_path, norm=True)
        
        elif 'bias' in path.lower():
            fname = 'combined_' + path.split('/')[-1].lower() + '.fits'
            out_file_path = os.path.join(path, fname)
            create_combined_product(path, out_file_path, norm=False)
     
    # Step 3: Apply flatfield correction if requested
    if flat_filed:
        milky_found = False
        combined_flat_path = None
        
        for path in images_to_correct:
            if 'milky' in path.lower():
                milky_found = True
                for file in os.listdir(path):
                    if 'combined' in file.lower():
                        combined_flat_path = os.path.join(path, file)
                        flat_field = SkipperImageROI(
                            filename=combined_flat_path,
                            overscan_correction_type='none',
                            manual_input_slices=False,
                            correction_method="CUBIC_SPLINE",
                            ROI=False,
                            calibration_single_sample=True,
                            gain_correction=False
                        )
                        flat_field.processImage()
                        flat_field = flat_field.get_full_image()
                break  

        if not milky_found:
            log.warning("Milky-flatfields directory not found: Possibly 'Milky' is missing in the header.")
            log.warning("Flatfield correction will not be applied.")
            sys.exit(0)
        
        if combined_flat_path is None:
            log.warning("Combined flat-field product not found.")
            log.warning("Flatfield correction will not be applied.")
            sys.exit(0)

        # Apply the flatfield to object images
        for path in images_to_correct:
            if 'object' in path.lower():
                for file in os.listdir(path):
                    if 'proc' in file.lower():
                        to_correct_path = os.path.join(path, file)
                        img = fits.open(to_correct_path)
                        ROI_num = int(img[header_index_info].header.get("NROIS", "UNKNOWN"))

                        if ROI_num == 0:
                            object_img = SkipperImageROI(
                                filename=to_correct_path,
                                overscan_correction_type='none',
                                manual_input_slices=False,
                                correction_method="CUBIC_SPLINE",
                                ROI=False,
                                calibration_single_sample=True,
                                gain_correction=False
                            )
                            object_img.processImage()
                            img_data = object_img.get_full_image()

                        elif ROI_num > 0:
                            object_img = SkipperImageROI(
                                filename=to_correct_path,
                                overscan_correction_type='none',
                                manual_input_slices=False,
                                correction_method="CUBIC_SPLINE",
                                ROI=False,
                                calibration_single_sample=True,
                                gain_correction=False
                            )
                            object_img.processImage()
                            img_data = object_img.get_full_image()

                        amp_object, row_object, col_object = img_data.shape
                        amp_flat, row_flat, col_flat = flat_field.shape

                        new_rows = max(row_object, row_flat)
                        new_cols = max(col_object, col_flat)
                        new_shape = (amp_object, new_rows, new_cols)

                        new_image = np.ones(new_shape)
                        new_combined_flat_field = np.ones(new_shape)

                        new_image[:, :row_object, :col_object] = img_data
                        new_combined_flat_field[:, :row_flat, :col_flat] = flat_field

                        object_flat_fielded = new_image / new_combined_flat_field
                        object_img.set_full_image(object_flat_fielded)

                        fname = 'flat_filed_' + file
                        out_file_path = os.path.join(path, fname)
                        object_img.save(out_file_path, injected=False)


In [None]:
def main():
    # Set the directory path containing the images to be processed
    data_path = "/data/des81.b/data/emarrufo/SIFS_data/SIFS_DATA/20240722_TEST/20240722_test"

    # Call the image_reduction function
    image_reduction(data_path, flat_filed=False)

if __name__ == "__main__":
    main()
