In [8]:
import os
import glob
import re
import numpy as np
from astropy.io import fits


In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [9]:
IMG_SIZE = 1024
BATCH_SIZE = 4
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4

In [10]:
TRAIN_FITS_ROOT  = "./Data/Training/FITS"
TRAIN_MASKS_ROOT = "./Data/Training/Masks"

MODEL_PATH = "model_CH_UNet.h5"

INFER_FITS_ROOT = "./Data/Inferrence/FITS"
INFER_MASKS_ROOT = "./Data/Inferrence/Masks"

In [11]:
def read_fits(path):
    """Read FITS and return 2D float32 numpy array."""
    with fits.open(path) as hdul:
        data = hdul[0].data.astype(np.float32)
    return np.nan_to_num(data)

In [12]:
def normalize_aia(img, clip_low=1, clip_high=99):
    low = np.percentile(img, clip_low)
    high = np.percentile(img, clip_high)
    img = np.clip(img, low, high)
    img = (img - low) / (high - low + 1e-6)
    return img


In [13]:

def read_mask_png(path):
    """
    Read PNG mask (0/255-ish) and return 0/1 float32 array.

    We *don't* rename; we just read your existing _CH_MASK_FINAL.png.
    """
    from PIL import Image
    im = Image.open(path).convert("L")
    arr = np.array(im, dtype=np.float32)
    # Consider >127 as CH
    arr = (arr > 127).astype(np.float32)
    return arr

In [14]:
def double_conv(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


def build_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=32):
    inputs = keras.Input(shape=input_shape)

    # Encoder
    c1 = double_conv(inputs, base_filters)
    p1 = layers.MaxPool2D(2)(c1)

    c2 = double_conv(p1, base_filters * 2)
    p2 = layers.MaxPool2D(2)(c2)

    c3 = double_conv(p2, base_filters * 4)
    p3 = layers.MaxPool2D(2)(c3)

    c4 = double_conv(p3, base_filters * 8)
    p4 = layers.MaxPool2D(2)(c4)

    # Bottleneck
    bn = double_conv(p4, base_filters * 16)

    # Decoder
    u4 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding="same")(bn)
    u4 = layers.Concatenate()([u4, c4])
    c5 = double_conv(u4, base_filters * 8)

    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(c5)
    u3 = layers.Concatenate()([u3, c3])
    c6 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c6)
    u2 = layers.Concatenate()([u2, c2])
    c7 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c7)
    u1 = layers.Concatenate()([u1, c1])
    c8 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c8)

    model = keras.Model(inputs, outputs, name="CH_UNet")
    return model

In [15]:
FITS_RE = re.compile(r"AIA(\d{8})_(\d{4,6})_(\d{4})\.fits$")
MASK_RE = re.compile(r"AIA(\d{8})_(\d{6})_(\d{4})_CH_MASK_FINAL\.png$")


In [16]:
def parse_fits_name(fname):
    m = FITS_RE.match(fname)
    if not m:
        return None
    date, time, wave = m.groups()
    # pad time to 6 digits if needed (HHMM → HHMM00)
    if len(time) == 4:
        time = time + "00"
    return date, time, wave

def parse_mask_name(fname):
    m = MASK_RE.match(fname)
    if not m:
        return None
    date, time, wave = m.groups()
    return date, time, wave


In [17]:
def time_to_int(time_str):
    """
    Convert HHMMSS (string) to integer seconds from midnight.
    Just for comparing "closest".
    """
    if len(time_str) != 6:
        # should not happen if regex correct, but fail-safe
        time_str = time_str.zfill(6)
    h = int(time_str[0:2])
    m = int(time_str[2:4])
    s = int(time_str[4:6])
    return h*3600 + m*60 + s

def build_fits_mask_pairs(fits_root, masks_root):
    """
    Recursively find all FITS and PNG masks, then pair:

    - For each FITS, find masks with same (date, wave)
    - Choose the one whose time is closest to FITS time
    - If none exist, skip that FITS

    Returns:
        list of (fits_path, mask_path)
    """
    # Gather FITS files
    fits_paths = [
        p for p in glob.glob(os.path.join(fits_root, "**", "*.fits"), recursive=True)
    ]

    if not fits_paths:
        raise RuntimeError(f"No FITS files found under {fits_root}")

    # Gather mask files
    mask_paths = [
        p for p in glob.glob(os.path.join(masks_root, "**", "*.png"), recursive=True)
    ]

    if not mask_paths:
        raise RuntimeError(f"No mask PNG files found under {masks_root}")

    # Index masks by (date, wave)
    mask_index = {}
    for mp in mask_paths:
        fname = os.path.basename(mp)
        parsed = parse_mask_name(fname)
        if parsed is None:
            continue
        date, time, wave = parsed
        key = (date, wave)
        entry = (time_to_int(time), mp)
        mask_index.setdefault(key, []).append(entry)

    # Sort mask lists by time
    for key in mask_index:
        mask_index[key].sort(key=lambda x: x[0])

    pairs = []

    for fp in fits_paths:
        fname = os.path.basename(fp)
        parsed = parse_fits_name(fname)
        if parsed is None:
            # not an AIA-style name, skip
            continue
        date, time, wave = parsed
        key = (date, wave)
        if key not in mask_index:
            # no mask for this date+wave
            continue

        t_fits = time_to_int(time)
        candidates = mask_index[key]

        # find closest time
        best_dt = None
        best_mp = None
        for t_mask, mp in candidates:
            dt = abs(t_mask - t_fits)
            if best_dt is None or dt < best_dt:
                best_dt = dt
                best_mp = mp

        if best_mp is not None:
            pairs.append((fp, best_mp))

    if not pairs:
        raise RuntimeError("No FITS↔mask pairs found. Check filename patterns.")

    print(f"Found {len(pairs)} FITS↔mask pairs for training.")
    return pairs

In [18]:
build_fits_mask_pairs(TRAIN_FITS_ROOT, TRAIN_MASKS_ROOT)

RuntimeError: No FITS↔mask pairs found. Check filename patterns.