In [1]:
import os
import glob

import numpy as np
import pandas as pd

import PIL
import matplotlib.pyplot as plt

from sunpy.map.maputils import all_coordinates_from_map, coordinate_is_on_solar_disk
import sunpy.visualization.colormaps.color_tables as ct
import sunpy.map

import astropy.units as u

In [2]:
# onnx
# https://developer.apple.com/metal/jax/ (with keras via
# multiscale sampling

In [3]:
os.environ["KERAS_BACKEND"] = "jax"
import keras
# keras.config.set_backend("jax")
from keras import layers, ops

In [4]:
IMG_SIZE = 256
BATCH_SIZE = 4
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4

### Data Processing

In [5]:
TRAIN_FITS_ROOT = "./Data/Training/FITS"
TRAIN_MASKS_ROOT = "./Data/Training/Masks"

MODEL_PATH = "model_CH_UNet.keras"

INFER_FITS_ROOT = "./Data/Inferrence/FITS"
INFER_MASKS_ROOT = "./Data/Inferrence/Masks"

In [6]:
def mask_fits(f):
    hpc_coords = all_coordinates_from_map(f)
    mask = coordinate_is_on_solar_disk(hpc_coords)
    palette = f.cmap.copy()
    palette.set_bad("black")
    scaled_map = sunpy.map.Map(f.data, f.meta, mask=~mask)
    ff = scaled_map.data
    return ff


def prepare_fits(path, mask_disk=True, clip_low=1, clip_high=99):
    f = sunpy.map.Map(path)
    if mask_disk:
        ff = mask_fits(f).data
    else:
        ff = f.data

    low = np.percentile(ff, clip_low)
    high = np.percentile(ff, clip_high)
    ff = np.clip(ff, low, high)
    ff = (ff - low) / (high - low + 1e-6)
    return ff

In [7]:
def prepare_mask(path):
    """
    Read PNG mask (0/255-ish) and return 0/1 float32 array.

    We *don't* rename; we just read your existing _CH_MASK_FINAL.png.
    """
    from PIL import Image

    im = Image.open(path).convert("L")
    arr = np.array(im, dtype=np.float32)
    # Consider >127 as CH
    arr = (arr > 127).astype(np.float32)
    return arr

In [8]:
pd.set_option("display.width", 1000)


def prepare_dataset(fits_root, masks_root):
    def index(p):
        return p.split("/")[-1][3:16]

    # Collect all FITS and masks
    fits_files = glob.glob(os.path.join(fits_root, "**", "*.fits"), recursive=True)
    mask_files = glob.glob(
        os.path.join(masks_root, "**", "*_CH_MASK_FINAL.png"), recursive=True
    )

    df_fits = pd.DataFrame(
        {
            "key": [index(p) for p in fits_files],
            "fits_path": fits_files,
        }
    )

    df_masks = pd.DataFrame(
        {
            "key": [index(p) for p in mask_files],
            "mask_path": mask_files,
        }
    )

    # Optional: detect duplicate keys (same timestamp, multiple files)
    dup_fits = df_fits[df_fits.duplicated("key", keep=False)]
    dup_masks = df_masks[df_masks.duplicated("key", keep=False)]

    if not dup_fits.empty:
        print("⚠ Duplicate keys in FITS:")
        print(dup_fits.sort_values("key"))

    if not dup_masks.empty:
        print("⚠ Duplicate keys in masks:")
        print(dup_masks.sort_values("key"))

    df_fits = df_fits.drop_duplicates("key", keep="first")
    df_masks = df_masks.drop_duplicates("key", keep="first")

    # Outer join to see everything in one table
    merged = df_fits.merge(df_masks, on="key", how="outer", indicator=True)

    matches = merged[merged["_merge"] == "both"].copy()
    fits_only = merged[merged["_merge"] == "left_only"].copy()
    masks_only = merged[merged["_merge"] == "right_only"].copy()

    for df in (matches, fits_only, masks_only):
        df.drop(columns=["_merge"], inplace=True)

    matches.set_index(matches.key, inplace=True, drop=True)
    matches.drop(["key"], axis=1, inplace=True)

    return matches, fits_only, masks_only

In [9]:
train_df = prepare_dataset(TRAIN_FITS_ROOT, TRAIN_MASKS_ROOT)[0]
inf_df = prepare_dataset(INFER_FITS_ROOT, INFER_MASKS_ROOT)[0]

⚠ Duplicate keys in FITS:
                key                                          fits_path
9182  20170125_2146  ./Data/Inferrence/FITS/2017/01/25/AIA20170125_...
9130  20170125_2146  ./Data/Inferrence/FITS/2017/01/25/AIA20170125_...
5507  20170712_2154  ./Data/Inferrence/FITS/2017/07/12/AIA20170712_...
5596  20170712_2154  ./Data/Inferrence/FITS/2017/07/12/AIA20170712_...
7885  20170801_0054  ./Data/Inferrence/FITS/2017/08/01/AIA20170801_...
...             ...                                                ...
7790  20170831_2154  ./Data/Inferrence/FITS/2017/08/31/AIA20170831_...
6787  20170831_2254  ./Data/Inferrence/FITS/2017/09/08/31/AIA201708...
7800  20170831_2254  ./Data/Inferrence/FITS/2017/08/31/AIA20170831_...
6794  20170831_2354  ./Data/Inferrence/FITS/2017/09/08/31/AIA201708...
7807  20170831_2354  ./Data/Inferrence/FITS/2017/08/31/AIA20170831_...

[1550 rows x 2 columns]


### Network Architecture

#### Data

In [10]:
def augment_pair(img, mask):
    # --- random horizontal flip ---
    if np.random.rand() < 0.5:
        img = img[:, ::-1, :]
        mask = mask[:, ::-1, :]

    # --- small random brightness scaling ---
    scale = np.random.uniform(0.9, 1.1)  # ±10 %
    img = img * scale
    img = np.clip(img, 0.0, 1.0)

    return img, mask

In [11]:
def load_pair(fits_path, mask_path):
    # decode paths if coming in as bytes
    if isinstance(fits_path, (bytes, np.bytes_)):
        fits_path = fits_path.decode("utf-8")
    if isinstance(mask_path, (bytes, np.bytes_)):
        mask_path = mask_path.decode("utf-8")

    # load 2-D arrays
    img = np.asarray(prepare_fits(fits_path), dtype=np.float32)
    mask = np.asarray(prepare_mask(mask_path), dtype=np.float32)

    # resize
    img_resized = PIL.Image.fromarray((img * 255).astype(np.uint8)).resize(
        (IMG_SIZE, IMG_SIZE), resample=PIL.Image.BILINEAR
    )
    mask_resized = PIL.Image.fromarray((mask * 255).astype(np.uint8)).resize(
        (IMG_SIZE, IMG_SIZE), resample=PIL.Image.NEAREST
    )

    # normalize back to [0,1] and add channel axis
    img = np.expand_dims(np.array(img_resized, dtype=np.float32) / 255.0, axis=-1)
    mask = np.expand_dims(np.array(mask_resized, dtype=np.float32) / 255.0, axis=-1)
    mask = (mask > 0.5).astype(np.float32)

    return img, mask

In [12]:
def pair_generator(fits_paths, mask_paths, batch_size, augment=False):
    n = len(fits_paths)
    idxs = np.arange(n)

    while True:
        np.random.shuffle(idxs)
        for i in range(0, n, batch_size):
            batch_idx = idxs[i : i + batch_size]
            imgs, masks = [], []
            for j in batch_idx:
                img, mask = load_pair(fits_paths[j], mask_paths[j])
                if augment:
                    img, mask = augment_pair(img, mask)

                imgs.append(img)
                masks.append(mask)

            if not imgs:
                continue

            X = np.stack(imgs, axis=0).astype(np.float32)  # (B, H, W, 1)
            Y = np.stack(masks, axis=0).astype(np.float32)  # (B, H, W, 1)
            yield X, Y

#### Model

In [13]:
def double_conv(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [14]:
def build_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=32):
    inputs = keras.Input(shape=input_shape)

    # processedinputs = .....(inputs)

    # Encoder
    c1 = double_conv(inputs, base_filters)
    p1 = layers.MaxPool2D(2)(c1)

    c2 = double_conv(p1, base_filters * 2)
    p2 = layers.MaxPool2D(2)(c2)

    c3 = double_conv(p2, base_filters * 4)
    p3 = layers.MaxPool2D(2)(c3)

    c4 = double_conv(p3, base_filters * 8)
    p4 = layers.MaxPool2D(2)(c4)

    # Bottleneck
    bn = double_conv(p4, base_filters * 16)

    # Decoder
    u4 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding="same")(bn)
    u4 = layers.Concatenate()([u4, c4])
    c5 = double_conv(u4, base_filters * 8)

    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(c5)
    u3 = layers.Concatenate()([u3, c3])
    c6 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c6)
    u2 = layers.Concatenate()([u2, c2])
    c7 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c7)
    u1 = layers.Concatenate()([u1, c1])
    c8 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c8)

    model = keras.Model(inputs, outputs, name="CH_UNet")
    return model

In [15]:
def dice_coef(y_true, y_pred, smooth=1.0):
    # flatten per-sample
    y_true_f = ops.reshape(y_true, (ops.shape(y_true)[0], -1))
    y_pred_f = ops.reshape(y_pred, (ops.shape(y_pred)[0], -1))

    intersection = ops.sum(y_true_f * y_pred_f, axis=1)
    denom = ops.sum(y_true_f, axis=1) + ops.sum(y_pred_f, axis=1)

    dice = (2.0 * intersection + smooth) / (denom + smooth)
    return ops.mean(dice)


def bce_dice_loss(y_true, y_pred):
    bce = keras.losses.binary_crossentropy(y_true, y_pred)
    dice_loss = 1.0 - dice_coef(y_true, y_pred)
    return 0.4 * ops.mean(bce) + 0.6 * dice_loss

In [16]:
def train_model(pairs_df):
    fits_paths = pairs_df["fits_path"].astype(str).tolist()
    mask_paths = pairs_df["mask_path"].astype(str).tolist()

    if len(fits_paths) == 0:
        raise RuntimeError("pairs_df is empty: no FITS-mask pairs to train on.")
    if len(fits_paths) != len(mask_paths):
        raise RuntimeError(
            f"Mismatch in pairs_df: {len(fits_paths)} FITS vs {len(mask_paths)} masks."
        )

    n_total = len(fits_paths)
    print(f"Training on {n_total} FITS-mask pairs")

    # 90/10 split
    n_val = max(1, int(0.1 * n_total))
    n_train = n_total - n_val

    train_fits = fits_paths[:n_train]
    train_masks = mask_paths[:n_train]
    val_fits = fits_paths[n_train:]
    val_masks = mask_paths[n_train:]

    # steps per epoch (integer)
    steps_per_epoch = max(1, n_train // BATCH_SIZE)
    val_steps = max(1, n_val // BATCH_SIZE)

    train_gen = pair_generator(
        train_fits,
        train_masks,
        batch_size=BATCH_SIZE,
        augment=True,  # set False if you don't want augmentation
    )

    val_gen = pair_generator(
        val_fits,
        val_masks,
        batch_size=BATCH_SIZE,
        augment=False,
    )

    model = build_unet(base_filters=32)  # must be built with keras.layers, not tf.keras

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=bce_dice_loss,  # your keras.ops-based loss
        metrics=[dice_coef, "accuracy"],
    )

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            MODEL_PATH,
            monitor="val_loss",
            save_best_only=True,
            save_weights_only=False,
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=3,
            verbose=1,
        ),
    ]

    model.fit(
        train_gen,
        epochs=NUM_EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=val_gen,
        validation_steps=val_steps,
        callbacks=callbacks,
    )

    print(f"Training finished. Best model saved to {MODEL_PATH}")

### Model Training

In [17]:
def load_trained_model():
    custom_objects = {
        "bce_dice_loss": bce_dice_loss,
        "dice_coef": dice_coef,
    }
    model = keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
    return model

In [18]:
# train_model(train_df)

In [19]:
model = load_trained_model()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1763025204.127443 9056719 service.cc:145] XLA service 0x3141e7770 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1763025204.127467 9056719 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1763025204.131211 9056719 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1763025204.131227 9056719 mps_client.cc:384] XLA backend will use up to 12712722432 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.92 GB



### Testing

In [20]:
def apply_model_to_array(model, img2d, resize=False, img_size=IMG_SIZE):
    """
    img2d : 2D numpy array (H, W), already preprocessed (normalized etc.)
    model : trained Keras model (Keras 3, any backend)
    resize: if True and shape != (img_size, img_size), resizes with bilinear
    img_size: target size for the model input

    Returns
    -------
    prob_map : 2D numpy array (img_size, img_size) with values in [0, 1]
    """
    img = np.asarray(img2d, dtype=np.float32)

    if img.ndim != 2:
        raise ValueError(f"Expected 2D array, got shape {img.shape}")

    # Optional resize to model input size
    if resize and img.shape != (img_size, img_size):
        # Use float-preserving PIL mode "F", bilinear interpolation
        pil_img = PIL.Image.fromarray(img.astype(np.float32), mode="F")
        pil_img = pil_img.resize((img_size, img_size), resample=PIL.Image.BILINEAR)
        img = np.array(pil_img, dtype=np.float32)

    # Add channel and batch dims: (H, W) -> (1, H, W, 1)
    x = img[np.newaxis, ..., np.newaxis]  # (1, H, W, 1)

    # Predict: output shape (1, H, W, 1) -> (H, W)
    prob = model.predict(x, verbose=0)[0, ..., 0]

    return prob

In [21]:
def mask_via_model(path):
    prob_map = apply_model_to_array(model, path, resize=True)
    mask = (prob_map > 0.5).astype(np.float32)  # binary mask if you want it
    mask_uint8 = (mask * 255).clip(0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(mask_uint8, mode="L")

    if img.size != (1024, 1024):
        img = img.resize((1024, 1024), resample=PIL.Image.NEAREST)

    return img

In [22]:
def prob_map_via_model(path):
    prob_map = apply_model_to_array(model, path, resize=True)
    img_uint8 = np.clip(prob_map * 255, 0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(img_uint8, mode="L")
    return img

In [23]:
cmap = ct.aia_color_table(u.Quantity(193, "Angstrom"))
# cmap = "gray"


def plot_mask(row, sdo=False):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(np.array(mask_via_model(prepare_fits(row.fits_path))), cmap=cmap)
    plt.title("helio-n")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(np.array(PIL.Image.open(row.mask_path)), cmap=cmap)
    plt.title("IDL")
    plt.axis("off")

    plt.tight_layout()
    plt.show()


def plot_sdo(row):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(prepare_fits(row.fits_path), cmap=cmap)
    plt.contour(
        np.array(mask_via_model(prepare_fits(row.fits_path))),
        levels=[0.5],
        colors="red",
        linewidths=1.5,
    )
    plt.title("helio-n")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(prepare_fits(row.fits_path), cmap=cmap)
    plt.contour(
        np.array(prepare_mask(row.mask_path)),
        levels=[0.5],
        colors="red",
        linewidths=1.5,
    )
    plt.title("IDL")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# 1. Register your dataframes here
dfs = {
    "train": train_df,
    "inferrence": inf_df,
}

# 2. Widgets
df_selector = widgets.RadioButtons(
    options=list(dfs.keys()),
    value="train",
    description="DataFrame:",
)

idx_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dfs["train"]) - 1,
    step=1,
    description="Index:",
    continuous_update=False,
)

show_mask_checkbox = widgets.Checkbox(
    value=True,
    description="Show mask (off = SDO)",
)

out = widgets.Output()


# 3. Update slider range when DF changes
def update_slider_range(change):
    df = dfs[df_selector.value]
    idx_slider.max = max(0, len(df) - 1)
    if idx_slider.value > idx_slider.max:
        idx_slider.value = idx_slider.max


df_selector.observe(update_slider_range, names="value")


# 4. Main update function
def update_plot(change=None):
    with out:
        clear_output(wait=True)
        df = dfs[df_selector.value]
        if len(df) == 0:
            print("Selected dataframe is empty.")
            return

        row = df.iloc[idx_slider.value]

        if show_mask_checkbox.value:
            # You implement this: should plot SDO + NN/IDL mask
            plot_mask(row)
        else:
            # You implement this: should plot just the SDO / FITS-based view
            plot_sdo(row)


# 5. Hook up callbacks
idx_slider.observe(update_plot, names="value")
show_mask_checkbox.observe(update_plot, names="value")
df_selector.observe(update_plot, names="value")

# 6. Display the UI
ui = widgets.VBox(
    [
        df_selector,
        idx_slider,
        show_mask_checkbox,
        out,
    ]
)

display(ui)

# Initial draw
update_slider_range(None)
update_plot(None)

VBox(children=(RadioButtons(description='DataFrame:', options=('train', 'inferrence'), value='train'), IntSlid…