In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

from collections import defaultdict
from pathlib import Path
from operator import itemgetter
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [3]:
interpro_path = Path("/weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/")
dataset_path = interpro_path / "swissprot/sharded_swissprot/with_ss/dataset_splits/seq_splits/"
labels_path = interpro_path / "label_sets/selected_subset"

evals_path = Path("/net/vast-storage/scratch/vast/kellislab/artliang/magneton/magneton/tests/esmc/")

outpath = Path("~/temp/figs/").expanduser()
os.makedirs(outpath, exist_ok=True)

In [4]:
types_with_evals = [
    "Domain",
    "Homologous_superfamily",
]

results = {
    interpro_type: pd.read_csv(evals_path / f"final_evaluation_results_esmc_{interpro_type}.csv")
    for interpro_type in types_with_evals
}

for interpro_type, df in results.items():
    raw_acc = (df.true_labels == df.pred_labels).mean()
    print(f"InterPro type: {interpro_type}: {len(df)} rows, raw accuracy: {raw_acc:.3f}")


InterPro type: Domain: 32305 rows, raw accuracy: 0.960
InterPro type: Homologous_superfamily: 104754 rows, raw accuracy: 0.886


In [5]:
# Load train set stats and label mappers
train_set_stats_path = dataset_path / "train_sharded" / "summary_stats"

label_maps = {
    interpro_type: {
        interpro_id: label for (interpro_id, label) in
        (pd.read_table(labels_path / f"{interpro_type}.labels.tsv")
         [["interpro_id", "label"]]
         .itertuples(index=False))
    }
    for interpro_type in types_with_evals
}

stats_dfs = {
    interpro_type: (
        pd.read_table(train_set_stats_path / f"{interpro_type}_summaries.tsv")
        .assign(label=lambda x: x.element_id.map(label_maps[interpro_type]))
    )
    for interpro_type in types_with_evals
}


for interpro_type, df in stats_dfs.items():
    print(f"InterPro type: {interpro_type}: {len(df)} types in train set")

InterPro type: Domain: 15868 types in train set
InterPro type: Homologous_superfamily: 3511 types in train set


In [6]:
def per_type_perf_plot(
    interpro_type: str,
    truncate_q: Optional[float]=0.99,
):
    per_type_res = (
        results[interpro_type]
        .assign(
            correct=lambda x: x.true_labels == x.pred_labels,
            count=1,
        )
        [["true_labels", "correct", "count"]]
        .groupby("true_labels")
        .agg({
            "correct": "mean",
            "count": "sum",
        })
        .reset_index(names="label")
        .merge(
            stats_dfs[interpro_type].rename(columns={"count":"train_set_count"}),
            on="label",
            how="left",
        )
    )
    assert not per_type_res.train_set_count.isna().any()



    with sns.plotting_context("notebook"):
        sns.jointplot(
            x="train_set_count",
            y="correct",
            data=per_type_res,
            joint_kws={
                "alpha": 0.3,
            },
        )
        # Truncate due to some outliers that are highly represented
        if truncate_q:
            xmax = stats_dfs[interpro_type]["count"].quantile(truncate_q)
            plt.xlim(0, xmax)
        plt.xlabel("# times element appears in train set")
        plt.ylabel("Per element\ntest set accuracy", rotation="horizontal", ha="right")
        plt.suptitle(f"Performance by {interpro_type.lower().replace('_', ' ')} representation in train set", y=1)

In [7]:
def acc_by_per_prot_num(
    interpro_type: str,
):
    res = results[interpro_type]
    type_per_prot = res.protein_ids.value_counts().to_frame()

    per_prot_res = (
        res
        .assign(correct=lambda x: x.true_labels == x.pred_labels)
        [["protein_ids", "correct"]]
        .groupby("protein_ids")
        .agg({
            "correct": "mean",
        })
        .join(type_per_prot)
    )

    acc_by_count = (
        per_prot_res
        .groupby("count")
        .correct
        .describe()
        .rename_axis(index="num_domains")
        .reset_index()
    )

    with sns.plotting_context("notebook"):
        ax = sns.barplot(
            x="count",
            y="correct",
            data=per_prot_res,
            alpha=0.7,
            edgecolor="k",
        )
        _ = ax.bar_label(
            ax.containers[0],
            labels=acc_by_count["count"].astype(int),
            label_type="center",
            fontsize=12,
        )
        sns.despine()
        ax.set_xlabel(f"Number of {interpro_type.lower()}s per protein")
        ax.set_ylabel(
            "Average accuracy\nper protein",
            rotation="horizontal",
            ha="right",
        )

In [8]:
for interpro_type in types_with_evals:
    per_type_perf_plot(interpro_type)
    plt.savefig(outpath / f"per_type_perf_{interpro_type}.pdf", bbox_inches="tight")
    plt.close()

    acc_by_per_prot_num(interpro_type)
    plt.savefig(outpath / f"acc_by_per_prot_num_{interpro_type}.pdf", bbox_inches="tight")
    plt.close()