In [None]:
def plot_RG_border_prob(mat_p, border_p, region_s, genome=None, title="TITLE", merge_hom=True, line_title=None, pm_args={}, prob_args={}):
    assert genome is not None, "genome must be specified"
    region = Region(region_s, genome=genome).r
    # --- parse border data --- #
    if isinstance(border_p, str) or isinstance(border_p, Path):
        data = pd.read_parquet(
            border_p
        )
        data = (data == "boundary").astype(int)
        data = data.sort_index()
    else: # assume dataframe
        data = border_p.copy()
        data.index = pd.MultiIndex.from_product(
            [[chrom_rm_suffix(region[0][0]) if merge_hom else region[0]], data.index]
        )
        #print(data)
        data.index.names = ["chrom", "start"]
    if merge_hom:
        # don't separate homologous chromosomes
        # in tad borders
        tad_region_0 = (chrom_rm_suffix(region[0][0]),region[0][1])
        tad_region_1 = (chrom_rm_suffix(region[1][0]),region[1][1])
    else:
        tad_region_0 = region[0]
        tad_region_1 = region[1]
    #print(tad_region_0, tad_region_1)
    prob = data.loc[
        tad_region_0:tad_region_1
        :
    ].sum(axis=1) / data.shape[1]
    prob = prob.rename("TAD_prob").reset_index()
    if line_title is None:
        line_title = f"Probability of scTLD boundary"
    # --- generate pm --- #
    if isinstance(mat_p, str) or isinstance(mat_p, Path):
        pm = cool2mat(mat_p, region_s)
        heatmap_title = "Proximity Matrix"
    elif isinstance(mat_p, Mchr):
        pm = mat_p.PM(
            None,
            [region, region],
            proximity = 3
        )
        heatmap_title = "Proximity Matrix"
    elif isinstance(mat_p, pd.DataFrame):
        # assume all_rg df
        samples = mat_p.columns
        pm = ((mat_p[samples]**2).mean(axis=1)**0.5).unstack()
        heatmap_title = "mean RG Matrix"
    else:
        raise ValueError("mat_p must be a path or Mchr object")
    # --- plot --- #
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.1,
        row_heights=[0.8, 0.2],
        subplot_titles=(heatmap_title, line_title)
    )
    heatmap = _plot_mat(
        pm,
        title = "",
        donorm = False,
        ignore_diags = False,
        cmap = "Viridis",
        **pm_args
        ).data[0]
    fig.add_trace(heatmap, row=1, col=1)
    fig.add_trace(
        go.Scatter(
            x = prob["start"],
            y = prob["TAD_prob"],
            mode = "lines",
            name = "TAD boundary probability",
            line = dict(
                color = "black",
                width = 2
            )
        ),
        row=2, col=1
    )
    fig.update_layout(
        height= 650,
        width = 500,
        title = title,
        plot_bgcolor = "rgba(0,0,0,0)",
    )
    return fig