In [None]:
import os
import glob
import re
import numpy as np
import sunpy.map
from sunpy.map.maputils import all_coordinates_from_map, coordinate_is_on_solar_disk
import pandas as pd
import PIL
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [72]:
IMG_SIZE = 512
BATCH_SIZE = 4
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4

### Data Processing

In [73]:
TRAIN_FITS_ROOT = "./Data/Training/FITS"
TRAIN_MASKS_ROOT = "./Data/Training/Masks"

MODEL_PATH = "model_CH_UNet.h5"

INFER_FITS_ROOT = "./Data/Inferrence/FITS"
INFER_MASKS_ROOT = "./Data/Inferrence/Masks"

In [None]:
def mask_fits(f):
    hpc_coords = all_coordinates_from_map(f)
    mask = coordinate_is_on_solar_disk(hpc_coords)
    palette = f.cmap.copy()
    palette.set_bad("black")
    scaled_map = sunpy.map.Map(f.data, f.meta, mask=~mask)
    ff = scaled_map.data
    return ff


def prepare_fits(path, mask_disk=True, clip_low=1, clip_high=99):
    f = sunpy.map.Map(path)
    if mask_disk:
        ff = mask_fits(f).data
    else:
        ff = f.data

    low = np.percentile(ff, clip_low)
    high = np.percentile(ff, clip_high)
    ff = np.clip(ff, low, high)
    ff = (ff - low) / (high - low + 1e-6)
    return ff

In [75]:
def prepare_mask(path):
    """
    Read PNG mask (0/255-ish) and return 0/1 float32 array.

    We *don't* rename; we just read your existing _CH_MASK_FINAL.png.
    """
    from PIL import Image

    im = Image.open(path).convert("L")
    arr = np.array(im, dtype=np.float32)
    # Consider >127 as CH
    arr = (arr > 127).astype(np.float32)
    return arr

In [None]:
def prepare_dataset(fits_root, masks_root):
    def index(p):
        return p.split("/")[-1][3:16]

    fits_re = re.compile(r"AIA(\d{8})_\d{4,6}_(\d{4})\.fits$")
    mask_re = re.compile(r"AIA(\d{8})_\d{6}_(\d{4})_CH_MASK_FINAL\.png$")

    # Collect all FITS and masks
    fits_files = glob.glob(os.path.join(fits_root, "**", "*.fits"), recursive=True)
    mask_files = glob.glob(
        os.path.join(masks_root, "**", "*_CH_MASK_FINAL.png"), recursive=True
    )

    df_fits = pd.DataFrame(
        {
            "key": [index(p) for p in fits_files],
            "fits_path": fits_files,
        }
    )

    df_masks = pd.DataFrame(
        {
            "key": [index(p) for p in mask_files],
            "mask_path": mask_files,
        }
    )

    # Optional: detect duplicate keys (same timestamp, multiple files)
    dup_fits = df_fits[df_fits.duplicated("key", keep=False)]
    dup_masks = df_masks[df_masks.duplicated("key", keep=False)]

    if not dup_fits.empty:
        print("⚠ Duplicate keys in FITS:")
        print(dup_fits.sort_values("key"))

    if not dup_masks.empty:
        print("⚠ Duplicate keys in masks:")
        print(dup_masks.sort_values("key"))

    # Outer join to see everything in one table
    merged = df_fits.merge(df_masks, on="key", how="outer", indicator=True)

    matches = merged[merged["_merge"] == "both"].copy()
    fits_only = merged[merged["_merge"] == "left_only"].copy()
    masks_only = merged[merged["_merge"] == "right_only"].copy()

    for df in (matches, fits_only, masks_only):
        df.drop(columns=["_merge"], inplace=True)

    matches.set_index(matches.key, inplace=True, drop=True)
    matches.drop(["key"], axis=1, inplace=True)

    return matches, fits_only, masks_only

In [77]:
train_df = prepare_dataset(TRAIN_FITS_ROOT, TRAIN_MASKS_ROOT)[0]
inf_df = prepare_dataset(INFER_FITS_ROOT, INFER_MASKS_ROOT)[0]

⚠ Duplicate keys in FITS:
                key                                          fits_path
929   20170712_2154  ./Data/Training/FITS/2017/07/12/AIA20170712_21...
1018  20170712_2154  ./Data/Training/FITS/2017/07/12/AIA20170712_21...
⚠ Duplicate keys in FITS:
                key                                          fits_path
7128  20170125_2146  ./Data/Inferrence/FITS/2017/01/25/AIA20170125_...
7180  20170125_2146  ./Data/Inferrence/FITS/2017/01/25/AIA20170125_...


### Network Architecture

In [None]:
def augment_pair(img, mask):
    # --- random horizontal flip ---
    stacked = tf.concat([img, mask], axis=-1)  # (H, W, 2)
    stacked = tf.cond(
        tf.random.uniform(()) > 0.5,
        lambda: tf.image.flip_left_right(stacked),
        lambda: stacked,
    )
    img, mask = stacked[..., :1], stacked[..., 1:]

    # --- small random brightness scaling ---

    scale = tf.random.uniform((), 0.9, 1.1)  # ±10 %
    img = img * scale
    img = tf.clip_by_value(img, 0.0, 1.0)

    return img, mask

In [None]:
def _load_pair_numpy(fits_path, mask_path):
    # fits_path, mask_path come from tf.numpy_function as bytes or strings
    fits_path = (
        fits_path.decode("utf-8")
        if isinstance(fits_path, (bytes, np.bytes_))
        else fits_path
    )
    mask_path = (
        mask_path.decode("utf-8")
        if isinstance(mask_path, (bytes, np.bytes_))
        else mask_path
    )

    img = prepare_fits(fits_path)  # 2D np array
    mask = prepare_mask(mask_path)  # 2D np array

    img = np.asarray(img, dtype=np.float32)
    mask = np.asarray(mask, dtype=np.float32)

    if img.ndim == 2:
        img = img[..., np.newaxis]
    if mask.ndim == 2:
        mask = mask[..., np.newaxis]

    img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE), method="bilinear").numpy()
    mask = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE), method="nearest").numpy()

    return img, mask


def load_pair_tf(fits_path, mask_path):
    img, mask = tf.numpy_function(
        _load_pair_numpy,
        [fits_path, mask_path],
        [tf.float32, tf.float32],
    )

    img.set_shape((IMG_SIZE, IMG_SIZE, 1))
    mask.set_shape((IMG_SIZE, IMG_SIZE, 1))

    img, mask = augment_pair(img, mask)

    return img, mask

In [None]:
def double_conv(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [81]:
def build_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=32):
    inputs = keras.Input(shape=input_shape)

    # Encoder
    c1 = double_conv(inputs, base_filters)
    p1 = layers.MaxPool2D(2)(c1)

    c2 = double_conv(p1, base_filters * 2)
    p2 = layers.MaxPool2D(2)(c2)

    c3 = double_conv(p2, base_filters * 4)
    p3 = layers.MaxPool2D(2)(c3)

    c4 = double_conv(p3, base_filters * 8)
    p4 = layers.MaxPool2D(2)(c4)

    # Bottleneck
    bn = double_conv(p4, base_filters * 16)

    # Decoder
    u4 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding="same")(bn)
    u4 = layers.Concatenate()([u4, c4])
    c5 = double_conv(u4, base_filters * 8)

    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(c5)
    u3 = layers.Concatenate()([u3, c3])
    c6 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c6)
    u2 = layers.Concatenate()([u2, c2])
    c7 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c7)
    u1 = layers.Concatenate()([u1, c1])
    c8 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c8)

    model = keras.Model(inputs, outputs, name="CH_UNet")
    return model

In [None]:
def build_tiny_unet(input_shape=(IMG_SIZE, IMG_SIZE, 1), base_filters=16):
    """
    Smaller, dumber U-Net:
      - Only 3 downsampling levels
      - Much fewer filters
      - Cheaper to train, less capacity

    Suitable when you care more about speed than squeezing out last % IoU.
    """
    inputs = keras.Input(shape=input_shape)

    # Encoder (3 levels instead of 4–5)
    c1 = double_conv(inputs, base_filters)  #  H x  W
    p1 = layers.MaxPool2D(2)(c1)  # H/2 x W/2

    c2 = double_conv(p1, base_filters * 2)  # H/2 x W/2
    p2 = layers.MaxPool2D(2)(c2)  # H/4 x W/4

    c3 = double_conv(p2, base_filters * 4)  # H/4 x W/4
    p3 = layers.MaxPool2D(2)(c3)  # H/8 x W/8

    # Bottleneck
    bn = double_conv(p3, base_filters * 8)

    # Decoder
    u3 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding="same")(bn)
    u3 = layers.Concatenate()([u3, c3])
    c4 = double_conv(u3, base_filters * 4)

    u2 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding="same")(c4)
    u2 = layers.Concatenate()([u2, c2])
    c5 = double_conv(u2, base_filters * 2)

    u1 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding="same")(c5)
    u1 = layers.Concatenate()([u1, c1])
    c6 = double_conv(u1, base_filters)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c6)

    model = keras.Model(inputs, outputs, name="CH_TinyUNet")
    return model

In [83]:
def dice_coef(y_true, y_pred, smooth=1.0):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
    y_pred_f = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])

    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    denom = tf.reduce_sum(y_true_f, axis=1) + tf.reduce_sum(y_pred_f, axis=1)

    dice = (2.0 * intersection + smooth) / (denom + smooth)
    return tf.reduce_mean(dice)


def bce_dice_loss(y_true, y_pred):
    bce = keras.losses.binary_crossentropy(y_true, y_pred)
    return 0.4 * tf.reduce_mean(bce) + 0.6 * (1.0 - dice_coef(y_true, y_pred))

### Model Training

In [84]:
def load_trained_model():
    custom_objects = {
        "bce_dice_loss": bce_dice_loss,
        "dice_coef": dice_coef,
    }
    model = keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
    return model

In [None]:
def train_model(pairs_df):
    fits_paths = pairs_df["fits_path"].astype(str).tolist()
    mask_paths = pairs_df["mask_path"].astype(str).tolist()

    if len(fits_paths) == 0:
        raise RuntimeError("pairs_df is empty: no FITS-mask pairs to train on.")
    if len(fits_paths) != len(mask_paths):
        raise RuntimeError(
            f"Mismatch in pairs_df: {len(fits_paths)} FITS vs {len(mask_paths)} masks."
        )

    n_total = len(fits_paths)
    print(f"Training on {n_total} FITS-mask pairs")

    fits_paths_tf = tf.constant(fits_paths)
    mask_paths_tf = tf.constant(mask_paths)

    # 90/10 split
    n_val = max(1, int(0.1 * n_total))
    n_train = n_total - n_val

    train_fits = fits_paths_tf[:n_train]
    train_masks = mask_paths_tf[:n_train]
    val_fits = fits_paths_tf[n_train:]
    val_masks = mask_paths_tf[n_train:]

    # steps per epoch (integer)
    steps_per_epoch = n_train // BATCH_SIZE
    val_steps = max(1, n_val // BATCH_SIZE)

    # --- datasets ---

    train_ds = (
        tf.data.Dataset.from_tensor_slices((train_fits, train_masks))
        .shuffle(buffer_size=n_train)
        .map(
            lambda f, m: load_pair_tf(f, m),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .repeat()  # important: infinite stream
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    val_ds = (
        tf.data.Dataset.from_tensor_slices((val_fits, val_masks))
        .map(
            lambda f, m: load_pair_tf(f, m),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .repeat()  # also infinite so Keras can do val_steps each epoch
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    # --- model ---

    model = build_unet(base_filters=32)

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=bce_dice_loss,
        metrics=[dice_coef, "accuracy"],
    )

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            MODEL_PATH,
            monitor="val_loss",
            save_best_only=True,
            save_weights_only=False,
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=3,
            verbose=1,
        ),
    ]

    model.fit(
        train_ds,
        epochs=NUM_EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=val_ds,
        validation_steps=val_steps,
        callbacks=callbacks,
    )

    print(f"Training finished. Best model saved to {MODEL_PATH}")

In [86]:
train_model(train_df)

Training on 2364 FITS-mask pairs
Epoch 1/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 630ms/step - accuracy: 0.8440 - dice_coef: 0.0789 - loss: 0.7294



[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 671ms/step - accuracy: 0.9347 - dice_coef: 0.1098 - loss: 0.6641 - val_accuracy: 0.9261 - val_dice_coef: 0.1521 - val_loss: 0.6352 - learning_rate: 1.0000e-04
Epoch 2/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 603ms/step - accuracy: 0.9760 - dice_coef: 0.1976 - loss: 0.5562



[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m342s[0m 642ms/step - accuracy: 0.9794 - dice_coef: 0.2308 - loss: 0.5263 - val_accuracy: 0.9662 - val_dice_coef: 0.1839 - val_loss: 0.5540 - learning_rate: 1.0000e-04
Epoch 3/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m343s[0m 642ms/step - accuracy: 0.9870 - dice_coef: 0.3790 - loss: 0.4087 - val_accuracy: 0.9672 - val_dice_coef: 0.1436 - val_loss: 0.5687 - learning_rate: 1.0000e-04
Epoch 4/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 614ms/step - accuracy: 0.9899 - dice_coef: 0.5011 - loss: 0.3235



[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m348s[0m 652ms/step - accuracy: 0.9904 - dice_coef: 0.5436 - loss: 0.2956 - val_accuracy: 0.9695 - val_dice_coef: 0.2684 - val_loss: 0.4924 - learning_rate: 1.0000e-04
Epoch 5/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m364s[0m 683ms/step - accuracy: 0.9923 - dice_coef: 0.6794 - loss: 0.2076 - val_accuracy: 0.9683 - val_dice_coef: 0.2478 - val_loss: 0.5121 - learning_rate: 1.0000e-04
Epoch 6/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 621ms/step - accuracy: 0.9931 - dice_coef: 0.7447 - loss: 0.1661



[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m353s[0m 661ms/step - accuracy: 0.9932 - dice_coef: 0.7597 - loss: 0.1568 - val_accuracy: 0.9680 - val_dice_coef: 0.3071 - val_loss: 0.4822 - learning_rate: 1.0000e-04
Epoch 7/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m348s[0m 653ms/step - accuracy: 0.9938 - dice_coef: 0.8083 - loss: 0.1262 - val_accuracy: 0.9686 - val_dice_coef: 0.2994 - val_loss: 0.4915 - learning_rate: 1.0000e-04
Epoch 8/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m344s[0m 644ms/step - accuracy: 0.9942 - dice_coef: 0.8369 - loss: 0.1082 - val_accuracy: 0.9693 - val_dice_coef: 0.3126 - val_loss: 0.4867 - learning_rate: 1.0000e-04
Epoch 9/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 605ms/step - accuracy: 0.9945 - dice_coef: 0.8521 - loss: 0.0987



[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m343s[0m 643ms/step - accuracy: 0.9946 - dice_coef: 0.8560 - loss: 0.0962 - val_accuracy: 0.9692 - val_dice_coef: 0.3932 - val_loss: 0.4402 - learning_rate: 1.0000e-04
Epoch 10/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m351s[0m 659ms/step - accuracy: 0.9948 - dice_coef: 0.8679 - loss: 0.0886 - val_accuracy: 0.9693 - val_dice_coef: 0.3551 - val_loss: 0.4674 - learning_rate: 1.0000e-04
Epoch 11/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m351s[0m 658ms/step - accuracy: 0.9949 - dice_coef: 0.8741 - loss: 0.0847 - val_accuracy: 0.9672 - val_dice_coef: 0.2490 - val_loss: 0.5408 - learning_rate: 1.0000e-04
Epoch 12/30
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 621ms/step - accuracy: 0.9952 - dice_coef: 0.8813 - loss: 0.0800
Epoch 12: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
[1m532/532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m351s[0m 658ms/s

In [87]:
model = load_trained_model()



### Testing

In [None]:
def apply_model_to_array(model, img2d, resize=False):
    """
    img2d: 2D numpy array (H, W), already preprocessed (normalized, disk masked, etc.)
    model: your trained Keras model

    Returns:
        prob_map: 2D numpy array (img_size, img_size) with values in [0, 1]
    """
    img = np.asarray(img2d, dtype=np.float32)

    if img.ndim != 2:
        raise ValueError(f"Expected 2D array, got shape {img.shape}")

    # Optional: resize if shape doesn't match training size
    if resize and img.shape != (IMG_SIZE, IMG_SIZE):
        import tensorflow as tf

        img_tf = tf.convert_to_tensor(img[..., np.newaxis])  # (H, W, 1)
        img_tf = tf.image.resize(img_tf, (IMG_SIZE, IMG_SIZE), method="bilinear")
        img = img_tf.numpy()[..., 0]

    # Add channel and batch dims: (H, W) → (1, H, W, 1)
    x = img[np.newaxis, ..., np.newaxis]  # (1, H, W, 1)

    # Predict
    prob = model.predict(x, verbose=0)[0, ..., 0]  # back to (H, W)

    return prob

In [None]:
def mask_via_model(path):
    prob_map = apply_model_to_array(model, path, resize=True)
    mask = (prob_map > 0.5).astype(np.float32)  # binary mask if you want it
    mask_uint8 = (mask * 255).clip(0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(mask_uint8, mode="L")

    if img.size != (1024, 1024):
        img = img.resize((1024, 1024), resample=PIL.Image.NEAREST)

    return img

In [None]:
def prob_map_via_model(path):
    prob_map = apply_model_to_array(model, path, resize=True)
    img_uint8 = np.clip(prob_map * 255, 0, 255).astype(np.uint8)
    img = PIL.Image.fromarray(img_uint8, mode="L")
    return img

In [149]:
def plot_mask(row, sdo=False):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(np.array(mask_via_model(prepare_fits(row.fits_path))), cmap="gray")
    # plt.title(titles[0])
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(np.array(PIL.Image.open(row.mask_path)), cmap="gray")
    # plt.title(titles[1])
    plt.axis("off")

    plt.tight_layout()
    plt.show()

def plot_sdo(row):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(prepare_fits(row.fits_path), cmap="gray")
    plt.contour(np.array(mask_via_model(prepare_fits(row.fits_path))), levels=[0.5], colors="red", linewidths=1.5)
    # plt.title(titles[0])
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(prepare_fits(row.fits_path), cmap="gray")
    plt.contour(np.array(prepare_mask(row.mask_path)), levels=[0.5], colors="red", linewidths=1.5)
    # plt.title(titles[1])
    plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# 1. Register your dataframes here
dfs = {
    "train": train_df,
    "inferrence": inf_df,
}

# 2. Widgets
df_selector = widgets.RadioButtons(
    options=list(dfs.keys()),
    value="train",
    description="DataFrame:",
)

idx_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dfs["train"]) - 1,
    step=1,
    description="Index:",
    continuous_update=False,
)

show_mask_checkbox = widgets.Checkbox(
    value=True,
    description="Show mask (off = SDO)",
)

out = widgets.Output()

# 3. Update slider range when DF changes
def update_slider_range(change):
    df = dfs[df_selector.value]
    idx_slider.max = max(0, len(df) - 1)
    if idx_slider.value > idx_slider.max:
        idx_slider.value = idx_slider.max

df_selector.observe(update_slider_range, names="value")

# 4. Main update function
def update_plot(change=None):
    with out:
        clear_output(wait=True)
        df = dfs[df_selector.value]
        if len(df) == 0:
            print("Selected dataframe is empty.")
            return

        row = df.iloc[idx_slider.value]

        if show_mask_checkbox.value:
            # You implement this: should plot SDO + NN/IDL mask
            plot_mask(row)
        else:
            # You implement this: should plot just the SDO / FITS-based view
            plot_sdo(row)

# 5. Hook up callbacks
idx_slider.observe(update_plot, names="value")
show_mask_checkbox.observe(update_plot, names="value")
df_selector.observe(update_plot, names="value")

# 6. Display the UI
ui = widgets.VBox([
    df_selector,
    idx_slider,
    show_mask_checkbox,
    out,
])

display(ui)

# Initial draw
update_slider_range(None)
update_plot(None)

VBox(children=(RadioButtons(description='DataFrame:', options=('train', 'inferrence'), value='train'), IntSlid…