In [None]:
import os
import copy
import scanpy as sc
import plotly.express as px
import numpy as np
import pandas as pd
import benchplots as bp
import plotly.express as px
import plotly.graph_objects as go


from benchdb import * 
from benchmodels import *
from benchplots import *
from benchutils import *
from plotly.subplots import make_subplots
from sklearn.model_selection import GroupKFold

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0


tmuris facs

In [2]:
tmf = load_benchdb("tmuris_facs.json")
mvote_all(tmf, to_mvote=RM_METHODS)
df = bench_to_df(tmf)

In [6]:
fig = plot_box(df, 'Tabula Muris Senis', merge_mv_rcm=False, horizontal=False, color_boxes=False)
fig.update_layout(height=800, width=1000, title="Tabula Muris Senis (FACS)")
fig.update_layout(margin=dict(t=60))
fig

In [7]:
os.makedirs("SUP", exist_ok=True)
fig.write_image("SUP/facs.png", scale=3)

tmuris scatter plots

In [14]:
db = load_benchdb('tmuris_drop.json')

In [None]:
ds = load_adata('../data/tmuris_senis_droplet.h5ad')
key = "cell_ontology_class"
split_on = 'mouse.id'

groups = ds.obs['mouse.id']
X = np.arange(ds.shape[0])
y = ds.obs[key]

In [5]:
nref = [] # [fold-1 : {label: count}]

gkf = GroupKFold(n_splits=5)
labels = ds.obs[key].unique().tolist()

for fold, (ref_idx, _) in enumerate(gkf.split(X, y, groups=groups), 1):
    nref.append({})

    ref = ds[ref_idx]
    for l in labels:
        nref[fold-1][l] = int((ref.obs[key] == l).sum())


In [None]:
# nref:  [ <idx fold-1> { label: count } ]
def plot_scatter(method, nref, logall= True, msize=6.0, size_max=18, use_symbols=True):
    data = []    
    for fold in range(5):
        
        a = db[method][f"fold_{fold+1}"]["fref"]
        preds = np.array(a.preds)
        true  = np.array(a.true)

        for ql in np.unique(true):
            n_q   = (true == ql).sum()
            _nref = nref[fold].get(str(ql), None)
            if _nref is None:
                continue
            acc   = (preds[true == ql] == ql).mean()
            data.append([ql, n_q, _nref, acc])


    df = pd.DataFrame(data, columns=["label", "n_q", "n_ref", "acc"])
    df["status"] = (df["acc"] >= 0.5).astype(int).map({0: "incorrect", 1: "correct"})

    # aggregate duplicates -> encode multiplicity in size
    g = (
        df.groupby(["n_q", "n_ref", "status"], as_index=False)
          .size()
          .rename(columns={"size": "k"})
    )

    colors = {"correct": "#0072B2", "incorrect": "#D55E00"}  # Okabe–Ito

    kwargs = dict(
        data_frame=g,
        x="n_q",
        y="n_ref",
        size="k",
        size_max=size_max,
        color="status",
        color_discrete_map=colors,
        log_x=logall,
        log_y=logall,
    )

    if use_symbols:
        kwargs.update(
            symbol="status",
            symbol_map={"incorrect": "x", "correct": "circle"},
        )

    fig = px.scatter(**kwargs)

    # clean styling for print
    fig.update_traces(
        marker=dict(
            size=msize,          # baseline; "size" mapping still applies via k
            opacity=0.85,
            line=dict(width=0.6) # subtle outline helps separation
        )
    )

    fig.update_xaxes(title="count in query", showline=True, mirror=True, ticks="outside")
    fig.update_yaxes(title="count in reference", showline=True, mirror=True, ticks="outside")
    fig.update_layout(
        template="simple_white",
        title=f"{method} accuracy by cell type count in query and reference",
        legend_title_text="",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
    )

    bp.add_paper_styling(fig)
    return fig, df


def plot_scatter_panel(
    nref,
    methods: list[str],
    rows: int = 1,
    cols: int = None,
    logall: bool = True,
    msize: float = 6.0,
    size_max: int = 18,
    use_symbols: bool = True,
    title: str = "Accuracy by cell type count",
    ytitle: str = "count in reference",
    xtitle: str = "count in query",
    fs: int = 16,
    show_outer_ticks: bool = True,
    show_all_ticks: bool = False,
    remove_tick_marks: bool = False,
    title_x_off: float = 0.5,
    title_y_off: float = 0.98,
    ytitle_x_off: float = -0.10,
    ytitle_y_off: float = 0.5,
    xtitle_x_off: float = 0.5,
    xtitle_y_off: float = -0.12,
    cell_w: int = 320,
    cell_h: int = 300,
    hspace: float = 0.06,
    vspace: float = 0.10,
):
    """
    Multi-panel scatter plot for method comparison.
    Mirrors the heatmap panel pattern with post-add_paper_styling axis corrections.
    """
    # Generate individual figures
    figs = []
    for m in methods:
        f, _ = plot_scatter(m, nref, logall=logall, msize=msize, 
                            size_max=size_max, use_symbols=use_symbols)
        figs.append(f)

    n = len(figs)
    if cols is None:
        cols = n
    assert rows * cols >= n, "insufficient rows/cols"

    # Create subplot grid
    panel = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=methods,
        horizontal_spacing=hspace,
        vertical_spacing=vspace,
    )

    # Track which legend entries we've shown (to avoid duplicates)

    for k, f in enumerate(figs):
        r = (k // cols) + 1
        c = (k % cols) + 1
        
        # Sort traces: "correct" first, "incorrect" last (on top)
        sorted_traces = sorted(f.data, key=lambda t: t.name == "incorrect")
        
        for trace in sorted_traces:
            tr = copy.deepcopy(trace)
            tr.showlegend = False
            panel.add_trace(tr, row=r, col=c)

        # for trace in f.data:
        #     tr = copy.deepcopy(trace)
        #     tr.showlegend = False
        #     panel.add_trace(tr, row=r, col=c)

    # Layout sizing
    panel.update_layout(
        width=cell_w * cols + 60,
        height=cell_h * rows + 80,
        margin=dict(l=100, r=40, t=80, b=100),
    )

    # Apply paper styling (this may mess with axes)
    bp.add_paper_styling(panel, lines=False)

    # --- Post-add_paper_styling axis corrections ---
    for k in range(n):
        r = (k // cols) + 1
        c = (k % cols) + 1

        # Decide tick label visibility
        if show_all_ticks:
            show_x = True
            show_y = True
        elif show_outer_ticks:
            show_x = (r == rows)   # bottom row only
            show_y = (c == 1)      # left column only
        else:
            show_x = False
            show_y = False

        axis_common = dict(
            showgrid=False,  # <-- changed
            zeroline=False,
            showline=True,
            linewidth=1,
            linecolor="black",
            mirror=True,
        )

        # X axis — ticks conditional on show_x
        xaxis_update = dict(
            **axis_common,
            showticklabels=show_x,
            ticks="outside" if (show_x and not remove_tick_marks) else "",
            ticklen=5 if (show_x and not remove_tick_marks) else 0,
            title=None,
        )
        if logall:
            xaxis_update["type"] = "log"

        # Y axis — ticks conditional on show_y
        yaxis_update = dict(
            **axis_common,
            showticklabels=show_y,
            ticks="outside" if (show_y and not remove_tick_marks) else "",
            ticklen=5 if (show_y and not remove_tick_marks) else 0,
            title=None,
        )
        if logall:
            yaxis_update["type"] = "log"

        panel.update_xaxes(**xaxis_update, row=r, col=c)
        panel.update_yaxes(**yaxis_update, row=r, col=c)

    # Shared axis title annotations
    # X title (bottom center)
    panel.add_annotation(
        text=xtitle,
        x=xtitle_x_off,
        y=xtitle_y_off,
        xref="paper",
        yref="paper",
        showarrow=False,
        font=dict(size=fs),
    )
    # Y title (left center, rotated)
    panel.add_annotation(
        text=ytitle,
        x=ytitle_x_off,
        y=ytitle_y_off,
        xref="paper",
        yref="paper",
        showarrow=False,
        textangle=-90,
        font=dict(size=fs),
    )

    # Main title
    panel.update_layout(
        title=dict(
            text=title,
            x=title_x_off,
            xanchor="center",
            y=title_y_off,
            yanchor="top",
            font=dict(size=fs + 4),
        )
    )

    # Shift subplot titles slightly
    for ann in panel.layout.annotations:
        if ann.text in methods:
            ann.yshift = (ann.yshift or 0) + 5
            ann.font = dict(size=fs)


    return panel

In [18]:
data = []
for m in RCM_METHODS + MV_METHODS:
    for i in range(5):
        a = db[m][f"fold_{i+1}"]['fref']
        data.append([m, a.cacc])

df = pd.DataFrame(data, columns=['method', 'acc'])
df = df.groupby('method').mean().sort_values('acc', ascending=False)
methods = df.index.tolist()

In [20]:
panel = plot_scatter_panel(
    nref,
    methods=methods,
    rows=4, cols=3,
    title="",#"Annotation accuracy vs. cell type abundance", #"Accuracy by cell type count in query and reference",
    msize=6,
    hspace=0.01,
    vspace=0.03,
    title_y_off=1.0,
    ytitle_x_off=-0.08,
    xtitle_y_off=-0.05
)

panel.update_layout(width=1200, height=1600)
panel.show()