# Benchmark


In this notebook, we use pepme to evalulate antimicrobial peptide sequences.


In [None]:
import numpy as np

from pepme import FeatureCache, compute_metrics, show_table
from pepme.metrics import FBD, HV, MMD, Fold, HitRate, Novelty, Precision, Recall, Uniqueness
from pepme.models.embeddings import Esm2, Esm2Checkpoint
from pepme.models.properties import Charge, Gravy
from pepme.utils import random_subset, read_fasta_file, shuffle_sequences

Let's load the datasets.


In [None]:
DATASET_PATHS = {
    "AMP-Diffusion": "./data/amp-diffusion.fasta",
    "AMP-GAN": "./data/amp-gan.fasta",
    "CPL-Diff": "./data/cpl-diff.fasta",
    "HydrAMP": "./data/hydramp.fasta",
    "OmegAMP": "./data/omegamp.fasta",
    "DBAASP": "./data/dbaasp.fasta",
    "UniProt": "./data/uniprot/uniprot_8_50_100.fasta",
}

In [None]:
datasets = {name: read_fasta_file(path) for name, path in DATASET_PATHS.items()}

In [None]:
for model_name, sequences in datasets.items():
    print(f"{model_name}: {len(sequences)} sequences")

AMP-Diffusion: 47671 sequences
AMP-GAN: 150000 sequences
CPL-Diff: 49985 sequences
HydrAMP: 50000 sequences
OmegAMP: 149504 sequences
DBAASP: 8967 sequences
UniProt: 2933310 sequences


Let's setup the data and models.


In [None]:
n_samples = 3_000
seed = 42

benchmark_datasets = {
    name: random_subset(sequences, n_samples=n_samples, seed=seed) if len(sequences) > n_samples else sequences
    for name, sequences in datasets.items()
}
benchmark_datasets["DBAASP (shuffled)"] = shuffle_sequences(benchmark_datasets["DBAASP"])
seqs_dbaasp = random_subset(benchmark_datasets.pop("DBAASP"), n_samples=1000, seed=seed)


def my_embedder(sequences: list[str]) -> np.ndarray:
    lengths = [len(sequence) for sequence in sequences]
    counts = [sequence.count("K") for sequence in sequences]
    return np.array([lengths, counts]).T


cache = FeatureCache(
    models={
        "embedder": my_embedder,
        "esm2": Esm2(
            model_name=Esm2Checkpoint.t6_8M,
            batch_size=256,
            device="cpu",
            verbose=False,
        ),
        "gravy": Gravy(),
        "charge": Charge(),
    }
)

Let's select the metrics.


In [None]:
embedder = "esm2"  # "embedder"


def hit_rate_condition_fn(sequences: list[str]) -> np.ndarray:
    no_cystine = ~np.array(["C" in seq for seq in sequences])
    return no_cystine


metrics = [
    Uniqueness(),
    Novelty(reference=seqs_dbaasp, reference_name="DBAASP"),
    FBD(reference=seqs_dbaasp, embedder=cache.model(embedder)),
    MMD(reference=seqs_dbaasp, embedder=cache.model(embedder)),
    Fold(
        Precision(neighborhood_size=3, reference=seqs_dbaasp, embedder=cache.model(embedder), strict=True),
        split_size=len(seqs_dbaasp),
        drop_last=True,
    ),
    Fold(
        Recall(neighborhood_size=3, reference=seqs_dbaasp, embedder=cache.model(embedder), strict=True),
        split_size=len(seqs_dbaasp),
        drop_last=True,
    ),
    Fold(
        metric=HitRate(condition_fn=hit_rate_condition_fn),
        k=5,
    ),
    Fold(
        metric=HV(predictors=[cache.model("gravy"), cache.model("charge")], nadir=np.array([-10, -50])),
        k=5,
    ),
]

`Fold` computes the metrics multiple using different folds, and aggregate the values (mean and standard deviation).

Wrapping `Fold` around the `Precision` and `Recall` metric, removes the sample size bias inherent in these metrics (introduced by k-NN), while still utilizing as many of the available sequences as possible.


In [None]:
df = compute_metrics(benchmark_datasets, metrics)

100%|██████████| 56/56 [01:11<00:00,  1.28s/it, data=DBAASP (shuffled), metric=HV-2]            


Let's look at the results.


In [None]:
show_table(df, decimals=[3, 3, 2, 2, 2, 2, 2, 0])

Unnamed: 0,Uniqueness↑,Novelty (DBAASP)↑,FBD↓,MMD↓,Precision↑,Recall↑,Hit-rate↑,HV-2↑
AMP-Diffusion,0.973,1.0,1.18,3.24,0.75±0.01,0.58±0.02,0.61±0.03,830±65
AMP-GAN,1.0,1.0,3.51,11.64,0.86±0.00,0.27±0.01,0.52±0.02,776±33
CPL-Diff,0.99,1.0,1.35,5.41,0.73±0.01,0.80±0.01,0.72±0.02,887±33
HydrAMP,1.0,1.0,4.18,16.82,0.60±0.00,0.45±0.03,0.60±0.01,785±19
OmegAMP,0.982,0.998,0.93,4.11,0.85±0.01,0.77±0.02,0.69±0.02,823±32
UniProt,1.0,1.0,3.54,12.29,0.71±0.01,0.47±0.07,0.56±0.01,814±26
DBAASP (shuffled),0.999,0.998,1.09,3.95,0.88±0.00,0.75±0.02,0.66±0.01,940±72
