In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.stats.proportion import proportion_confint
import matplotlib.ticker as mticker
from matplotlib.lines import Line2D
import numpy as np

In [None]:
df = pd.read_csv( snakemake.input.terra_results, sep="\t" )
df = df.loc[df["vibecheck_reads_lineage"].notna()]
df.loc[df["vibecheck_reads_lineage"].str.startswith( "anc"),"vibecheck_reads_lineage"] = "UNK"
df["correct"] = df["vibecheck_reads_lineage"] == df["te"]
df = df.loc[~df["te"].isin(["sporadic","T17"])]
df.head()

In [None]:
successes = df["correct"].sum()
observations = df.shape[0]
accuracy = successes / observations
ci = proportion_confint( successes, observations, alpha=0.05, method="jeffries" )
print( f"Vibecheck correctly reported {accuracy:.1%} of samples (95% confidence interval: {ci[0]:.1%} - {ci[1]:.1%})" )

In [None]:
summ = df.groupby( "te" )["correct"].agg( ["count", "sum"] ).reset_index()
summ.columns = ["lineage_actual", "observations", "successes"]
summ["accuracy"] = summ["successes"] / summ["observations"]
summ[["accuracy_low", "accuracy_high"]] = summ.apply( lambda x: pd.Series( proportion_confint( x["successes"], x["observations"], alpha=0.05, method="jeffreys" ) ), axis=1 )
summ.sort_values( "lineage_actual" )
summ["numeric_lineage"] = summ["lineage_actual"].str.extract(r"(\d+)" ).fillna(99).astype(int)
summ = summ.sort_values( by="numeric_lineage" ).reset_index( drop=True )
summ["failures"] = (summ["observations"] - summ["successes"]) / summ["observations"]
summ

In [None]:
confusion_matrix = df.pivot_table( index="vibecheck_reads_lineage", columns="te", values="vibecheck_reads_confidence", aggfunc="count", fill_value=0 )
confusion_matrix = confusion_matrix.reindex( columns=summ["lineage_actual"], index=summ["lineage_actual"]).fillna(0)
confusion_matrix

In [None]:
fig, ax = plt.subplots( dpi=200, figsize=(10,4), ncols=2 )

lineages = summ.shape[0]

ax[0].bar( summ.index, summ["accuracy"], color="skyblue", zorder=100 )
ax[0].bar( summ.index, summ["failures"], bottom=summ["accuracy"], color="red", hatch="/////", edgecolor="gainsboro", linewidth=0, zorder=100 )

ax[0].bar( lineages, accuracy, color="skyblue", zorder=100 )
ax[0].bar( lineages, 1-accuracy, bottom=accuracy, color="red", hatch="/////", edgecolor="gainsboro", linewidth=0, zorder=100 )

ax[0].set_xticks( range( lineages + 1), summ["lineage_actual"].replace( {"UNK" : "Other"} ).to_list() + ["All"], fontsize=7 )
ax[0].yaxis.set_major_formatter(mticker.PercentFormatter(1))
ax[0].set_yticks(np.arange(0,1.2,0.2))
ax[0].set_yticks(np.arange(0,1.05,0.05), minor=True)

ax[0].set_ylim( 0, 1)
ax[0].set_xlim( -0.5, lineages + 0.5 )

ax[0].axvline( lineages - 0.5, color="black", linewidth=0.75 )

ax[0].set_ylabel( "Proportion of samples", fontweight="bold")

ax[0].grid( which="major", axis="y", linewidth=1, color="#F1F1F1", zorder=1 )
ax[0].grid( which="minor", axis="y", linewidth=0.5, color="#F1F1F1", zorder=1 )

legend1 = [
    Line2D([0], [0], linestyle='none', marker='s', color="skyblue", markeredgecolor="black", markeredgewidth=1, label="Correct", markersize=10 ),
    Line2D([0], [0], linestyle='none', marker='s', color="red", markeredgecolor="black", markeredgewidth=1, label="Incorrect", markersize=10 ),
]

legend1 = ax[0].legend( handles=legend1, title="Lineage assignment", loc="lower left", handletextpad=0, edgecolor="white", fancybox=False, alignment="left", fontsize=10, title_fontproperties={"size" : 10, "weight" : "bold"} )
legend1.set_zorder(150)

ax[1].imshow( confusion_matrix, cmap="Blues", vmax=75 )

ax[1].set_xticks( range( confusion_matrix.shape[0] ), [i.replace( "UNK", "Other") for i in confusion_matrix.columns], fontsize=7 )
ax[1].set_yticks( range( confusion_matrix.shape[1] ), [i.replace( "UNK", "Other") for i in confusion_matrix.index], fontsize=7 )

for i in range( confusion_matrix.shape[0] ):
    for j in range( confusion_matrix.shape[1] ):
        value = confusion_matrix.iloc[i, j]
        if value > 0:
            ax[1].text( j, i, int( confusion_matrix.iloc[i, j] ), ha="center", va="center", color="black" if value < 50 else "white", fontsize=8 )

ax[1].set_xticks( np.arange( 0, confusion_matrix.shape[0] ) + 0.5, minor=True )
ax[1].set_yticks( np.arange( 0, confusion_matrix.shape[0] ) + 0.5, minor=True )

ax[1].grid( which="minor", color="white", linewidth=1, zorder=100)
ax[1].tick_params( axis="both", which="minor", size=0 )

ax[1].set_xlabel( "Actual lineage", fontweight="bold" )
ax[1].set_ylabel( "Assigned lineage", fontweight="bold" )

plt.tight_layout()
plt.savefig( snakemake.output.accuracy_plot )
plt.show()