In [None]:
import os
import glob
import re
import numpy as np
import sunpy.map
from sunpy.map.maputils import all_coordinates_from_map, coordinate_is_on_solar_disk
import pandas as pd
import PIL

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
IMG_SIZE = 1024
BATCH_SIZE = 4
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4

### Data Processing

In [None]:
TRAIN_FITS_ROOT = "./Data/Training/FITS"
TRAIN_MASKS_ROOT = "./Data/Training/Masks"

MODEL_PATH = "model_CH_UNet.h5"

INFER_FITS_ROOT = "./Data/Inferrence/FITS"
INFER_MASKS_ROOT = "./Data/Inferrence/Masks"

In [None]:
def mask_fits(f):
    hpc_coords = all_coordinates_from_map(f)
    mask = coordinate_is_on_solar_disk(hpc_coords)
    palette = f.cmap.copy()
    palette.set_bad("black")
    scaled_map = sunpy.map.Map(f.data, f.meta, mask=~mask)
    ff = scaled_map.data
    return ff

def prepare_fits(path, mask_disk=True, clip_low=1, clip_high=99):
    f = sunpy.map.Map(path)
    if mask_disk:
        ff = mask_fits(f).data
    else:
        ff = f.data

    low = np.percentile(ff, clip_low)
    high = np.percentile(ff, clip_high)
    ff = np.clip(ff, low, high)
    ff = (ff - low) / (high - low + 1e-6)
    return ff

In [None]:
def prepare_mask(path):
    """
    Read PNG mask (0/255-ish) and return 0/1 float32 array.

    We *don't* rename; we just read your existing _CH_MASK_FINAL.png.
    """
    from PIL import Image

    im = Image.open(path).convert("L")
    arr = np.array(im, dtype=np.float32)
    # Consider >127 as CH
    arr = (arr > 127).astype(np.float32)
    return arr

In [None]:
def prepare_dataset(fits_root, masks_root):
    def index(p):
        return p.split("/")[-1][3:16]

    fits_re = re.compile(r"AIA(\d{8})_\d{4,6}_(\d{4})\.fits$")
    mask_re = re.compile(r"AIA(\d{8})_\d{6}_(\d{4})_CH_MASK_FINAL\.png$")

    # Collect all FITS and masks
    fits_files = glob.glob(os.path.join(fits_root, "**", "*.fits"), recursive=True)
    mask_files = glob.glob(
        os.path.join(masks_root, "**", "*_CH_MASK_FINAL.png"), recursive=True
    )

    df_fits = pd.DataFrame(
    {
        "key": [index(p) for p in fits_files],
        "fits_path": fits_files,
    }
    )

    df_masks = pd.DataFrame(
        {
            "key": [index(p) for p in mask_files],
            "mask_path": mask_files,
        }
    )

    # Optional: detect duplicate keys (same timestamp, multiple files)
    dup_fits = df_fits[df_fits.duplicated("key", keep=False)]
    dup_masks = df_masks[df_masks.duplicated("key", keep=False)]

    if not dup_fits.empty:
        print("⚠ Duplicate keys in FITS:")
        print(dup_fits.sort_values("key"))

    if not dup_masks.empty:
        print("⚠ Duplicate keys in masks:")
        print(dup_masks.sort_values("key"))

    # Outer join to see everything in one table
    merged = df_fits.merge(df_masks, on="key", how="outer", indicator=True)

    matches = merged[merged["_merge"] == "both"].copy()
    fits_only = merged[merged["_merge"] == "left_only"].copy()
    masks_only = merged[merged["_merge"] == "right_only"].copy()


    for df in (matches, fits_only, masks_only):
        df.drop(columns=["_merge"], inplace=True)

    matches.set_index(matches.key, inplace=True, drop=True)
    matches.drop(["key"], axis=1, inplace=True)

    return matches, fits_only, masks_only

In [None]:
train_df = prepare_dataset(TRAIN_FITS_ROOT, TRAIN_MASKS_ROOT)[0]
inf_df = prepare_dataset(INFER_FITS_ROOT, INFER_MASKS_ROOT)[0]

### Network Architecture

In [None]:
def augment_pair(img, mask):
    # --- random horizontal flip ---
    stacked = tf.concat([img, mask], axis=-1)          # (H, W, 2)
    stacked = tf.cond(
        tf.random.uniform(()) > 0.5,
        lambda: tf.image.flip_left_right(stacked),
        lambda: stacked,
    )
    img, mask = stacked[..., :1], stacked[..., 1:]
    
    # --- small random brightness scaling ---
    
    scale = tf.random.uniform((), 0.9, 1.1)            # ±10 %
    img = img * scale
    img = tf.clip_by_value(img, 0.0, 1.0)

    return img, mask

In [None]:
def _load_pair_numpy(fits_path, mask_path):
    # fits_path, mask_path come from tf.numpy_function as bytes or strings
    fits_path = fits_path.decode("utf-8") if isinstance(fits_path, (bytes, np.bytes_)) else fits_path
    mask_path = mask_path.decode("utf-8") if isinstance(mask_path, (bytes, np.bytes_)) else mask_path

    img  = prepare_fits(fits_path)   # 2D np array
    mask = prepare_mask(mask_path)   # 2D np array

    img  = np.asarray(img, dtype=np.float32)
    mask = np.asarray(mask, dtype=np.float32)

    if img.ndim == 2:
        img = img[..., np.newaxis]
    if mask.ndim == 2:
        mask = mask[..., np.newaxis]

    return img, mask


def load_pair_tf(fits_path, mask_path):
    img, mask = tf.numpy_function(
        _load_pair_numpy,
        [fits_path, mask_path],
        [tf.float32, tf.float32],
    )

    img.set_shape((IMG_SIZE, IMG_SIZE, 1))
    mask.set_shape((IMG_SIZE, IMG_SIZE, 1))

    img, mask = augment_pair(img, mask)

    return img, mask

In [None]:
def double_conv(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


In [None]:
def build_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=32):
    inputs = keras.Input(shape=input_shape)

    # Encoder
    c1 = double_conv(inputs, base_filters)
    p1 = layers.MaxPool2D(2)(c1)

    c2 = double_conv(p1, base_filters * 2)
    p2 = layers.MaxPool2D(2)(c2)

    c3 = double_conv(p2, base_filters * 4)
    p3 = layers.MaxPool2D(2)(c3)

    c4 = double_conv(p3, base_filters * 8)
    p4 = layers.MaxPool2D(2)(c4)

    # Bottleneck
    bn = double_conv(p4, base_filters * 16)

    # Decoder
    u4 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding="same")(bn)
    u4 = layers.Concatenate()([u4, c4])
    c5 = double_conv(u4, base_filters * 8)

    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(c5)
    u3 = layers.Concatenate()([u3, c3])
    c6 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c6)
    u2 = layers.Concatenate()([u2, c2])
    c7 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c7)
    u1 = layers.Concatenate()([u1, c1])
    c8 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c8)

    model = keras.Model(inputs, outputs, name="CH_UNet")
    return model

In [None]:

def build_tiny_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=16):
    """
    Smaller, dumber U-Net:
      - Only 3 downsampling levels
      - Much fewer filters
      - Cheaper to train, less capacity

    Suitable when you care more about speed than squeezing out last % IoU.
    """
    inputs = keras.Input(shape=input_shape)

    # Encoder (3 levels instead of 4–5)
    c1 = double_conv(inputs, base_filters)          #  H x  W
    p1 = layers.MaxPool2D(2)(c1)                   # H/2 x W/2

    c2 = double_conv(p1, base_filters * 2)          # H/2 x W/2
    p2 = layers.MaxPool2D(2)(c2)                   # H/4 x W/4

    c3 = double_conv(p2, base_filters * 4)          # H/4 x W/4
    p3 = layers.MaxPool2D(2)(c3)                   # H/8 x W/8

    # Bottleneck
    bn = double_conv(p3, base_filters * 8)

    # Decoder
    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(bn)
    u3 = layers.Concatenate()([u3, c3])
    c4 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c4)
    u2 = layers.Concatenate()([u2, c2])
    c5 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c5)
    u1 = layers.Concatenate()([u1, c1])
    c6 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c6)

    model = keras.Model(inputs, outputs, name="CH_TinyUNet")
    return model

In [None]:
def dice_coef(y_true, y_pred, smooth=1.0):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
    y_pred_f = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])

    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    denom = tf.reduce_sum(y_true_f, axis=1) + tf.reduce_sum(y_pred_f, axis=1)

    dice = (2.0 * intersection + smooth) / (denom + smooth)
    return tf.reduce_mean(dice)


def bce_dice_loss(y_true, y_pred):
    bce = keras.losses.binary_crossentropy(y_true, y_pred)
    return 0.4 * tf.reduce_mean(bce) + 0.6 * (1.0 - dice_coef(y_true, y_pred))

### Model Training

In [None]:
def load_trained_model():
    custom_objects = {
        "bce_dice_loss": bce_dice_loss,
        "dice_coef": dice_coef,
    }
    model = keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
    return model

In [None]:
def train_model(pairs_df):
    """
    Train U-Net using a DataFrame with columns:
      - 'fits_path' : paths to AIA FITS files
      - 'mask_path' : paths to corresponding CH_MASK_FINAL images

    The DataFrame index (your key) is not used here,
    it's just along for the ride.
    """

    # --- Extract aligned lists from the DataFrame ---
    fits_paths = pairs_df["fits_path"].astype(str).tolist()
    mask_paths = pairs_df["mask_path"].astype(str).tolist()

    if len(fits_paths) == 0:
        raise RuntimeError("pairs_df is empty: no FITS-mask pairs to train on.")

    if len(fits_paths) != len(mask_paths):
        raise RuntimeError(
            f"Mismatch in pairs_df: {len(fits_paths)} FITS vs {len(mask_paths)} masks."
        )

    n_total = len(fits_paths)
    print(f"Training on {n_total} FITS-mask pairs")

    # --- Build tf.data datasets from the lists ---
    fits_paths_tf = tf.constant(fits_paths)
    mask_paths_tf = tf.constant(mask_paths)

    # Simple temporal-ish split: first 90% train, last 10% val
    n_val = max(1, int(0.1 * n_total))
    n_train = n_total - n_val

    train_fits = fits_paths_tf[:n_train]
    train_masks = mask_paths_tf[:n_train]
    val_fits = fits_paths_tf[n_train:]
    val_masks = mask_paths_tf[n_train:]

    train_ds = (
        tf.data.Dataset.from_tensor_slices((train_fits, train_masks))
        .shuffle(buffer_size=n_train)
        .map(load_pair_tf, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    val_ds = (
        tf.data.Dataset.from_tensor_slices((val_fits, val_masks))
        .map(load_pair_tf, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    # --- Build and compile model ---
    model = build_unet(base_filters=32)

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=bce_dice_loss,
        metrics=[dice_coef, "accuracy"],
    )

    model.summary()

    # --- Callbacks for checkpointing & LR scheduling ---
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            MODEL_PATH,
            monitor="val_loss",
            save_best_only=True, 
            save_weights_only=False,
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=3,
            verbose=1,
        ),
    ]

    # --- Train ---
    model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=NUM_EPOCHS,
        callbacks=callbacks,
    )

    print(f"Training finished. Best model saved to {MODEL_PATH}")

In [None]:
train_model(train_df)

In [None]:
model = load_trained_model()

### Testing

In [None]:
def apply_model_to_array(model, img2d, resize=False):
    """
    img2d: 2D numpy array (H, W), already preprocessed (normalized, disk masked, etc.)
    model: your trained Keras model

    Returns:
        prob_map: 2D numpy array (img_size, img_size) with values in [0, 1]
    """
    img = np.asarray(img2d, dtype=np.float32)

    if img.ndim != 2:
        raise ValueError(f"Expected 2D array, got shape {img.shape}")

    # Optional: resize if shape doesn't match training size
    if resize and img.shape != (IMG_SIZE, IMG_SIZE):
        import tensorflow as tf
        img_tf = tf.convert_to_tensor(img[..., np.newaxis])  # (H, W, 1)
        img_tf = tf.image.resize(img_tf, (IMG_SIZE, IMG_SIZE), method="bilinear")
        img = img_tf.numpy()[..., 0]

    # Add channel and batch dims: (H, W) → (1, H, W, 1)
    x = img[np.newaxis, ..., np.newaxis]  # (1, H, W, 1)

    # Predict
    prob = model.predict(x, verbose=0)[0, ..., 0]  # back to (H, W)

    return prob

In [None]:
def mask_via_model(path):
    prob_map = apply_model_to_array(model, path, resize=True)
    mask = (prob_map > 0.5).astype(np.float32)  # binary mask if you want it
    mask_uint8 = (mask * 255).clip(0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(mask_uint8, mode="L")
    return img

In [None]:
mask_via_model(prepare_fits(train_df.iloc[666].fits_path))

In [None]:
PIL.Image.open(train_df.iloc[666].mask_path)