In [None]:
from IPython.core.display import display_svg
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

from pyrejection.evaluation import render_svg_fig, standard_fig_style

In [None]:
Eu = 'Unconditional Error Rate: E<sup>u</sup>'
Ec = 'Conditional Error Rate: E<sup>c</sup>'
C = 'Coverage Rate: C'
R = 'Rejection Rate: R'

In [None]:
# The interpolation code assumes these points are all on the pareto front of the capacity curve.
classifiers = pd.DataFrame([
    {'label': 'f<sup>0</sup>', Eu: 0.7, C: 1},
    {'label': 'B', Eu: 0.35, C: 0.7},
    {'label': 'D (p=0.5)', Eu: 0.2, C: 0.45},
    {'label': 'A', Eu: 0.05, C: 0.2},
    {'label': 'f<sup>⊥</sup>', Eu: 0, C: 0},
])

def conditional_error(coverage, unconditional_error):
    if coverage == 0:
        return None
    return unconditional_error / coverage

classifiers[Ec] = classifiers.apply(lambda row: conditional_error(row[C], row[Eu]), axis=1)
classifiers[R] = 1 - classifiers[C]

In [None]:
def get_interpolated_classifiers(clfs, n):
    keys = [R, C, Eu, Ec]
    interpolated_classifiers = []
    for i in range(n+1):
        rejection = i / n
        prev_clfs = clfs[clfs[R] <= rejection]
        prev_clf = clfs.iloc[prev_clfs[R].idxmax()]
        next_clfs = clfs[clfs[R] >= rejection]
        next_clf = clfs.iloc[next_clfs[R].idxmin()]
        if prev_clf[R] == next_clf[R]:
            interpolated_classifiers.append(prev_clf[keys].to_dict())
            continue
        p = (rejection - prev_clf[R]) / (next_clf[R] - prev_clf[R])
        unconditional_error = (p * next_clf[Eu]) + ((1 - p) * prev_clf[Eu])
        coverage = 1 - rejection
        interpolated_classifiers.append({
            R: rejection,
            C: coverage,
            Eu: unconditional_error,
            Ec: conditional_error(coverage, unconditional_error),
        })
    return pd.DataFrame(interpolated_classifiers)

def plot_curve(classifiers, x, y):
    interpolated_classifiers = get_interpolated_classifiers(classifiers, 500)
    fig = px.line(interpolated_classifiers, x=x, y=y)
    if y == Eu:
        fig.add_trace(go.Scatter(
            x=classifiers[x],
            y=pd.Series(1, index=classifiers.index),
            fill='tonexty', # fill area between trace0 and trace1
            fillcolor='rgba(00,00,00,0.1)',
            mode='markers',
            marker=dict(opacity=0),
        ))
        fig.add_trace(go.Scatter(
            x=[0.3],
            y=[0.55],
            text=['Capacity'],
            mode='markers+text',
            marker=dict(opacity=0),
            textposition='top center',
            textfont=dict(size=24, color='#000000'),
        ))
    fig.add_trace(go.Scatter(
        x=classifiers[x],
        y=classifiers[y],
        text=classifiers['label'],
        mode='markers+text',
    ))
    fig.update_traces(
        textposition='bottom right',
        line=dict(color='#000000'),
        showlegend=False,
    )
    standard_fig_style(fig)
    fig.update_layout({
        'width': 550,
        'height': 450,
        'xaxis': {
            'range': [-0.01, 1.05],
        },
        'xaxis_title': x,
        'yaxis': {
            'range': [-0.1, 1.01],
        },
        'yaxis_title': y,
        'margin': {'r': 0},
        'legend_title_text': '',
        'font': {'size': 22},
    })
    return fig

In [None]:
render_svg_fig(plot_curve(classifiers, x=C, y=Eu))

In [None]:
render_svg_fig(plot_curve(classifiers, x=C, y=Ec))