In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.optimize import curve_fit

In [None]:
def linear_fit(data, xcol, ycol):
    
    def linfunc(x, a, b):
        return a * x + b
    
    subset = np.log10(
        data[[xcol, ycol]]
    ).dropna()
    
    xlog = subset[xcol]
    ylog = subset[ycol]
    
    popt, _ = curve_fit(f = linfunc, 
                        xdata = xlog, 
                        ydata = ylog)
    
    xvals = np.log10(data[xcol].unique())
    yvals = linfunc(xvals, *popt)
    
    return xvals, yvals

def logarithmic_fit(data, xcol, ycol):
    
    def logfunc(x, a, b, c):
        return a * np.exp(-b * (x + c))
    
    subset = np.log10(
        data[[xcol, ycol]]
    ).dropna()
    
    xlog = subset[xcol]
    ylog = subset[ycol]
    
    popt, _ = curve_fit(f = logfunc, 
                        xdata = xlog, 
                        ydata = ylog)
    
    xvals = np.log10(data[xcol].unique())
    yvals = logfunc(xvals, *popt)
    
    return xvals, yvals    

def get_y(data, target_col):
    return np.log10(data[target_col])

In [None]:
speed = pd.read_csv("./data/method_comparison_speed.tsv", sep="\t")

ylabs = {
    "second":np.log10(1),
    "minute":np.log10(60),
    "hour":np.log10(3600),
    "day":np.log10(3600*24),
    "week":np.log10(3600*24*7)
    }

x = np.log10(speed["n_sequences"])
colors = sns.color_palette("colorblind")[:5]

with sns.axes_style("whitegrid"):

    fig, ax = plt.subplots(figsize=(4,3), dpi=200)

    method = "ClusTCR¹"
    x_pred, y_pred = linear_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[0])
    ax.plot(x_pred, y_pred, label="ClusTCR", alpha=.2, c=colors[0])

    method = "ClusTCR²"
    x_pred, y_pred = linear_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[0])
    ax.plot(x_pred, y_pred, label="ClusTCRmp", alpha=.2, c=colors[0], ls="--")

    method = "GLIPH2"
    x_pred, y_pred = linear_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[1])
    ax.plot(x_pred, y_pred, label=method, alpha=.2, c=colors[1])

    method = "iSMART"
    x_pred, y_pred = linear_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[2])
    ax.plot(x_pred, y_pred, label=method, alpha=.2, c=colors[2])

    method = "TCRDist*"
    x_pred, y_pred = linear_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[3])
    ax.plot(x_pred, y_pred, label="TCRDist", alpha=.2, c=colors[3])

    method = "GIANA"
    x_pred, y_pred = logarithmic_fit(speed, "n_sequences", method)
    ax.scatter(x, get_y(speed,method), s=10, c=colors[4])
    ax.plot(x_pred, y_pred, label=method, alpha=.2, c=colors[4])

    ax.set_xlabel("Number of sequences")
    ax.set_ylabel("Runtime")

    ax.set_yticks(list(ylabs.values()))
    ax.set_yticklabels(list(ylabs.keys()), fontsize=5)

    ax.set_xticks([4,5,6])
    ax.set_xticklabels(["10K", "100K", "1M"], fontsize=5)

    ax.legend(fontsize=8)
    
    sns.despine(left=True, bottom=True)
    
    fig.tight_layout()