In [None]:
# --- add these helpers somewhere above the function ---
from typing import Tuple


def _order_roots_clockwise_by_distance(distances: dict[str, int], top: str = "h_21"):
    """
    Order roots by ascending distance, rotate so `top` is first.
    This list will be laid out clockwise by the arc generator.
    """
    roots = [k for k, _ in sorted(distances.items(), key=lambda kv: kv[1])]
    if top in roots:
        i = roots.index(top)
        roots = roots[i:] + roots[:i]
    return roots

def _generate_arcs_clockwise(n: int, theta_deg: float, gap_total_deg: float, start_at_deg: float = 90.0):
    """
    Make n arcs centered at angles that step clockwise from `start_at_deg`.
    Each arc has angular width `theta_deg`; total gaps sum to `gap_total_deg`.
    """
    if n <= 0:
        return []
    gap_per = gap_total_deg / n
    step = theta_deg + gap_per  # center-to-center angular step
    arcs = []
    for i in range(n):
        mid = (start_at_deg - i * step) % 360.0  # clockwise
        start = mid - theta_deg / 2.0
        end   = mid + theta_deg / 2.0
        arcs.append({"start_deg": start, "end_deg": end, "mid_deg": mid})
    return arcs



def plot_flower_initial_and_final_two_gradients(
    tsv_path: str,
    init_col: int = 4,
    final_col: int = 5,
    theta: float = 20.0,
    gap_total: float = 20.0,
    r_inner: float = 0.70,
    r_outer: float = 1.00,
    center: Tuple[float, float] = (0.0, 0.0),
    cmap_init: str = "plasma",
    cmap_final: str = "plasma",
    root_col: int = 0,
    title_left: str = "Initial log-likelihood",
    title_right: str = "Final log-likelihood",
    show_colorbars: bool = True,
    top_root: str = "h_21",
    vmin_init: float = None,
    vmax_init: float = None,
    vmin_final: float = None,
    vmax_final: float = None,
):
    """
    Read TSV and draw two flower plots (initial & final) with separate color scales.
    Petal order is anti-clockwise by hard-coded distance; h_21 at 12 o'clock.
    Labels show ONLY the distance from root (single string).

    You can override color scale ranges via:
      vmin_init, vmax_init, vmin_final, vmax_final
    """
    # hard-coded distances
    distances = {
        "h_21": 5,  "h_22": 12, "h_23": 15, "h_24": 21, "h_25": 26,
        "h_26": 22, "h_27": 17, "h_28": 15, "h_29": 18, "h_30": 10,
        "h_31": 18, "h_32": 12, "h_33": 18, "h_34": 18, "h_35": 20,
        "h_36": 22, "h_37": 8,
    }

    init_by_root, final_by_root = _load_grouped(tsv_path, root_col=root_col, init_col=init_col, final_col=final_col)

    order = _order_roots_clockwise_by_distance(distances, top=top_root)
    order = [r for r in order if r in init_by_root or r in final_by_root]
    if not order:
        raise ValueError("No matching roots from TSV in the hard-coded distance list.")

    arcs  = _generate_arcs_clockwise(n=len(order), theta_deg=theta, gap_total_deg=gap_total, start_at_deg=90.0)

    cx, cy = center
    band = r_outer - r_inner

    all_init = [v for r in order for v in init_by_root.get(r, [])]
    all_final = [v for r in order for v in final_by_root.get(r, [])]
    if not all_init or not all_final:
        raise ValueError("Parsed data has no initial or final values.")

    # Use provided range or fallback to min/max from data
    print(vmax_init)

    if vmin_init is None: vmin_init = min(all_init)
    if vmax_init is None: vmax_init = max(all_init)
    if vmin_final is None: vmin_final = min(all_final)
    if vmax_final is None: vmax_final = max(all_final)

    # vmin_init = -4365.5
    # vmax_init = -4364.8
    # vmin_final = -4365.5
    # vmax_final = -4364.8

    norm_init = colors.Normalize(vmin=vmin_init, vmax=vmax_init)
    norm_final = colors.Normalize(vmin=vmin_final, vmax=vmax_final)
    cm_init = cm.get_cmap(cmap_init)
    cm_final = cm.get_cmap(cmap_final)

    fig, (axL, axR) = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)

    # Left: initial
    for root, arc in zip(order, arcs):
        vals = init_by_root.get(root, [])
        if not vals:
            continue
        m = len(vals)
        for idx, val in enumerate(vals):
            rin  = r_inner + (band * idx) / m
            rout = r_inner + (band * (idx + 1)) / m
            axL.add_patch(Wedge(
                (cx, cy), rout,
                theta1=arc["start_deg"], theta2=arc["end_deg"],
                width=(rout - rin),
                facecolor=cm_init(norm_init(val)), edgecolor="none",
            ))
    # Labels
    pad = 0.08 * r_outer
    for root, arc in zip(order, arcs):
        mid = 0.5 * (arc["start_deg"] + arc["end_deg"])
        ang = math.radians(mid)
        label_r = r_outer + 0.06
        axL.text(
            cx + label_r * math.cos(ang),
            cy + label_r * math.sin(ang),
            _label_from_distance(root, distances),
            ha="center", va="center",
            rotation=mid - 90, rotation_mode="anchor", fontsize=8,
        )
    axL.set_xlim(cx - r_outer - pad, cx + r_outer + pad)
    axL.set_ylim(cy - r_outer - pad, cy + r_outer + pad)
    axL.set_aspect("equal", adjustable="box")
    axL.axis("off")
    axL.set_title(title_left)

    # Right: final
    for root, arc in zip(order, arcs):
        vals = final_by_root.get(root, [])
        if not vals:
            continue
        m = len(vals)
        for idx, val in enumerate(vals):
            rin  = r_inner + (band * idx) / m
            rout = r_inner + (band * (idx + 1)) / m
            axR.add_patch(Wedge(
                (cx, cy), rout,
                theta1=arc["start_deg"], theta2=arc["end_deg"],
                width=(rout - rin),
                facecolor=cm_final(norm_final(val)), edgecolor="none",
            ))
    for root, arc in zip(order, arcs):
        mid = 0.5 * (arc["start_deg"] + arc["end_deg"])
        ang = math.radians(mid)
        label_r = r_outer + 0.06
        axR.text(
            cx + label_r * math.cos(ang),
            cy + label_r * math.sin(ang),
            _label_from_distance(root, distances),
            ha="center", va="center",
            rotation=mid - 90, rotation_mode="anchor", fontsize=8,
        )
    axR.set_xlim(cx - r_outer - pad, cx + r_outer + pad)
    axR.set_ylim(cy - r_outer - pad, cy + r_outer + pad)
    axR.set_aspect("equal", adjustable="box")
    axR.axis("off")
    axR.set_title(title_right)

    # Colorbars
    if show_colorbars:
        smL = cm.ScalarMappable(norm=norm_init, cmap=cm_init); smL.set_array([])
        cbarL = fig.colorbar(smL, ax=axL, fraction=0.046, pad=0.04)
        cbarL.set_label("Initial log likelihood")

        smR = cm.ScalarMappable(norm=norm_final, cmap=cm_final); smR.set_array([])
        cbarR = fig.colorbar(smR, ax=axR, fraction=0.046, pad=0.04)
        cbarR.set_label("Final log likelihood")

    return fig, (axL, axR)
