# A*PA2 evals

This notebook contains the latest evals for A*PA2.

In [376]:
import numpy as np
import math
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='once', category=UserWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.lines as mlines
import json
from pathlib import Path

In [377]:
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)

# Data reading and preparation

In [378]:
labelsize=10
markersize=4
linewidth = 0.75

def column_display_name(col):
    d = {
        "divergence": "Divergence",
        "runtime": "Runtime per alignment [s]",
        "runtime_capped": "Runtime per alignment [s]",
        "s_per_pair": "Avg. runtime per alignment [s]",
        "s_per_pair_capped": "Avg. runtime per alignment [s]",
        "length": "Sequence length [bp]",
        "band": "Equivalent band",
        "algo_key": "algorithm",
        "algo_pretty": " ",
    }
    if col in d:
        return d[col]
    return col

dataset_pretty = {
    'sars-cov-2': 'SARS-CoV-2 pairs',
    'ont-1k': '1kbp ONT reads',
    # 'ont-10k': '10kbp ONT reads',
    'ont-50k': '10kbp ONT reads',
    'ont-500k': '>500kbp ONT reads',
    'ont-500k-genvar': '>500kbp ONT reads + gen.var.',
    'bam2seq_10kto20k': 'BAM 10k',
    'overlap_10kto20k': 'overlap 10k',
    'bam2seq_100kto200k': 'BAM 100k',
    'overlap_100kto200k': 'overlap 100k',
    'bam2seq_unrestricted': 'BAM',
    'overlap_unrestricted': 'overlap',
}
dataset_order = list(dataset_pretty.keys())
def dataset_key(key):
    return (dataset_order.index(key) if key in dataset_order else 99, key) 


# Line style:
# - slow (no pruning): dotted
# - normal: solid
# - diagonal-transition: dashed
# Colours:
# edlib/wfa ('extern'): blue/purple
# sh/csh/gcsh: orange -> brown -> green gradient
# noprune/normal/dt: 60% -> 70% -> 85% saturation
colors = {'dijkstra': '#786061', 'sh': "#e87146", 'csh': "#8c662a", 'gcsh': "#257d26"}
dashed = (0, (5, 5))
dotted = (0, (1, 4))
algorithm_styles = {
    "edlib": ("#DE4AFF", '-', 'Edlib'),
    "biwfa": ("#625AFF", '-', 'BiWFA'),
    'astarpa': ('#0f7a10', '-', 'A*PA'),
    'astarpa-r1': ('#0f7a10', '-', 'A*PA\n(r=1)'),
    'astarpa-preprune': ('#0f7a10', '-', '+PP'),


    "wfa-adaptive": ("#44A", '-', 'WFA Adaptive'),
    "blockaligner": ("#884", '-', 'Block Aligner'),

    # Summary
    'astarpa2-simple': ('#aa0000', '-', 'A*PA2\nsimple'),
    'astarpa2-full': ('#00aaaa', '-', 'A*PA2\nfull'),

    # Timing
    'astarpa2-t_simple': ('#aa0000', '-', 'A*PA2\nsimple'),
    'astarpa2-t_full': ('#00aaaa', '-', 'A*PA2\nfull'),

    # Incremental
    'astarpa2-gapgap': ('#aa0000', '-', 'Band\nDoubling'),
    'astarpa2-gapdist': ('#aa0000', '-', '+A*'),
    'astarpa2-blocks': ('#aa0000', '-', '+Blocks'),
    'astarpa2-simd': ('#aa0000', '-', '+SIMD'),
    'astarpa2-ilp': ('#aa0000', '-', '+ILP'),
    'astarpa2-dt-trace': ('#aa0000', '-', '+DTT'),
    'astarpa2-sparse_h': ('#aa0000', '-', '+Sparse h\nA*PA2\nsimple'),
    'astarpa2-incrementaldoubling': ('#00aaaa', '-', '+ID'),
    'astarpa2-GCSH': ('#00aaaa', '-', '+GCSH'),
    'astarpa2-pre-pruning': ('#00aaaa', '-', '+PP\n'),
    'astarpa2-prune': ('#00aaaa', '-', '+Pruning\nA*PA2\nfull'),
    
    # Ablation full
    'astarpa2-GCSH-base': ('#00aaaa', '-', 'GCSH\nbase'),
    'astarpa2-noGCSH': ('#00aaaa', '-', '-GCSH\n+SH'),
    'astarpa2-noGCSH-Gap': ('#00aaaa', '-', '-GCSH\n+Gap-h'),
    'astarpa2-noGCSH-GapGap': ('#00aaaa', '-', '-A*'),
    'astarpa2-nosimd': ('#00aaaa', '-', '-SIMD'),
    'astarpa2-noilp': ('#00aaaa', '-', '-ILP'),
    'astarpa2-nodt': ('#00aaaa', '-', '-DTT'),
    'astarpa2-nosparseh': ('#00aaaa', '-', '-Sparse h'),
    'astarpa2-noid': ('#00aaaa', '-', '-ID'),
    'astarpa2-noprune': ('#00aaaa', '-', '-Prune'),
    'astarpa2-nopreprune': ('#00aaaa', '-', '-PP'),

    # Ablation simple
    'astarpa2-simple-base': ('#cc0000', '-', 'A*PA2-simple'),
    'astarpa2-simple-gapgap': ('#cc0000', '-', '-A*'),
    'astarpa2-simple-nosimd': ('#cc0000', '-', '-SIMD'),
    'astarpa2-simple-noilp': ('#cc0000', '-', '-ILP'),
    'astarpa2-simple-id': ('#cc0000', '-', '+ID'),
    'astarpa2-simple-nosparseh': ('#cc0000', '-', '-Sparse h'),
    'astarpa2-simple-nodt': ('#cc0000', '-', '-DTT'),

    # Parameters
    # heuristic related
    'astarpa2-k10': ('#00aaaa', '-', 'k10'),
    'astarpa2-k11': ('#00aaaa', '-', 'k11'),
    'astarpa2-k13': ('#00aaaa', '-', 'k13'),
    'astarpa2-k14': ('#00aaaa', '-', 'k14'),
    'astarpa2-p7': ('#00aaaa', '-', 'p7'),
    'astarpa2-p28': ('#00aaaa', '-', 'p28'),
    # engineering related
    'astarpa2-f1.5': ('#00aaaa', '-', 'f1.5'),
    'astarpa2-f2.5': ('#00aaaa', '-', 'f2.5'),
    'astarpa2-B512': ('#00aaaa', '-', 'B512'),
    'astarpa2-B128': ('#00aaaa', '-', 'B128'),
    'astarpa2-B64': ('#00aaaa', '-', 'B64'),
    'astarpa2-g80': ('#00aaaa', '-', 'g80'),
    'astarpa2-g40': ('#00aaaa', '-', 'g40'),
    'astarpa2-g20': ('#00aaaa', '-', 'g20'),
    'astarpa2-g10': ('#00aaaa', '-', 'g10'),
    'astarpa2-x5': ('#00aaaa', '-', 'x5'),
    'astarpa2-x10': ('#00aaaa', '-', 'x10'),
    'astarpa2-x20': ('#00aaaa', '-', 'x20'),
    'astarpa2-x2': ('#00aaaa', '-', 'x2'),
    'astarpa2-r2': ('#00aaaa', '-', 'r2'),
    

    # Simple parameters
    # only engineering related
    'astarpa2-simple-f1.5': ('#cc0000', '-', 'f1.5'),
    'astarpa2-simple-f2.5': ('#cc0000', '-', 'f2.5'),
    'astarpa2-simple-B512': ('#cc0000', '-', 'B512'),
    'astarpa2-simple-B128': ('#cc0000', '-', 'B128'),
    'astarpa2-simple-B64': ('#cc0000', '-', 'B64'),
    'astarpa2-simple-g80': ('#cc0000', '-', 'g80'),
    'astarpa2-simple-g40': ('#cc0000', '-', 'g40'),
    'astarpa2-simple-g20': ('#cc0000', '-', 'g20'),
    'astarpa2-simple-g10': ('#cc0000', '-', 'g10'),
    'astarpa2-simple-x5': ('#cc0000', '-', 'x5'),
    'astarpa2-simple-x10': ('#cc0000', '-', 'x10'),
    'astarpa2-simple-x20': ('#cc0000', '-', 'x20'),
    'astarpa2-simple-x2': ('#cc0000', '-', 'x2'),

    'speedup': ('', '', 'A*PA2 speedup'),

}
algorithm_order = list(algorithm_styles.keys())
palette = {k: v[0] for k, v in algorithm_styles.items()}

def get_algorithm_key(row):
    name = row['algo_name']
    if name == 'Edlib': return 'edlib'
    if name == 'Wfa':
        if row.get('job_algo_Wfa_heuristic') != "None":
            return 'wfa-adaptive'
        if row['job_algo_Wfa_memorymodel'] == 'MemoryUltraLow':
            return 'biwfa'
        else:
            return 'wfa'
    if name == 'BlockAligner':
        return 'blockaligner'
    if name == 'AstarPa':
        t = row['job_algo_AstarPa_heuristic_type']
        r = row['job_algo_AstarPa_heuristic_r']
        key = 'astarpa'
        if r == 1:
            key += '-r1'
        if row['job_algo_AstarPa_heuristic_p']:
            key += '-preprune'
        return key
    if name == 'AstarPa2':
        key = 'astarpa2'
        name = row.job_algo_AstarPa2_name
        if name:
            return f'{key}-{name}'
        if row.job_algo_AstarPa2_front_Bit_sparse:
            key += '-sparse'
        if row.job_algo_AstarPa2_front_Bit_simd:
            key += '-simd'
        if row.job_algo_AstarPa2_sparsehcalls:
            key += '-h'
        return key
    return 'unknown'

# Returns display name, color, and style for an algorithm
def algorithm_display(row, split):
    (c, l, n) = algorithm_styles[row['algo_key']]
    if 'r' in split:
        if row.r:
            n += f' (r={row.r})'
    return (c, l, n)

In [379]:
def read_results(path):
    # - Read a json file
    # - Rename json fields from a_b to a-b
    # - Flatten into dataframe
    # - Flatten algorithm params into a few fields:
    #   - algo_name: the type of algorithm
    #   - algo_full: the json-string of algorithm parameters
    # - Rename and compute some common columns:
    #   - error-rate
    #   - length
    #   - s_per_pair
    #   - p_correct
    
    json_path = Path(path)
    data = json.loads(json_path.read_text())
    
    # Remove underscores from all keys
    def remove_underscores(o):
        if isinstance(o, list):
            return [remove_underscores(v) for v in o]
        if isinstance(o, dict):
            return {k.replace('_', ''): remove_underscores(v) for k, v in o.items()}
        return o
    
    data = remove_underscores(data)

    # Clean up algo columns
    for x in data:
        name = list(x['job']['algo'].keys())[0]
        obj = x['job']['algo']
        obj['name'] = name
        x['algo_name'] = name
        x['algo_full'] = json.dumps(obj)
        #del x['job']['algo']
        if 'Ok' in x['output']:
            del x['output']['Ok']['costs']

    # Flatten the js
    df = pd.json_normalize(data, sep='_')
    df['algo_key'] = df.apply(get_algorithm_key, axis=1)
    df['algo_pretty'] = df['algo_key'].map(lambda key: algorithm_styles[key][2])
    
    # Convenience renaming
    df = df.rename({'job_dataset_Generated_length': 'length',
                    'job_dataset_Generated_errorrate': 'errorrate',
                    'job_timelimit': 'timelimit',
                    'output_Ok_pcorrect': 'pcorrect',
                    'output_Ok_measured_runtime': 'runtime',
                    'output_Ok_measured_memory': 'memory',
                    'stats_divergence_mean': 'divergence',
                    'job_algo_AstarPa_diagonaltransition': 'dt',
                    'job_algo_AstarPa_heuristic_prune': 'prune',
                    'job_algo_AstarPa_heuristic_r': 'r',
                    #'job_algo_AstarPa2_heuristic_r': 'r',
                   }, axis='columns')
    if 'r' not in df.columns:
        df['r'] = 1
    
    # Order rows
    df['algo_ord'] = df['algo_key'].map(lambda key: algorithm_order.index(key))
    df.sort_values(by='algo_ord', inplace=True, kind = 'stable')
    if 'length' in df.columns:
        df.sort_values(by='length', inplace=True, kind = 'stable')
    if 'errorrate' in df.columns:
        df.sort_values(by='errorrate', inplace=True, kind = 'stable')
    # Order by dataset
    if 'job_dataset_File' in df.columns and df.job_dataset_File.notna().all():
        df['dataset'] = df['job_dataset_File'].map(lambda f: Path(f).parent.name)
        df['dataset_ord'] = df['dataset'].map(dataset_key)
        df.sort_values(by='dataset_ord', inplace=True, kind = 'stable')
    
    # Computed columns
    df['costmodel'] = df.apply(lambda row: (row['job_costs_sub'], row['job_costs_open'], row['job_costs_extend']), axis=1)
    df['s_per_pair'] = df['runtime'] / df['stats_seqpairs']
    df['timelimit_per_pair'] = df['timelimit'] / df['stats_seqpairs']
    if 'length' in df.columns and 'output_Ok_stats_expanded' in df.columns:
        df['band'] = df['output_Ok_stats_expanded'] / (df['stats_seqpairs']* df['length'])

    def runtime_capped(row):
        if not math.isnan(row['runtime']):
            return row['runtime']
        if row['output_Err'] == 'Timeout':
            return row['timelimit']
        return row['timelimit']*1.1
    df['runtime_capped'] = df.apply(runtime_capped, axis = 1)
    df['s_per_pair_capped'] = df['runtime_capped'] / df['stats_seqpairs']
    
    df['editdistance'] = df['stats_insertions'] + df['stats_deletions'] + df['stats_substitutions']
    
    # Some specific fixes
    df = df.fillna({'r': 0}, downcast='infer')
    
    # Remove unsupported algos
    if 'output_Err' in df.columns:
        df = df[df.output_Err != 'Unsupported']
    
    return df

## The one plotting function

In [416]:
def plot(df,
         name='',
         file=None,
         x='length',
         y='s_per_pair',
         # Column to use for hue and style.
         # Always change both at the same time!
         hue='algo_key',
         style='r',
         # column to use for marker size
         size=None,
         # Logarithmic axes by default
         xlog=True,
         xlim=(0, None),
         ylog=True,
         ylim=None,
         # alph
         alpha=1.0,
         # Use line instead of scatter plot?
         connect=False,
         # Draw a cone from the given filter and x
         cone=None,
         cone_x=3*10**4,
         fit=False,
         line_labels=False,
         categorical=False,
         ax=None,
         width=None,
         height=None,
         png=False,
         mp=None,
         minorticks=False,
        ):
    
    if df[y].isna().all():
        print(f"All values of {y} are nan.")
        return
    
    df = df[df[y].notnull()]
    assert not df.empty
    
    # We group data by this set of keys.
    split = [hue, style]
    
    # Remove 'r' from the split if not both r=1 and r=2 are present,
    # to prevent redundant (r=1) in plots.
    if 'r' in split and 'r' in df.columns:
        if not (1 in df.r.values and 2 in df.r.values):
            split.remove('r')
    
    # Group the data into datapoints per line
    groups = df.groupby(split, sort=False)
    
    # Not sure if needed actually.
    sns.reset_defaults()
    sns.set_context(None) # 'paper', 'notebook'
    
    # Set up the figure if not provided.
    if ax is None:
        fig, ax = plt.subplots()
        width = width or 3
        height = height or 2
        fig.set_size_inches(width, height, forward=True)
        hasax = False
    else:
        hasax = True

    
    # Set log scales
    ax.set(xscale='log' if xlog else 'linear', yscale='log' if ylog else 'linear')
    
    # limit number of ticks
    if ylog:
        ax.locator_params(axis='y', numticks=6)
        ax.yaxis.set_minor_formatter(plt.NullFormatter())
    else:
        ax.locator_params(axis='y', nbins=6)

    if xlog:
        ax.locator_params(axis='x', numticks=6)
        ax.xaxis.set_minor_formatter(plt.NullFormatter())
    else:
        ax.locator_params(axis='x', nbins=5)
    
    
    # PLOTTING
    
    if not categorical:
        # Show a scatterplot of points.
        # Each group is plotted separately for more control over its style.
        for k, group in groups:
            first_row = group.iloc[0]
            color, linestyle, grouplabel = algorithm_display(first_row, split)

            ax.plot(x,
                    y,
                    data=group.sort_values(by=x),
                    color=color,
                    linestyle=linestyle if connect else 'None',
                    marker='o',
                    alpha=alpha,
                    dash_capstyle = 'round',
                    label=grouplabel,
                    zorder=2,
                    markersize=4,
                    linewidth=linewidth
                   )
    if categorical:
        # Overlay a boxplot and swarmplot on top of each other
        for k, group in groups: 
            is_exact = group.iloc[0].output_Ok_isexact
            marker = 'o' if is_exact else '^'
            markersize = 3 if is_exact else 4
            sns.swarmplot(data=group,
                            x=x,
                            y=y,
                            hue=hue,
                            palette=palette,
                            ax=ax,
                            size=markersize,
                            linewidth=0,
                            edgecolor='gray',
                            zorder=0.5,
                            dodge=False,
                            marker=marker,
            )
        sns.boxplot(data=df,
                    x=x,
                    y=y,
                    ax=ax,
                    linewidth=linewidth,
                    whis=0,
                    showcaps=False,
                    showfliers=False,
                    boxprops={'facecolor':'None'},
                    whiskerprops={'linewidth':0},
                    showmeans=True,
                    meanprops={"marker":"o",
                               "markerfacecolor":"white", 
                               "markeredgecolor":"red",
                               "markersize":"7"}
                    )
    
    # TEXT
    
    # Title
    if name:
        ax.set_title(name, y=1.05)
    
    # Remove legend
    ax.legend().remove()
    
    # BACKGROUND
    ax.set_facecolor("#F8F8F8")
    ax.set_axisbelow(True) 
    ax.grid(False)
    if categorical:
        ax.tick_params(axis="y", which="both", right=True)
        ax.grid(True, axis="y", which="major", color="black", alpha=.5, zorder=0, lw=0.5)
        ax.grid(True, axis="y", which="minor", color="black", alpha=.1, zorder=0, lw=0.5)
    else:
        ax.grid(True, axis="y", which="major", color="white", alpha=1, zorder=0)
    
    
    # AXES
    
    # Labels
    ax.set_xlabel(column_display_name(x))  # weight='bold',
    ax.set_ylabel(column_display_name(y), rotation=0, ha="left")
    ax.yaxis.set_label_coords(-0.5/width if width else -0.1, 1.00)
    
    # Limits
    x_margin = 1.5
    y_margin = 1.5
    if xlog:
        #xs = df[df[x] > 0][x]
        ax.set_xlim(df[x].min() / x_margin, df[x].max() * x_margin)

    if ylog:
        ax.set_ylim(df[y].min() / y_margin, df[y].max() * y_margin)
    
    # Start linear scales at 0.
    if not xlog and not categorical and x != 'job_costs_open':
        ax.set(xlim=xlim)
    if not ylog:
        ax.set(ylim=(0,None))
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
 
    
    # Show bottom spine, and left spine when xlog=false
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(not xlog and not categorical)
    
    # Format % scales.
    if x in ['errorrate', 'divergence']:
        ax.xaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))
    
    # Show major ticks
    ax.tick_params(
        axis="both",
        which="major",
        bottom=True,
        top=False,
        left=True,
        right=False,
    )
    # No minor ticks
    ax.tick_params(
        axis="both",
        which="minor",
        bottom=False,
        top=False,
        left=False,
        right=False,
        labelbottom=False,  # labels along the bottom edge are off
    )
    # Do show minor ticks for small log ranges
    if ylog and minorticks:
         ax.tick_params(axis="y", which="minor", left=True)
    
    
    # CONE
    # Fills the region between x**1 and x**2
    if cone:
        x0 = cone_x
        x_max = x_margin * df[x].max()
        x_range = (x0, x_max)
        
        y0 = df[cone(df) & (df[x] == cone_x)][y].max()
        y_lin = (y0, y0 * (x_max / x0) ** 1)
        y_quad = (y0, y0 * (x_max / x0) ** 2)
        ax.fill_between(x_range, y_lin, y_quad, color="grey", alpha=0.15, zorder=0.4)
        
    # TIME LIMIT
    if y=='runtime_capped' or (y=='s_per_pair_capped' and x != 'length'):
        timeouts = df[df.runtime.isna()]
        if len(timeouts.index) > 0:
            # assert len(timeouts.timelimit_per_pair.unique()) == 1, str(timeouts.timelimit_per_pair.unique())

            timelimit = timeouts.iloc[0].timelimit_per_pair

            # draw a red line at the timelimit.
            ax.axhline(y=timelimit, color="red", linestyle="-", alpha=1, linewidth=0.5)

            # Modify/add the timelimit ticklabel with TL=
            if False:
                ylabels = [x for x in ax.get_yticklabels()]
                found = False
                for i, l in enumerate(ylabels):
                    if l.get_position()[1] == timelimit:
                        ylabels[i] = "TL=" + ylabels[i].get_text()
                        found = True
                if found:
                    ax.set_yticklabels(ylabels)
                else:
                    yticks = list(ax.get_yticks())
                    ylabels = list(ax.get_yticklabels())
                    yticks.append(timelimit)
                    ylabels.append("TLE")
                    ax.set_yticks(yticks)
                    try:
                        ax.set_yticklabels(ylabels)
                    except ValueError:
                        pass
                    finally:
                        pass

    # POLY FIT

    def angle(slope):
        x_min, x_max = ax.get_xlim()
        y_min, y_max = ax.get_ylim()
        bbox = ax.get_window_extent()
        x_sz = bbox.width
        y_sz = bbox.height
        x_factor = x_sz / (np.log10(x_max) - np.log10(x_min) if xlog else x_max - x_min)
        y_factor = y_sz / (np.log10(y_max) - np.log10(y_min) if ylog else y_max - y_min) 
        slope = slope * y_factor / x_factor
        return math.atan(slope)*180/math.pi
    
    if fit:
        assert x=='length' and xlog and ylog, "Polynomial fits only work in log-log plots with x=length"
        for k, group in groups:
            first_row = group.iloc[0]
            color, linestyle, grouplabel = algorithm_display(first_row, split)
            grouplabel = grouplabel.replace('\n', ' ')
            fit_label = grouplabel
            
            filtered = group[group.runtime.notnull()]
            ps = filtered[[x,y]].dropna()
            xmin, xmax = filtered[x].min(), filtered[x].max()
            if len(ps) > 1:
                fit = np.polyfit(np.log(ps[x]), np.log(ps[y]), 1)
                f = lambda x: x**fit[0] * np.exp(fit[1])
                # Extra {{ and }} are for the math-mode superscript
                fit_label = f"{grouplabel} $\sim n^{{{fit[0]:0.2f}}}$"

                ymin, ymax = f(xmin), f(xmax)
                # line from xmin to xmax (use plt.axline for infinite line)
                ax.plot([xmin, xmax], [ymin, ymax], color=color, linestyle=linestyle, alpha=1, dash_capstyle = 'round', zorder=2, linewidth=linewidth)
                #print(f'Exponent for {k}: {fit[0]:0.2f}')

            ax.text(
                xmax,
                min(ymax, ax.get_ylim()[1]),
                fit_label,
                color=color,
                ha="right",
                va="bottom",
                size=labelsize,
                alpha=1,
                rotation=angle(fit[0]),
                rotation_mode='anchor',
            )
    if line_labels:
        # If no legend and no fits are shown, show manual labels instead
        for split_key, group in groups:
            first_row = group.iloc[0]
            color, linestyle, grouplabel = algorithm_display(first_row, split)

            grouplabel = grouplabel.replace('\n', ' ')

            max_idx = group[x].idxmax()
            label_x = group[x][max_idx]
            label_y = min(group[y][max_idx], ax.get_ylim()[1])
            key = split_key[0] if isinstance(split_key, tuple) else split_key
            
            by_x = group[x].argsort()
            last = group.iloc[by_x.iloc[-1]]
            before = group.iloc[by_x.iloc[-3]]
            slope = (last[y] - before[y])/(last[x] - before[x])
            ax.text(
                label_x,
                label_y,
                grouplabel,
                color=color,
                ha="right",
                va="bottom",
                size=labelsize,
                alpha=1,
                rotation=angle(slope),
                rotation_mode='anchor',
            )

    if not hasax:
        if file:
            plt.savefig(f"plots/{file}.svg", dpi=300, bbox_inches='tight')
            if png:
                plt.savefig(f"plots/{file}.png", dpi=300, bbox_inches='tight')

In [381]:
# Scaling with DIVERGENCE
df = read_results("results/scaling-e.json")
plot(df, file=f'scaling_e', x='divergence', y='s_per_pair', size=None, xlog=False, ylog=False, connect=True, line_labels=True,
     ylim=(0,0.28), width=4.4, height=3)
plt.close()
plot(df, file=f'scaling_e_zoom', x='divergence', y='s_per_pair', size=None, xlog=False, ylog=False, connect=True, line_labels=True,
     ylim=(0,0.079), width=4.4, height=3)
plt.close()

In [405]:
# SCALING WITH LENGTH

df = read_results("results/scaling-n.json")
cone = lambda df: ((df['algo_key'] == 'astarpa-r1') | (df['algo_key'] == 'astarpa'))
for e, g in df.groupby('errorrate'):
    plot(g, file=f'scaling_n_e{e}', x='length', y='s_per_pair', fit=True, cone=cone, cone_x = 10000, width=4.4, height=3, ylim = (10**-3.9, 10**3.9))
plt.close()


In [414]:
# BOXPLOTS on real data
def boxplot(path, w0, row=False, vlines=[], wr=None,wspace=0.15, ylim=None, sharelegend=False, minorticks=False):
    df = read_results(f"results/{path}.json")
    ww=1
    datasets = len(df.dataset.unique())
    hh = (datasets+ww-1)//ww
    if row:
        ww,hh=hh,ww
    w = ww*w0
    h = 3.7 * hh
    fig, axs = plt.subplots(hh, ww, figsize=(w, h), gridspec_kw={'width_ratios': wr})
    if not isinstance(axs, np.ndarray):
        axs = [axs]
    if isinstance(axs[0], np.ndarray):
        axs = [x for col in zip(*axs) for x in col]
    ylim = ylim or [None] * datasets
    for (k, g), ax, ylim in zip(df.groupby('dataset',sort=False),axs,ylim):
        avg = g.stats_seqpairs.unique().max() > 1
        y = 's_per_pair_capped' if avg else 'runtime_capped'
        plot(g, x='algo_pretty', y=y,
             xlog=False,
             ylog=True,
             ylim=ylim,
             categorical=True,
             ax=ax,
             width=w,
             minorticks=minorticks
             )
        ax.set_xlabel(dataset_pretty.get(k, k))
        if sharelegend:
            ax.set(xticklabels=[])
    
        for x in vlines:
            ax.axvline(x=x, color="black", alpha=0.5, linewidth=0.5, zorder=0.1)
    
    if sharelegend:
        keys = g.algo_key.unique()
        handles = []
        labels = []
        for key in keys:
            is_exact = g[g.algo_key == key].output_Ok_isexact.unique()[0]
            color, style, label = algorithm_styles[key]
            marker = 'o' if is_exact else '^'
            markersize = 7 if is_exact else 7
            labels.append(label.replace('\n', ' '))
            handles.append(mlines.Line2D([], [], color=color, marker=marker, linestyle='None',
                                         markersize=markersize))

        fig.legend(handles,
                   labels,
                   loc='lower center',
                   ncols=7,
                   bbox_to_anchor=[0.5, -0.08],
                   markerscale=2.3,
                   fontsize=12,
                   frameon=False,
                   handletextpad=0.1,
                   columnspacing=1.5,
        )
        fig.subplots_adjust(wspace=wspace, hspace=0.4)

    plt.savefig(f"plots/{path}.svg", bbox_inches='tight')
    plt.close()


In [411]:
# AVERAGE RUNTIME TABLE
import tabulate
def table(path):
    def key_fn(keys):
        return [dataset_key(key)[0] for key in keys]
    df = read_results(f"results/{path}.json")
    df.loc[df.algo_key == 'astarpa-r1', 'algo_key'] = 'astarpa'

    table = df.pivot_table(index='algo_key', columns=['dataset'], values='s_per_pair_capped', aggfunc=np.mean, sort=False)
    table.sort_index(axis=1, level=0, inplace=True, sort_remaining=False, key=key_fn)

    # Best A*PA2 is x faster than best of Edlib/Biwfa.
    table.loc['speedup'] =  np.minimum(table.loc['edlib'], table.loc['biwfa'])/np.minimum(table.loc['astarpa2-simple'], table.loc['astarpa2-full']) 

    table = table.round({'sars-cov-2': 5, 'ont-1k': 6, 'ont-10k': 5, 'ont-50k': 4, 'ont-500k': 2, 'ont-500k-genvar': 2})

    table = table.rename(axis=1, level=0, mapper=lambda c: dataset_pretty[c])
    table = table.rename(axis=0, mapper=lambda a: algorithm_styles[a][2].replace('\n', ' '))
    # display(table)
    print(tabulate.tabulate(table, headers=table.columns, tablefmt='orgtbl'))
    #print(table.to_latex())

    # TIMEOUTS TABLE
    table = df.pivot_table(index='algo_key', columns=['dataset'], values='runtime', aggfunc={'runtime': lambda x: x.isnull().sum()}, sort=False)
    table.sort_index(axis=1, level=0, inplace=True, sort_remaining=False, key=key_fn)
    table = table.rename(axis=1, level=0, mapper=lambda c: dataset_pretty[c])
    table = table.rename(axis=0, mapper=lambda a: algorithm_styles[a][2].replace('\n', ' '))
    print()
    print(tabulate.tabulate(table, headers=table.columns, tablefmt='orgtbl'))

    # % CORRECT TABLE
    df = df[df.output_Ok_isexact == False]
    # display(df.algo_name.unique())
    # display(df[df.algo_name == 'BlockAligner'].pcorrect)
    # df['ncorrect'] = df.pcorrect * df.stats_seqpairs
    table = df.pivot_table(index='algo_key', columns=['dataset'], values='pcorrect', aggfunc={'pcorrect': lambda x: x.mean()*100}, sort=False).round(0).astype('int')

    table.sort_index(axis=1, level=0, inplace=True, sort_remaining=False, key=key_fn)
    table = table.rename(axis=1, level=0, mapper=lambda c: dataset_pretty[c])
    table = table.rename(axis=0, mapper=lambda a: algorithm_styles[a][2].replace('\n', ' '))
    print()
    print(tabulate.tabulate(table, headers=table.columns, tablefmt='orgtbl'))


In [384]:
# INCREMENTAL
boxplot('real-incremental', 11, vlines=[1.5, 3.5, 10.5], ylim=[(0.04, 35)])

In [385]:
# SUMMARY
ymax=[0.1, 0.01, 1, 100, 100]
ylim = [(0.00055*x, 0.29*x) for x in ymax]
table('real-summary')
boxplot('real-summary', 4.0, row=True, wspace=0.22,vlines=[2.5,4.5],ylim=ylim,sharelegend=True)


|               |   SARS-CoV-2 pairs |   1kbp ONT reads |   10kbp ONT reads |   >500kbp ONT reads |   >500kbp ONT reads + gen.var. |
|---------------+--------------------+------------------+-------------------+---------------------+--------------------------------|
| Edlib         |            0.01114 |         0.00011  |            0.008  |                3.74 |                           5.2  |
| BiWFA         |            0.00113 |         4.2e-05  |            0.0093 |                4.47 |                           6.96 |
| A*PA          |            0.00625 |         0.000514 |            0.1901 |               14.01 |                          12.92 |
| WFA Adaptive  |            0.00085 |         3.8e-05  |            0.003  |                1.04 |                           0.81 |
| Block Aligner |            0.00235 |         3.8e-05  |            0.0009 |                0.63 |                           0.68 |
| A*PA2 simple  |            0.00089 |         5.2e-05  |            

In [417]:
# ABLATION
boxplot('real-ablation', 6, wr=[7, 7, 11], row=True, vlines=[0.5], wspace=0.11, minorticks=True)

In [418]:
# PARAMS
boxplot('real-params', 6, wr=[10, 10, 17], row=True, vlines=[0.5], wspace=0.11, minorticks=True)

In [422]:
# SCATTER PLOTS ON REAL DATA
# def scatterplot(path, w0, row=False, vlines=[], wr=None,wspace=0.15,ylim=None):
#     df = read_results(f"results/{path}.json")
#     ww=1
#     datasets = len(df.dataset.unique())
#     hh = (datasets+ww-1)//ww
#     if row:
#         ww,hh=hh,ww
#     w = ww*w0
#     h = 3.7 * hh
#     fig, axs = plt.subplots(hh, ww, figsize=(w, h), gridspec_kw={'width_ratios': wr})
#     if not isinstance(axs, np.ndarray):
#         axs = [axs]
#     if isinstance(axs[0], np.ndarray):
#         axs = [x for col in zip(*axs) for x in col]
#     if ylim is None:
#         ylim=[None] * datasets
#     for (k, g), ax,ylim in zip(df.groupby('dataset',sort=False),axs,ylim):
#         avg = g.stats_seqpairs.unique().max() > 1
#         plot(g,
#              x='divergence',
#              y='s_per_pair_capped' if avg else 'runtime_capped',
#              hue='algo_key',
#              xlog=False,
#              ylog=True,
#              categorical=False,
#              ax=ax,
#              width=w,
#              xlim=None,
#              ylim=ylim
#              )
#         ax.set_xlabel(dataset_pretty.get(k, k))
#         for x in vlines:
#             ax.axvline(x=x, color="black", alpha=0.5, linewidth=0.5, zorder=0.1)
    
#     fig.subplots_adjust(wspace=wspace, hspace=0.4)

#     plt.savefig(f"plots/{path}-scatter.svg", bbox_inches='tight')
#     plt.close()

# ymax=[0.1, 0.01, 1, 100, 100]
# ylim = [(0.00055*x, 0.29*x) for x in ymax]
# scatterplot('real-summary', 4.0, row=True, wspace=0.25, ylim=ylim)

In [389]:
# TIMING
df = read_results("results/real-timing.json")
time_labels = {
    'precomp': 'Precomputation',
    'jrange': 'Computing ranges',
    # 'fixedjrange': 'Range to reuse',
    'compute': 'Computing blocks',
    'pruning': 'Pruning matches',
    # 'contoursupdate': 'Contours update',
    'tracedt': 'DT trace',
    'tracefill': 'Fill trace',
    'rest': 'Overhead',
}

df['output_Ok_stats_tpruning'] += df['output_Ok_stats_tcontoursupdate']
df['output_Ok_stats_tjrange'] += df['output_Ok_stats_tfixedjrange']

for c in df.columns:
    prefix = 'output_Ok_stats_t'
    if c.startswith(prefix):
        name = c[len(prefix):]
        label = time_labels.get(name, name)
        df[label] = df[c] / df['stats_seqpairs']

def rest(row):
    t = row['s_per_pair']
    for c in time_labels.values():
        if c == 'Overhead': continue
        t -= row[c]
    return t

df['Overhead'] = df.apply(rest, axis=1)

datasets = len(df.dataset.unique())
fig, axes = plt.subplots(nrows=1, ncols=datasets, figsize=(3.5*datasets, 3.7))
fig.subplots_adjust(wspace=0.3)

# ymax = [0.0024, 0.00012, 0.00060, 0.0024, 0.24, 0.24]
ymax = [0.0024, 0.00012, 0.0024, 0.24, 0.24]

for i, ((k, g), ax, ymax) in enumerate(zip(df.groupby('dataset', sort=False), axes, ymax)):
    df = g
    df.sort_values(by='s_per_pair', inplace=True, kind = 'stable')
    df.plot.bar(y = time_labels.values(), stacked=True, width=.9, zorder=2, ax=ax, color = sns.color_palette(), ylim=(0, ymax))
    if df.stats_seqpairs.max() > 1:
        label = 'Avg. runtime per alignment [s]'
    else:
        label = 'Runtime per alignment [s]'
    ax.set_ylabel(label, rotation=0, ha="left")
    ax.set_xlabel(dataset_pretty[k])
    if i > 0:
        ax.legend().remove()
    ax.tick_params(
        bottom=False,
    )
    # ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0),useOffset=False)
    ax.set_xticklabels([])
    ax.locator_params(axis='y', nbins=2)
    # ax.yaxis.set_major_formatter(mtick.LogFormatterSciNotation(labelOnlyBase=False, minor_thresholds=(2,2), linthresh=1))
    ax.set_facecolor("#F8F8F8")
    ax.grid(False)
    ax.grid(True, axis="y", which="major", color="w")
    ax.yaxis.set_label_coords(-0.15,1.0)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(False)
    
plt.savefig(f"plots/real-timing.svg", bbox_inches='tight')    
plt.close()

In [423]:
# MEMORY USAGE
df = read_results("results/real-summary.json")
df.loc[df.algo_key == 'astarpa-r1', 'algo_key'] = 'astarpa'
df = df[df.memory.notna()]
df.memory = df.memory/1000000
df['memory2'] = df.memory
#df = df[df.algo_key.isin(['edlib', 'biwfa', 'gcsh-dt', 'astarnw', 'astarnw-sparse'])]
table = df.pivot_table(index='algo_key', columns=['dataset'], values=['memory2', 'memory'], aggfunc={'memory2': np.median, 'memory': np.max}, sort=False).round(0).astype('int')
table =table.rename({'memory2': 'Median', 'memory': 'Max'}, axis='columns')
table = table.swaplevel(axis=1)
def key_fn(keys):
    return [dataset_key(key)[0] for key in keys]
table.sort_index(axis=1, level=0, inplace=True, sort_remaining=False, key=key_fn)
# Pretty column names
table = table.rename(axis=1, level=0, mapper=lambda c: dataset_pretty[c])
table = table.rename(axis=0, mapper=lambda a: algorithm_styles[a][2].replace('\n', ' '))
# display(table)
import tabulate
headers = [' '.join(c) for c in table.columns]
print(tabulate.tabulate(table, headers=headers, tablefmt='orgtbl'))
#print(table.to_latex())

|               |   SARS-CoV-2 pairs Median |   SARS-CoV-2 pairs Max |   1kbp ONT reads Median |   1kbp ONT reads Max |   10kbp ONT reads Median |   10kbp ONT reads Max |   >500kbp ONT reads Median |   >500kbp ONT reads Max |   >500kbp ONT reads + gen.var. Median |   >500kbp ONT reads + gen.var. Max |
|---------------+---------------------------+------------------------+-------------------------+----------------------+--------------------------+-----------------------+----------------------------+-------------------------+---------------------------------------+------------------------------------|
| Edlib         |                         0 |                      0 |                       0 |                    0 |                        0 |                     0 |                          0 |                       0 |                                     0 |                                  0 |
| BiWFA         |                         0 |                      0 |                     

In [None]:
# Sanity check: CPU frequency
# Make sure that the CPU frequency is consistent over all experiments.
df = read_results("results/real-summary.json")
print(len(df.runtime))
print(df.runtime.sum())
df = df.rename({'output_Ok_measured_cpufreqstart': 'freqstart','output_Ok_measured_cpufreqend': 'freqend'}, axis='columns')
print(df[df.freqstart < 3550][['freqstart','freqend', 'output_Ok_measured_timestart', 'runtime']])
print(df[df.freqend < 3550][['freqstart','freqend', 'output_Ok_measured_timestart', 'runtime']])
for c in ['freqend', 'freqstart']:
    print(c, df[c].min(), df[c].max())
    assert df[c].max() < 3650
    assert df[c].min() > 3500

In [None]:
#%history -g -f filename