In [None]:
import copy
import os
import warnings

import pandas as pd
import numpy as np
from scipy.stats import entropy

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
warnings.simplefilter(action='ignore', category=pd.core.common.SettingWithCopyWarning) 

In [None]:
from polyphony import Polyphony
from polyphony.data import load_pancreas
from polyphony.tool import eval, projection

In [None]:
def auto_confirm(pp, entropy_threshold=0.5):
    
    n_cells = len(pp.ref.cell_type.unique())
    norm_threshold = entropy_threshold * np.log2(n_cells)
    
    def valid_anchor(anchor):
        cell_ids = [cell['cell_id'] for cell in anchor.cells]
        qry_cells = pp.qry.adata.obs.loc[cell_ids]
        ref_cells = pp.ref.adata.obs[pp.ref.anchor_assign == str(anchor.reference_id)]
        if entropy(qry_cells['cell_type'].value_counts()) >= norm_threshold:
            return False
        if qry_cells['cell_type'].value_counts().index[0] != \
            ref_cells['cell_type'].value_counts().index[0]:
            return False
        return True
    
    for a in pp.anchors:
        if not a.confirmed and valid_anchor(a):
            pp.confirm_anchor(a.id)
            
            
def confirm_all(pp):
    for anchor in pp.anchors:
        pp.confirm_anchor(anchor.id)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

SMALL_SIZE = 6
MEDIUM_SIZE = 8
BIGGER_SIZE = 10

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def plot_results(results, metrics='ilisi', legend_prefix=None):
    if isinstance(metrics, str):
        metrics = [metrics]
    fig, ax = plt.subplots(1, len(metrics), figsize=(6, 2))
    for i, m in enumerate(metrics):
        for k, v in results.items():
            y = [re[m] for re in v]
            x = list(range(len(y)))
            label = "{}={}".format(legend_prefix, k) \
                if legend_prefix is not None else k
            ax[i].plot(x, y, linewidth=2.0, label=label)
        ax[i].legend()
    plt.show()

In [None]:
benchmark_results = []

confirm_fn = {
    'baseline': None,
    'threshold = 0.25': lambda pp: auto_confirm(pp, 0.25),
    'threshold = 0.5': lambda pp: auto_confirm(pp, 0.5),
    'threshold = 1': lambda pp: auto_confirm(pp, 1),
}

for i in range(1):
    exp_results = {}
    
    ref, qry = load_pancreas()
    pp = Polyphony('pancreas-benchmark-{}'.format(i), ref, qry)
    
    pp.setup_anndata()
    pp.init_reference_step()
    pp.save_snapshot()

    pp = Polyphony.load_snapshot('pancreas-benchmark-{}'.format(i), 0)
    pp.update_query_model()
    pp.save_snapshot()

    for k, v in confirm_fn.items():
        exp = Polyphony.load_snapshot('pancreas-benchmark-{}'.format(i), 1)
        exp_results[k] = eval.benchmark(exp, confirm_fn=v, warm_epochs=0, step_epochs=100)
        benchmark_results.append(exp_results)