# Setup

In [None]:
import os
import glob
import random
import json

import numpy as np
import pandas as pd

import PIL
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from sunpy.map.maputils import all_coordinates_from_map, coordinate_is_on_solar_disk
import sunpy.visualization.colormaps.color_tables as ct
import sunpy.map
import astropy.units as u
from sunpy.coordinates import frames

from skimage.morphology import (
    binary_closing,
    disk,
    remove_small_objects,
    remove_small_holes,
)

from skimage import measure
import mahotas

import ipywidgets as widgets
from IPython.display import display, clear_output

In [None]:
# onnx
# https://developer.apple.com/metal/jax/ (with keras via
# multiscale sampling

In [None]:
os.environ["KERAS_BACKEND"] = "jax"
import keras

# keras.config.set_backend("jax")
from keras import layers, ops

In [176]:
with open("./Config/Training Params.json", "r") as f:
    model_params = json.load(f)
model_params

{'img_size': 256, 'batch_size': 4, 'num_epochs': 30, 'learning_rate': 0.0001}

In [177]:
with open("./Config/Smoothing Params.json", "r") as f:
    smoothing_params = json.load(f)
smoothing_params

{'threshold': 0.32, 'closing_radius': 1, 'min_size': 150, 'hole_size': 1000.0}

# Data Processing

In [None]:
FITS_ROOT = "/Volumes/JetDrive 330/SDO Data/FITS"
MASKS_ROOT = "/Volumes/JetDrive 330/SDO Data/Masks"

MODEL_PATH = "./Outputs/model_CH_UNet.keras"

In [None]:
def mask_fits(f):
    hpc_coords = all_coordinates_from_map(f)
    mask = coordinate_is_on_solar_disk(hpc_coords)
    palette = f.cmap.copy()
    palette.set_bad("black")
    scaled_map = sunpy.map.Map(f.data, f.meta, mask=~mask)
    ff = scaled_map.data
    return ff


def prepare_fits(path, mask_disk=True, clip_low=1, clip_high=99):
    f = sunpy.map.Map(path)
    if mask_disk:
        ff = mask_fits(f).data
    else:
        ff = f.data
    ff = np.flipud(ff)

    low = np.percentile(ff, clip_low)
    high = np.percentile(ff, clip_high)
    ff = np.clip(ff, low, high)
    ff = (ff - low) / (high - low + 1e-6)
    return ff

In [None]:
def prepare_hmi_jpg(jpg_path, target_size=1024):
    """
    Load 512px HMI JPEG (grayscale), upscale to target_size.
    Return a float32 array.
    """
    img = PIL.Image.open(jpg_path).convert("F")  # 32-bit float
    img = img.resize((target_size, target_size), resample=PIL.Image.BILINEAR)
    arr = np.array(img, dtype=np.float32)
    return arr

In [None]:
def prepare_mask(path, preserve_255=False):
    from PIL import Image

    im = Image.open(path).convert("L")
    if not preserve_255:
        arr = np.array(im, dtype=np.float32)
        # Consider >127 as CH
        arr = (arr > 127).astype(np.float32)
    else:
        arr = np.array(im)
    return arr

In [None]:
def prepare_hmi_jpg(
    path="/Users/aosh/Library/Containers/net.langui.FTPMounter/Data/.FTPVolumes/dec1/mnt/sun/sdo/hmi/L0/2017/08/31/20170831_171500_M_512.jpg",
    target_size=(1024, 1024),
):
    """
    Load an HMI JPG (512×512 or similar) and upscale to AIA grid size.
    Returns float32 magnetogram-like values in range [-1,1] based on brightness.
    """
    im = Image.open(path).convert("L")  # grayscale JPG
    im = im.resize(target_size, Image.BILINEAR)
    arr = np.array(im, dtype=np.float32)

    # Convert brightness → rough polarity proxy:
    # 0 = black → strong negative
    # 255 = white → strong positive
    # midpoint 128 → zero-ish
    arr = (arr - 128.0) / 128.0  # now in approx [-1, +1]

    return arr

In [None]:
pd.set_option("display.width", 1000)


def prepare_dataset(fits_root, masks_root):
    def index(p):
        return p.split("/")[-1][3:16]

    # Collect all FITS and masks
    fits_files = glob.glob(os.path.join(fits_root, "**", "*.fits"), recursive=True)
    mask_files = glob.glob(
        os.path.join(masks_root, "**", "*_CH_MASK_FINAL.png"), recursive=True
    )

    df_fits = pd.DataFrame(
        {
            "key": [index(p) for p in fits_files],
            "fits_path": fits_files,
        }
    )

    df_masks = pd.DataFrame(
        {
            "key": [index(p) for p in mask_files],
            "mask_path": mask_files,
        }
    )

    # Optional: detect duplicate keys (same timestamp, multiple files)
    dup_fits = df_fits[df_fits.duplicated("key", keep=False)]
    dup_masks = df_masks[df_masks.duplicated("key", keep=False)]

    if not dup_fits.empty:
        print("⚠ Duplicate keys in FITS:")
        print(dup_fits.sort_values("key"))

    if not dup_masks.empty:
        print("⚠ Duplicate keys in masks:")
        print(dup_masks.sort_values("key"))

    df_fits = df_fits.drop_duplicates("key", keep="first")
    df_masks = df_masks.drop_duplicates("key", keep="first")

    # Outer join to see everything in one table
    merged = df_fits.merge(df_masks, on="key", how="outer", indicator=True)

    matches = merged[merged["_merge"] == "both"].copy()
    fits_only = merged[merged["_merge"] == "left_only"].copy()
    masks_only = merged[merged["_merge"] == "right_only"].copy()

    for df in (matches, fits_only, masks_only):
        df.drop(columns=["_merge"], inplace=True)

    matches.set_index(matches.key, inplace=True, drop=True)
    matches.drop(["key"], axis=1, inplace=True)
    matches["pmap_path"] = ""

    return matches, fits_only, masks_only

In [None]:
df = prepare_dataset(FITS_ROOT, MASKS_ROOT)[0]

In [None]:
train_df = df["20170501":"20170801"]

In [None]:
# set-wise subtraction
inf_df = df.loc[~df.index.isin(train_df.index)]

# Network Architecture

## Data

In [None]:
def augment_pair(img, mask):
    # --- random horizontal flip ---
    if np.random.rand() < 0.5:
        img = img[:, ::-1, :]
        mask = mask[:, ::-1, :]

    # --- small random brightness scaling ---
    scale = np.random.uniform(0.9, 1.1)  # ±10 %
    img = img * scale
    img = np.clip(img, 0.0, 1.0)

    return img, mask

In [None]:
def load_pair(fits_path, mask_path):
    # decode paths if coming in as bytes
    if isinstance(fits_path, (bytes, np.bytes_)):
        fits_path = fits_path.decode("utf-8")
    if isinstance(mask_path, (bytes, np.bytes_)):
        mask_path = mask_path.decode("utf-8")

    # load 2-D arrays
    img = np.asarray(prepare_fits(fits_path), dtype=np.float32)
    mask = np.asarray(prepare_mask(mask_path), dtype=np.float32)

    # resize
    img_resized = PIL.Image.fromarray((img * 255).astype(np.uint8)).resize(
        (model_params["img_size"], model_params["img_size"]),
        resample=PIL.Image.BILINEAR,
    )
    mask_resized = PIL.Image.fromarray((mask * 255).astype(np.uint8)).resize(
        (model_params["img_size"], model_params["img_size"]), resample=PIL.Image.NEAREST
    )

    # normalize back to [0,1] and add channel axis
    img = np.expand_dims(np.array(img_resized, dtype=np.float32) / 255.0, axis=-1)
    mask = np.expand_dims(np.array(mask_resized, dtype=np.float32) / 255.0, axis=-1)
    mask = (mask > 0.5).astype(np.float32)

    return img, mask

In [None]:
def pair_generator(fits_paths, mask_paths, augment=False):
    n = len(fits_paths)
    idxs = np.arange(n)

    while True:
        np.random.shuffle(idxs)
        for i in range(0, n, model_params["batch_size"]):
            batch_idx = idxs[i : i + model_params["batch_size"]]
            imgs, masks = [], []
            for j in batch_idx:
                img, mask = load_pair(fits_paths[j], mask_paths[j])
                if augment:
                    img, mask = augment_pair(img, mask)

                imgs.append(img)
                masks.append(mask)

            if not imgs:
                continue

            X = np.stack(imgs, axis=0).astype(np.float32)  # (B, H, W, 1)
            Y = np.stack(masks, axis=0).astype(np.float32)  # (B, H, W, 1)
            yield X, Y

## Model

In [None]:
def double_conv(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [None]:
def build_unet(
    input_shape=(model_params["img_size"], model_params["img_size"], 1), base_filters=32
):
    inputs = keras.Input(shape=input_shape)

    # ----- Encoder -----
    c1 = double_conv(inputs, base_filters)  # 256 x 256,  f
    p1 = layers.MaxPool2D(2)(c1)  # 128 x 128

    c2 = double_conv(p1, base_filters * 2)  # 128 x 128,  2f
    p2 = layers.MaxPool2D(2)(c2)  # 64 x 64

    c3 = double_conv(p2, base_filters * 4)  # 64 x 64,    4f
    p3 = layers.MaxPool2D(2)(c3)  # 32 x 32

    c4 = double_conv(p3, base_filters * 8)  # 32 x 32,    8f
    p4 = layers.MaxPool2D(2)(c4)  # 16 x 16

    # extra encoder level
    c5 = double_conv(p4, base_filters * 16)  # 16 x 16,   16f
    p5 = layers.MaxPool2D(2)(c5)  # 8 x 8

    # ----- Bottleneck -----
    bn = double_conv(
        p5, base_filters * 16
    )  # keep 16f; base*32 is also possible but heavier

    # ----- Decoder -----
    u5 = layers.Conv2DTranspose(base_filters * 16, 2, strides=2, padding="same")(bn)
    u5 = layers.Concatenate()([u5, c5])
    c6 = double_conv(u5, base_filters * 16)

    u4 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding="same")(c6)
    u4 = layers.Concatenate()([u4, c4])
    c7 = double_conv(u4, base_filters * 8)

    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(c7)
    u3 = layers.Concatenate()([u3, c3])
    c8 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c8)
    u2 = layers.Concatenate()([u2, c2])
    c9 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c9)
    u1 = layers.Concatenate()([u1, c1])
    c10 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c10)

    model = keras.Model(inputs, outputs, name="CH_UNet_5lvl")
    return model

In [None]:
def dice_coef(y_true, y_pred, smooth=1.0):
    # flatten per-sample
    y_true_f = ops.reshape(y_true, (ops.shape(y_true)[0], -1))
    y_pred_f = ops.reshape(y_pred, (ops.shape(y_pred)[0], -1))

    intersection = ops.sum(y_true_f * y_pred_f, axis=1)
    denom = ops.sum(y_true_f, axis=1) + ops.sum(y_pred_f, axis=1)

    dice = (2.0 * intersection + smooth) / (denom + smooth)
    return ops.mean(dice)


def bce_dice_loss(y_true, y_pred):
    bce = keras.losses.binary_crossentropy(y_true, y_pred)
    dice_loss = 1.0 - dice_coef(y_true, y_pred)
    return 0.4 * ops.mean(bce) + 0.6 * dice_loss

In [None]:
def train_model(pairs_df):
    fits_paths = pairs_df["fits_path"].astype(str).tolist()
    mask_paths = pairs_df["mask_path"].astype(str).tolist()

    if len(fits_paths) == 0:
        raise RuntimeError("pairs_df is empty: no FITS-mask pairs to train on.")
    if len(fits_paths) != len(mask_paths):
        raise RuntimeError(
            f"Mismatch in pairs_df: {len(fits_paths)} FITS vs {len(mask_paths)} masks."
        )

    n_total = len(fits_paths)
    print(f"Training on {n_total} FITS-mask pairs")

    # 90/10 split
    n_val = max(1, int(0.1 * n_total))
    n_train = n_total - n_val

    train_fits = fits_paths[:n_train]
    train_masks = mask_paths[:n_train]
    val_fits = fits_paths[n_train:]
    val_masks = mask_paths[n_train:]

    # --- use every 3rd training sample ---
    train_fits = train_fits[::3]
    train_masks = train_masks[::3]

    # steps per epoch (integer)
    steps_per_epoch = max(1, n_train // model_params["batch_size"])
    val_steps = max(1, n_val // model_params["batch_size"])

    train_gen = pair_generator(
        train_fits,
        train_masks,
        augment=True,
    )

    val_gen = pair_generator(
        val_fits,
        val_masks,
        augment=False,
    )

    model = build_unet(base_filters=80)  # must be built with keras.layers, not tf.keras

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=model_params["learning_rate"]),
        loss=bce_dice_loss,  # your keras.ops-based loss
        metrics=[dice_coef, "accuracy"],
    )

    early_stop = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
    )

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            MODEL_PATH,
            monitor="val_loss",
            save_best_only=True,
            save_weights_only=False,
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=3,
            verbose=1,
        ),
        early_stop,
    ]

    model.fit(
        train_gen,
        epochs=model_params["num_epochs"],
        steps_per_epoch=steps_per_epoch,
        validation_data=val_gen,
        validation_steps=val_steps,
        callbacks=callbacks,
    )

    print(f"Training finished. Best model saved to {MODEL_PATH}")

## Training

In [None]:
def load_trained_model():
    custom_objects = {
        "bce_dice_loss": bce_dice_loss,
        "dice_coef": dice_coef,
    }
    model = keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
    return model

In [None]:
# train_model(train_df)

In [None]:
model = load_trained_model()
!notify "Model loaded"

# Model Application Definitions

In [None]:
def fits_to_pmap(img2d, resize=False, img_size=model_params["img_size"]):
    img = np.asarray(img2d, dtype=np.float32)

    if img.ndim != 2:
        raise ValueError(f"Expected 2D array, got shape {img.shape}")

    if resize and img.shape != (img_size, img_size):
        # Use float-preserving PIL mode "F", bilinear interpolation
        pil_img = PIL.Image.fromarray(img.astype(np.float32), mode="F")
        pil_img = pil_img.resize((img_size, img_size), resample=PIL.Image.BILINEAR)
        img = np.array(pil_img, dtype=np.float32)

    x = img[np.newaxis, ..., np.newaxis]  # (1, H, W, 1)

    prob = model.predict(x, verbose=0)[0, ..., 0]

    return prob

In [None]:
def pmap_to_mask(pmap, smoothing_params=smoothing_params, save=False):
    mask = pmap > smoothing_params["threshold"]  # binary mask

    if smoothing_params["closing_radius"] > 0:
        mask = binary_closing(mask, disk(smoothing_params["closing_radius"]))

    if smoothing_params["min_size"] > 0:
        mask = remove_small_objects(mask, min_size=smoothing_params["min_size"])

    if smoothing_params["hole_size"] > 0:
        mask = remove_small_holes(mask, area_threshold=smoothing_params["hole_size"])

    # mask_uint8 = (mask * 255).clip(0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(mask)

    if img.size != (1024, 1024):
        img = img.resize((1024, 1024), resample=PIL.Image.NEAREST)

    return np.array(img)

In [None]:
def save_pmap(row, pmap=None):
    path = row.mask_path.replace("CH_MASK_FINAL.png", "UNET_PMAP.npy")
    if pmap is None:
        pmap = fits_to_pmap(prepare_fits(row.fits_path))
    np.save(path, pmap)
    return path

# Testing

## Similarity Metrics

In [None]:
def _ensure_binary_mask(mask):
    """Convert mask to boolean 2D array."""
    mask = np.asarray(mask)
    if mask.ndim != 2:
        raise ValueError(f"mask must be 2D, got shape {mask.shape}")
    if mask.dtype == bool:
        return mask
    return mask > 0.5

In [None]:
def compute_zernike_descriptor(mask, degree=8):
    """
    Compute Zernike-moment-based shape descriptor for a binary mask,
    using mahotas.features.zernike_moments.

    Parameters
    ----------
    mask : 2D array-like (bool or 0/1)
        Binary mask of a single region (or union of CHs).
    degree : int
        Maximum Zernike polynomial degree (typical: 6–12).
        This 'degree' in mahotas is 'n_max' in the literature.

    Returns
    -------
    desc : np.ndarray, shape (M,)
        Rotation-invariant Zernike descriptor (magnitudes), L2-normalized.
        If the mask is empty, returns a zero vector of length 1.
    """
    mask = _ensure_binary_mask(mask)

    # --- crop to bounding box of the region ---
    ys, xs = np.nonzero(mask)
    if len(xs) == 0:
        # empty mask
        return np.zeros(1, dtype=float)

    y_min, y_max = ys.min(), ys.max()
    x_min, x_max = xs.min(), xs.max()
    cropped = mask[y_min : y_max + 1, x_min : x_max + 1]

    h, w = cropped.shape
    # make it square by padding (centered)
    size = max(h, w)
    pad_y = (size - h) // 2
    pad_x = (size - w) // 2

    square = np.zeros((size, size), dtype=float)
    square[pad_y : pad_y + h, pad_x : pad_x + w] = cropped.astype(float)

    # mahotas assumes the Zernike circle is centered in the image
    radius = size // 2

    # mahotas.features.zernike_moments returns a 1D array of (real) magnitudes,
    # already rotation-invariant
    zm = mahotas.features.zernike_moments(square, radius, degree)

    desc = np.asarray(zm, dtype=float)

    # Optional: L2 normalize to make scale of descriptor comparable across regions
    norm = np.linalg.norm(desc)
    if norm > 0:
        desc = desc / norm

    return desc

In [None]:
def compute_fourier_descriptor(mask, num_descriptors=20, n_samples=256):
    """
    Compute Fourier shape descriptor from the boundary of a binary mask.

    Parameters
    ----------
    mask : 2D array-like (bool or 0/1)
        Binary mask of a region (or union of CHs).
    num_descriptors : int
        Number of low-frequency coefficients to keep (excluding DC).
    n_samples : int
        Number of boundary points to resample to (uniform along contour).

    Returns
    -------
    desc : np.ndarray, shape (num_descriptors,)
        Rotation/translation/starting-point invariant boundary descriptor,
        based on magnitudes of low-frequency Fourier coefficients.
    """
    mask = _ensure_binary_mask(mask)

    # --- find contours at 0.5 ---
    contours = measure.find_contours(mask.astype(float), level=0.5)
    if len(contours) == 0:
        # empty mask
        return np.zeros(num_descriptors, dtype=float)

    # choose the longest contour (largest region)
    contour = max(contours, key=lambda c: c.shape[0])

    # contour: array of shape (N, 2) with (row, col) = (y, x)
    ys, xs = contour[:, 0], contour[:, 1]

    # --- resample to fixed number of points along the contour length ---
    # compute cumulative distance along contour
    dy = np.diff(ys)
    dx = np.diff(xs)
    dists = np.sqrt(dx**2 + dy**2)
    cumlen = np.concatenate([[0], np.cumsum(dists)])
    total_len = cumlen[-1]

    if total_len == 0:
        return np.zeros(num_descriptors, dtype=float)

    # new parameterization from 0 to total_len
    target = np.linspace(0, total_len, n_samples, endpoint=False)
    # interpolate x(t), y(t)
    xs_resampled = np.interp(target, cumlen, xs)
    ys_resampled = np.interp(target, cumlen, ys)

    # --- build complex sequence and normalize ---
    z = xs_resampled + 1j * ys_resampled

    # translation invariance: subtract centroid
    z = z - z.mean()

    # scale invariance: normalize by RMS radius
    scale = np.sqrt(np.mean(np.abs(z) ** 2))
    if scale > 0:
        z = z / scale

    # --- Fourier transform along contour index ---
    Z = np.fft.fft(z)
    # We ignore the DC term Z[0] (translation)
    # and use the first num_descriptors low-frequency terms.
    # For invariance to starting point & rotation, use magnitudes.
    # freq indices: 1..num_descriptors
    max_k = min(num_descriptors, len(Z) // 2)
    coeffs = Z[1 : max_k + 1]
    desc = np.abs(coeffs)

    # zero-pad if needed
    if len(desc) < num_descriptors:
        pad = np.zeros(num_descriptors - len(desc), dtype=float)
        desc = np.concatenate([desc, pad])

    # optional second normalization
    norm = np.linalg.norm(desc)
    if norm > 0:
        desc = desc / norm

    return desc

In [None]:
def iou(mask1, mask2):
    """
    Compute Intersection-over-Union (IoU) for two binary masks.

    Parameters
    ----------
    mask1, mask2 : array-like
        2D numpy arrays. Values can be {0,1}, {0,255}, float, or bool.

    Returns
    -------
    float
        IoU value in [0,1].
    """

    # Convert to boolean
    m1 = np.asarray(mask1) > 0.5
    m2 = np.asarray(mask2) > 0.5

    intersection = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()

    if union == 0:
        return 1.0 if intersection == 0 else 0.0

    return intersection / union

In [None]:
def dice(mask1, mask2):
    m1 = np.asarray(mask1) > 0.5
    m2 = np.asarray(mask2) > 0.5

    intersection = np.logical_and(m1, m2).sum()
    a1 = m1.sum()
    a2 = m2.sum()

    denom = a1 + a2
    if denom == 0:
        return 1.0

    return 2.0 * intersection / denom

In [None]:
def shape_distance(desc_a, desc_b, metric="l2"):
    """
    Compute distance between two descriptor vectors.

    Parameters
    ----------
    desc_a, desc_b : array-like
        Descriptor vectors (e.g., Zernike or Fourier descriptors).
    metric : {"l2", "l1", "cosine"}
        Distance / dissimilarity measure.

    Returns
    -------
    d : float
        Distance (larger = more dissimilar).
    """
    a = np.asarray(desc_a, dtype=float)
    b = np.asarray(desc_b, dtype=float)

    if metric == "l2":
        return np.linalg.norm(a - b)
    elif metric == "l1":
        return np.sum(np.abs(a - b))
    elif metric == "cosine":
        na = np.linalg.norm(a)
        nb = np.linalg.norm(b)
        if na == 0 or nb == 0:
            return 1.0  # maximal dissimilarity
        cos_sim = np.dot(a, b) / (na * nb)
        return 1.0 - cos_sim  # cosine distance
    else:
        raise ValueError(f"Unknown metric: {metric!r}")

In [None]:
range_x = [384, 640]
range_y = [256, 768]

In [None]:
def rect_area(mask):
    return mask[range_y[0] : range_y[1], range_x[0] : range_x[1]].flatten().sum()

## Coronal Hole Area

In [None]:
R_SUN = 6.957e8 * u.m  # solar radius

## Plotting

In [None]:
def stats(row, smoothing_params=smoothing_params, m2=None):
    m1 = prepare_mask(row.mask_path)
    if m2 is None:
        m2 = pmap_to_mask(fits_to_pmap(prepare_fits(row.fits_path)), smoothing_params)

    stats = {}

    stats["fourier_distance"] = shape_distance(
        compute_fourier_descriptor(m1),
        compute_fourier_descriptor(m2),
    )

    stats["zernike_distance"] = shape_distance(
        compute_zernike_descriptor(m1),
        compute_zernike_descriptor(m2),
    )

    stats["rel_area"] = 1 - (rect_area(m2) / rect_area(m1))

    stats["iou"] = iou(m1, m2)
    stats["dice"] = dice(m1, m2)

    return stats

In [None]:
cmap = ct.aia_color_table(u.Quantity(193, "Angstrom"))
# cmap = "gray"


def print_distance(row, smoothing_params=smoothing_params):
    s = stats(row, smoothing_params)

    print("Fourier Distance: ", s["fourier_distance"])
    print("Zernike Distance: ", s["zernike_distance"])
    print("Center CH Area Difference (non-projective): ", s["rel_area"])
    print("I over U: ", s["iou"])
    print("Dice: ", s["dice"])


def plot_mask(row, smoothing_params=smoothing_params):
    print(row)
    print_distance(row, smoothing_params)
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(
        pmap_to_mask(fits_to_pmap(prepare_fits(row.fits_path)), smoothing_params),
        cmap=cmap,
    )
    plt.gca().add_patch(
        Rectangle(
            [range_x[0], range_y[0]],
            range_x[1] - range_x[0],
            range_y[1] - range_y[0],
            linewidth=1,
            edgecolor="y",
            facecolor="none",
        )
    )
    plt.title("helio-n (U-Net)")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(np.array(PIL.Image.open(row.mask_path)), cmap=cmap)
    plt.gca().add_patch(
        Rectangle(
            [range_x[0], range_y[0]],
            range_x[1] - range_x[0],
            range_y[1] - range_y[0],
            linewidth=1,
            edgecolor="y",
            facecolor="none",
        )
    )
    plt.title("IDL")
    plt.axis("off")

    plt.tight_layout()
    plt.show()


def plot_sdo(row, smoothing_params):
    print(row)
    print_distance(row, smoothing_params)
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(prepare_fits(row.fits_path), cmap=cmap)
    plt.contour(
        pmap_to_mask(fits_to_pmap(prepare_fits(row.fits_path)), smoothing_params),
        levels=[0.5],
        colors="red",
        linewidths=1.5,
    )
    plt.gca().add_patch(
        Rectangle(
            (range_x[0], range_y[0]),
            range_x[1] - range_x[0],
            range_y[1] - range_y[0],
            linewidth=1,
            edgecolor="y",
            facecolor="none",
        )
    )

    plt.title("helio-n (U-Net)")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(prepare_fits(row.fits_path), cmap=cmap)
    plt.contour(
        np.array(prepare_mask(row.mask_path)),
        levels=[0.5],
        colors="red",
        linewidths=1.5,
    )
    plt.gca().add_patch(
        Rectangle(
            (range_x[0], range_y[0]),
            range_x[1] - range_x[0],
            range_y[1] - range_y[0],
            linewidth=1,
            edgecolor="y",
            facecolor="none",
        )
    )

    plt.title("IDL")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

## UI

In [None]:
# 1. Register your dataframes here
dfs = {
    "train": train_df,
    "inference": inf_df,
}

# 2. Widgets
df_selector = widgets.RadioButtons(
    options=list(dfs.keys()),
    value="inference",
    description="DataFrame:",
)

idx_slider = widgets.IntSlider(
    value=random.randint(0, len(dfs["train"]) - 1),
    min=0,
    max=len(dfs["train"]) - 1,
    step=1,
    description="Index:",
    continuous_update=False,
)

show_mask_checkbox = widgets.Checkbox(
    value=False,
    description="Show Mask Only",
)

# New sliders
threshold_slider = widgets.FloatSlider(
    value=smoothing_params["threshold"],
    min=0.0,
    max=1.0,
    step=0.01,
    description="Threshold",
    continuous_update=False,
)

closing_radius_slider = widgets.IntSlider(
    value=smoothing_params["closing_radius"],
    min=0,
    max=20,
    step=1,
    description="Closing R",
    continuous_update=False,
)

min_size_slider = widgets.IntSlider(
    value=smoothing_params["min_size"],
    min=0,
    max=2000,
    step=10,
    description="Min size",
    continuous_update=False,
)

hole_size_slider = widgets.FloatSlider(
    value=smoothing_params["hole_size"],
    min=0.0,
    max=5000,
    step=50,
    description="Hole area",
    continuous_update=False,
)

out = widgets.Output()


# 3. Update slider range when DF changes
def update_slider_range(change):
    df = dfs[df_selector.value]
    idx_slider.max = max(0, len(df) - 1)
    if idx_slider.value > idx_slider.max:
        idx_slider.value = idx_slider.max


df_selector.observe(update_slider_range, names="value")


# 4. Main update function
def update_plot(change=None):
    global smoothing_params
    with out:
        clear_output(wait=True)
        df = dfs[df_selector.value]
        if len(df) == 0:
            print("Selected dataframe is empty.")
            return

        row = df.iloc[idx_slider.value]

        thr = threshold_slider.value
        cr = closing_radius_slider.value
        ms = min_size_slider.value
        ha = hole_size_slider.value

        smoothing_params = {
            "threshold": thr,
            "closing_radius": cr,
            "min_size": ms,
            "hole_size": ha,
        }

        if show_mask_checkbox.value:
            # Should plot SDO + NN/IDL mask
            plot_mask(
                row,
                {
                    "threshold": thr,
                    "closing_radius": cr,
                    "min_size": ms,
                    "hole_size": ha,
                },
            )
        else:
            # Should plot just the SDO / FITS-based view
            plot_sdo(
                row,
                {
                    "threshold": thr,
                    "closing_radius": cr,
                    "min_size": ms,
                    "hole_size": ha,
                },
            )

    with open("./Config/smoothing_params.json", "w") as f:
        json.dump(smoothing_params, f)


# 5. Hook up callbacks
idx_slider.observe(update_plot, names="value")
show_mask_checkbox.observe(update_plot, names="value")
df_selector.observe(update_plot, names="value")
threshold_slider.observe(update_plot, names="value")
closing_radius_slider.observe(update_plot, names="value")
min_size_slider.observe(update_plot, names="value")
hole_size_slider.observe(update_plot, names="value")

# 6. Display the UI
controls_top = widgets.HBox(
    [
        df_selector,
        idx_slider,
        show_mask_checkbox,
    ]
)

controls_bottom = widgets.HBox(
    [
        threshold_slider,
        closing_radius_slider,
        min_size_slider,
        hole_size_slider,
    ]
)

ui = widgets.VBox(
    [
        controls_top,
        controls_bottom,
        out,
    ]
)

display(ui)

# Initial draw
update_slider_range(None)
update_plot(None)

In [None]:
with open("./Outputs/smoothing_params.json", "w") as f:
    json.dump(smoothing_params, f)

In [None]:
df.iloc[1]

In [None]:
ddf = df.iloc[:10]

Unnamed: 0_level_0,fits_path,mask_path,pmap_path
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
20160101_0000,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0002,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0046,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0100,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0102,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0154,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0200,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0202,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0246,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,
20160101_0300,/Volumes/JetDrive 330/SDO Data/FITS/2016/01/01...,/Volumes/JetDrive 330/SDO Data/Masks/2016/01/A...,


In [172]:
ddf.mask_path = ddf.apply(save_pmap, axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ddf.mask_path = ddf.apply(save_pmap, axis=1)


In [None]:
# inf_df["nn_map"] = inf_df.apply(lambda x: mask_via_pmap(prepare_fits(x.fits_path)), axis=1)

In [None]:
# inf_df["stats"] = inf_df.apply(lambda x: stats(x, m2=x.nn_map), axis=1)

In [None]:
def plot_with_polarity(row, smoothing_params, B_thresh=0.15):
    """
    row: DataFrame row with row.fits_path and row.mask_path and row.hmi_path (JPG)
    smoothing_params: mask-via-model parameters
    """

    aia = prepare_fits(row.fits_path)  # (1024×1024, normalized)
    hmi = prepare_hmi_jpg()  # upscale JPG → match AIA grid

    nn_mask_raw = mask_via_pmap(row, smoothing_params)
    nn_mask = nn_mask_raw > 0.5

    idl_mask_raw = prepare_mask(row.mask_path)
    idl_mask = idl_mask_raw > 0.5

    # polarity masks using pseudo-HMI field
    nn_pos = nn_mask & (hmi >= B_thresh)
    nn_neg = nn_mask & (hmi <= -B_thresh)

    idl_pos = idl_mask & (hmi >= B_thresh)
    idl_neg = idl_mask & (hmi <= -B_thresh)

    # build RGBA overlays
    def make_overlay(pos, neg, alpha=0.5):
        h, w = pos.shape
        rgba_pos = np.zeros((h, w, 4), dtype=np.float32)
        rgba_neg = np.zeros((h, w, 4), dtype=np.float32)

        rgba_pos[..., 0] = 1.0  # red
        rgba_pos[..., 3] = alpha * pos.astype(float)

        rgba_neg[..., 2] = 1.0  # blue
        rgba_neg[..., 3] = alpha * neg.astype(float)

        return rgba_pos, rgba_neg

    nn_overlay_pos, nn_overlay_neg = make_overlay(nn_pos, nn_neg)
    idl_overlay_pos, idl_overlay_neg = make_overlay(idl_pos, idl_neg)

    # plots
    print_distance(row, smoothing_params)
    plt.figure(figsize=(10, 5))

    # ------------------ U-Net ------------------
    plt.subplot(1, 2, 1)
    plt.imshow(aia, cmap=cmap)
    plt.contour(nn_mask.astype(float), levels=[0.5], colors="red")
    plt.imshow(nn_overlay_pos)
    plt.imshow(nn_overlay_neg)
    plt.title("helio-n: red=+, blue=-")
    plt.axis("off")

    # ------------------ IDL ------------------
    plt.subplot(1, 2, 2)
    plt.imshow(aia, cmap=cmap)
    plt.contour(idl_mask.astype(float), levels=[0.5], colors="red")
    plt.imshow(idl_overlay_pos)
    plt.imshow(idl_overlay_neg)
    plt.title("IDL: red=+, blue=-")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
plot_with_polarity(inf_df.iloc[13385], smoothing_params)