# Setup

In [None]:
import gc
import random

import astropy.units as u
import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sunpy.map
import sunpy.visualization.colormaps.color_tables as ct
from astropy.visualization import AsinhStretch, ImageNormalize
from IPython.display import clear_output, display
from matplotlib.patches import Rectangle
from sunpy.coordinates import frames
from tqdm import tqdm
import json
import random
from pathlib import Path

import ipywidgets as widgets
from IPython.display import display, clear_output

from contextlib import contextmanager

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from Library.Processing import *
from Library.IO import *
from Library.Model import *
from Library.Metrics import *
from Library.Config import *
from Library.CH import *
from Library.Plot import *
from Models import load_date_range

In [None]:
pd.set_option("display.width", 10000)
pd.set_option("display.max_colwidth", None)

# Data Processing

In [None]:
df = pd.read_parquet(paths["artifact_root"] + "Paths.parquet")

In [None]:
ARCH_ID = "A1"
DATE_ID = "D1"

date_range = load_date_range(ARCH_ID, DATE_ID)
train_df, val_df = date_range.select_pairs(df)

In [None]:
# set-wise subtraction
if val_df is None:
    inf_df = df.loc[~df.index.isin(train_df.index)]
else:
    combined_df = pd.concat([train_df, val_df])
    inf_df = df.loc[~df.index.isin(combined_df.index)]

# Model

## Training

In [None]:
# train_model(train_df)

In [None]:
model = load_trained_model(ARCH_ID, DATE_ID)

## Plotting

In [None]:
def plot_sdo(
    row,
    postprocessing="P0",
    oval=None,
    show_fits=True,
    pmap=None,
):
    """
    Double plotter wrapper around plot_ch_map().
    Keeps the original signature.

    - Left: U-Net (helio-n)
    - Right: IDL
    """
    print(row)
    # Keep existing diagnostic print; assumes these globals exist in your notebook/module
    params = (
        postprocessing
        if isinstance(postprocessing, dict)
        else get_postprocessing_params(postprocessing)
    )
    print_distance(row, model, params)

    # Keep old behavior: compute oval once and reuse
    if oval is None:
        oval = generate_omask(row)

    m = sunpy.map.Map(row.fits_path)

    # Map old "smoothing_params" to the string key plot_ch_map expects.
    # This preserves call compatibility without changing plot_ch_map's signature.
    # If you have a canonical mapping util, swap this block out.

    fig = plt.figure(figsize=((TARGET_PX / DPI) * 2.1, TARGET_PX / DPI))
    ax1 = fig.add_subplot(
        1, 2, 1, projection=m
    )  # let plot_ch_map create/use WCS projection internally
    ax2 = fig.add_subplot(1, 2, 2, projection=m)

    # Left panel: U-Net
    plot_ch_map(
        row,
        source="unet",
        model=model,  # required if no pmap_path/pmap on row
        pmap=pmap,  # plot_ch_map will load/compute if needed
        postprocessing=postprocessing,
        oval=oval,
        show_fits=show_fits,
        multiplot_ax=ax1,
        set_title=False,
    )
    ax1.set_title("helio-n (U-Net)")

    # Right panel: IDL
    plot_ch_map(
        row,
        source="idl",
        model=None,
        pmap=None,
        postprocessing=postprocessing,  # ignored for IDL, harmless to pass
        oval=oval,
        show_fits=show_fits,
        multiplot_ax=ax2,
        set_title=False,
    )
    ax2.set_title("IDL")

    plt.tight_layout()
    plt.show()

In [None]:
_SUPPRESS_REDRAW = 0


@contextmanager
def suppress_redraw():
    global _SUPPRESS_REDRAW
    _SUPPRESS_REDRAW += 1
    try:
        yield
    finally:
        _SUPPRESS_REDRAW -= 1


# -------------------------
# 0) Postprocessing presets
# -------------------------
PP_DIR = Path("./Config/Postprocessing").resolve()


def list_json_presets(pp_dir: Path):
    if not pp_dir.exists():
        return ["Custom"], {"Custom": None}

    files = sorted(
        [p for p in pp_dir.iterdir() if p.is_file() and p.suffix.lower() == ".json"]
    )
    names = [p.stem for p in files]
    mapping = {p.stem: p for p in files}

    # Ensure Custom exists as an option even if file missing
    if "Custom" not in mapping:
        names = ["Custom"] + names
        mapping["Custom"] = pp_dir / "Custom.json"

    # Make Custom default and first in list
    if "Custom" in names:
        names = ["Custom"] + [n for n in names if n != "Custom"]

    return names, mapping


PRESET_NAMES, PRESET_PATHS = list_json_presets(PP_DIR)
DEFAULT_PRESET = "P1" if "P1" in PRESET_NAMES else PRESET_NAMES[0]


def load_params_from_file(path: Path, fallback: dict):
    try:
        if path is None:
            return dict(fallback)
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        # Only take known keys; fallback for missing
        out = dict(fallback)
        for k in ["threshold", "closing_radius", "min_size", "hole_size"]:
            if k in data:
                out[k] = data[k]
        return out
    except Exception:
        return dict(fallback)


def write_params_to_custom(path: Path, params: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(params, f, indent=2, sort_keys=True)


# --------------------------------
# 1) Register your dataframes here
# --------------------------------
dfs = {
    "train": train_df,
    "inference": inf_df,
}

# ----------------
# 2) Widgets
# ----------------
full_row = widgets.Layout(width="100%")
row_wrap = widgets.Layout(
    display="flex", flex_flow="row wrap", align_items="center", width="100%"
)
checkbox_layout = widgets.Layout(margin="0 0 0 12px")


df_selector = widgets.RadioButtons(
    options=list(dfs.keys()),
    value="inference",
    description="DataFrame:",
    layout=widgets.Layout(margin="0 8px 0 0"),
)

idx_slider = widgets.IntSlider(
    value=random.randint(0, len(dfs["train"]) - 1) if len(dfs["train"]) else 0,
    min=0,
    max=max(0, len(dfs["train"]) - 1),
    step=1,
    description="Index:",
    continuous_update=False,
    layout=full_row,
)

show_mask_checkbox = widgets.Checkbox(
    value=False,
    description="Show Mask Only",
    layout=checkbox_layout,
)

regenerate_pmap_checkbox = widgets.Checkbox(
    value=False,
    description="Regenerate PMAP",
    layout=checkbox_layout,
)

# Preset selector (radio)
preset_selector = widgets.RadioButtons(
    options=PRESET_NAMES,
    value=DEFAULT_PRESET,
    description="Postproc:",
    layout=widgets.Layout(margin="0 0 0 8px"),
)


# ----------------------------------------
# 3.5) When preset changes, load/lock controls
# ----------------------------------------
smoothing_params = get_postprocessing_params(
    DEFAULT_PRESET if DEFAULT_PRESET != "Custom" else "Custom"
)


def on_preset_change(change=None):
    global smoothing_params

    preset = preset_selector.value
    is_custom = preset == "Custom"
    set_custom_controls_enabled(is_custom)

    path = PRESET_PATHS.get(preset)
    params = load_params_from_file(path, fallback=smoothing_params)

    with suppress_redraw():
        threshold_slider.value = float(params["threshold"])
        closing_radius_slider.value = int(params["closing_radius"])
        min_size_slider.value = int(params["min_size"])
        hole_size_slider.value = float(params["hole_size"])

    update_plot(None)


preset_selector.observe(on_preset_change, names="value")

# Sliders (Custom-only controls)
threshold_slider = widgets.FloatSlider(
    value=smoothing_params["threshold"],
    min=0.0,
    max=1.0,
    step=0.01,
    description="Threshold",
    continuous_update=False,
    layout=full_row,
)
closing_radius_slider = widgets.IntSlider(
    value=smoothing_params["closing_radius"],
    min=0,
    max=20,
    step=1,
    description="Closing R",
    continuous_update=False,
    layout=full_row,
)
min_size_slider = widgets.IntSlider(
    value=smoothing_params["min_size"],
    min=0,
    max=2000,
    step=10,
    description="Min size",
    continuous_update=False,
    layout=full_row,
)
hole_size_slider = widgets.FloatSlider(
    value=smoothing_params["hole_size"],
    min=0.0,
    max=5000,
    step=50,
    description="Hole area",
    continuous_update=False,
    layout=full_row,
)

out = widgets.Output()

CUSTOM_CONTROLS = [
    threshold_slider,
    closing_radius_slider,
    min_size_slider,
    hole_size_slider,
]


def set_custom_controls_enabled(enabled: bool):
    for w in CUSTOM_CONTROLS:
        w.disabled = not enabled


# -------------------------------
# 3) Update slider range when DF changes
# -------------------------------
def update_slider_range(change=None):
    df = dfs[df_selector.value]
    idx_slider.max = max(0, len(df) - 1)
    if idx_slider.value > idx_slider.max:
        idx_slider.value = idx_slider.max


df_selector.observe(update_slider_range, names="value")


# -------------------------------
# 4) Main update function
# -------------------------------
def current_slider_params():
    return {
        "threshold": float(threshold_slider.value),
        "closing_radius": int(closing_radius_slider.value),
        "min_size": int(min_size_slider.value),
        "hole_size": float(hole_size_slider.value),
    }


def update_plot(change=None):
    global smoothing_params

    if _SUPPRESS_REDRAW > 0:
        return

    with out:
        clear_output(wait=True)

        df = dfs[df_selector.value]
        if len(df) == 0:
            print("Selected dataframe is empty.")
            return

        row = df.iloc[idx_slider.value]

        params = current_slider_params()
        smoothing_params = dict(params)

        pmap = None
        if regenerate_pmap_checkbox.value:
            _, pmap = save_pmap(model, row)
            regenerate_pmap_checkbox.value = False

        postprocessing = preset_selector.value
        if postprocessing == "Custom":
            write_params_to_custom(PP_DIR / "Custom.json", smoothing_params)

        plot_sdo(
            row,
            postprocessing=postprocessing,
            show_fits=not show_mask_checkbox.value,
            pmap=pmap,
        )


# -------------------------------
# 5) Hook up callbacks
# -------------------------------
idx_slider.observe(update_plot, names="value")
show_mask_checkbox.observe(update_plot, names="value")
regenerate_pmap_checkbox.observe(update_plot, names="value")
df_selector.observe(update_plot, names="value")

# Sliders should trigger updates only when Custom is active; simplest is to always call
# update_plot, but it will only persist to disk for Custom.
for w in CUSTOM_CONTROLS:
    w.observe(update_plot, names="value")

# -------------------------------
# 6) Display the UI
# -------------------------------
controls = [
    widgets.HBox([idx_slider], layout=full_row),
    widgets.HBox(
        [df_selector, show_mask_checkbox, regenerate_pmap_checkbox, preset_selector],
        layout=row_wrap,
    ),
]
slider_rows = [widgets.HBox([w], layout=full_row) for w in CUSTOM_CONTROLS]

ui = widgets.VBox(controls + slider_rows + [out], layout=full_row)

In [None]:
display(ui)

# Initial draw
update_slider_range(None)
on_preset_change(None)  # sets enabled/disabled + loads Custom if present + draws

# Coronal Hole Area

In [None]:
# 400 - no oval hole
# 200 — large central hole

In [None]:
d = df["20170101_0000":"20180101_0000"]

In [None]:
v_min = 300
a = 170
alpha = 0.4


def v(s):
    return v_min + a * (s * 100) ** alpha

In [None]:
from tqdm.auto import tqdm

tqdm.pandas()

In [48]:
d["v_idl"] = d.progress_apply(lambda row: ch_rel_area(row, reference_mode=True), axis=1)

  0%|          | 0/1459 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
d["v_ch"] = d.progress_apply(lambda row: ch_rel_area(row, reference_mode=False, model=model), axis=1)

In [None]:
d.to_parquet("Outputs/V IDL vs U-Net.parquet")

In [None]:
v(ch_rel_area(row, model=model, reference_mode=True))

In [None]:
v(ch_rel_area(row, model, reference_mode=False))

In [None]:
ch_rel_area(row, reference_mode=True)

# Polarity

In [None]:
row = df.iloc[6666]
row

In [None]:
plot_sdo(row, postprocessing="P1")

In [None]:
map_aia = sunpy.map.Map(row.fits_path)
map_hmi = sunpy.map.Map(row.hmi_path)

In [None]:
map_hmi.plot_settings["cmap"] = "hmimag"
map_hmi.plot_settings["norm"] = plt.Normalize(-1500, 1500)
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(121, projection=map_aia)
map_aia.plot(axes=ax1, clip_interval=(1, 99.9) * u.percent)
ax2 = fig.add_subplot(122, projection=map_aia)
map_hmi.plot(axes=ax2, title="HMI image in AIA reference frame")

In [None]:
def plot_with_polarity(row, smoothing_params, B_thresh=0.15):
    """
    row: DataFrame row with row.fits_path and row.mask_path and row.hmi_path (JPG)
    smoothing_params: mask-via-model parameters
    """

    aia = prepare_fits(row.fits_path)  # (1024×1024, normalized)
    hmi = prepare_hmi_jpg()  # upscale JPG → match AIA grid

    nn_mask_raw = pmap_to_mask(row, smoothing_params)
    nn_mask = nn_mask_raw > 0.5

    idl_mask_raw = prepare_mask(row.mask_path)
    idl_mask = idl_mask_raw > 0.5

    # polarity masks using pseudo-HMI field
    nn_pos = nn_mask & (hmi >= B_thresh)
    nn_neg = nn_mask & (hmi <= -B_thresh)

    idl_pos = idl_mask & (hmi >= B_thresh)
    idl_neg = idl_mask & (hmi <= -B_thresh)

    # build RGBA overlays
    def make_overlay(pos, neg, alpha=0.5):
        h, w = pos.shape
        rgba_pos = np.zeros((h, w, 4), dtype=np.float32)
        rgba_neg = np.zeros((h, w, 4), dtype=np.float32)

        rgba_pos[..., 0] = 1.0  # red
        rgba_pos[..., 3] = alpha * pos.astype(float)

        rgba_neg[..., 2] = 1.0  # blue
        rgba_neg[..., 3] = alpha * neg.astype(float)

        return rgba_pos, rgba_neg

    nn_overlay_pos, nn_overlay_neg = make_overlay(nn_pos, nn_neg)
    idl_overlay_pos, idl_overlay_neg = make_overlay(idl_pos, idl_neg)

    # plots
    print_distance(row, smoothing_params)
    plt.figure(figsize=(10, 5))

    # ------------------ U-Net ------------------
    plt.subplot(1, 2, 1)
    plt.imshow(aia, cmap=cmap)
    plt.contour(nn_mask.astype(float), levels=[0.5], colors="red")
    plt.imshow(nn_overlay_pos)
    plt.imshow(nn_overlay_neg)
    plt.title("helio-n: red=+, blue=-")
    plt.axis("off")

    # ------------------ IDL ------------------
    plt.subplot(1, 2, 2)
    plt.imshow(aia, cmap=cmap)
    plt.contour(idl_mask.astype(float), levels=[0.5], colors="red")
    plt.imshow(idl_overlay_pos)
    plt.imshow(idl_overlay_neg)
    plt.title("IDL: red=+, blue=-")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# plot_with_polarity(inf_df.iloc[13385], smoothing_params)

In [None]:
row = df.dropna().iloc[333]
row

In [None]:
hmi = sunpy.map.Map(row.hmi_path)
hmi

In [None]:
import re
from pathlib import Path

hmi = "hmi.M_720s.20100513_000000_TAI.fits"
aia = "AIA20100513_000000_0193.fits"

print(re.search(r"(\\d{8}_\\d{6})", hmi).group(1))  # expect 20100513_000000
print(aia[3:16])  # current AIA key = 20100513_0000

In [None]:
train_df.sort_index()

In [None]:
import pandas as pd
from Library.Config import paths
from Models import load_date_range

df = pd.read_parquet(paths["artifact_root"] + "Paths.parquet").sort_index()

arch_id = "A2"
date_id = "D1"
date_range = load_date_range(arch_id, date_id)

train_df, val_df = date_range.select_pairs(df)
print("train:", len(train_df), "val:", len(val_df))

# per-year counts (train)
for year in (2011, 2015, 2018, 2020):
    year_df = train_df.sort_index()[f"{year}0101_0000" :f"{year+1}0101_0000"]
    print(
        year, "rows:", len(year_df), "unique days:", year_df.index.normalize().nunique()
    )

# how keep_every would reduce it
kept = train_df.iloc[:: date_range.keep_every]
print("train kept:", len(kept), "keep_every:", date_range.keep_every)

# per-day counts (train)
per_day = train_df.groupby(train_df.index.normalize()).size()
print(
    "days:", len(per_day), "avg per day:", per_day.mean(), "max per day:", per_day.max()
)