# Benchmark


In this notebook, we use pepme to evalulate antimicrobial peptide sequences.


In [None]:
import numpy as np

from pepme import FeatureCache, compute_metrics, show_table
from pepme.metrics import FBD, HV, ID, MMD, Fold, HitRate, Novelty, Precision, Recall, Uniqueness
from pepme.models import Charge, Esm2, Esm2Checkpoint, Gravy
from pepme.utils import random_subset, read_fasta_file, shuffle_sequences

Let's load the datasets.


In [None]:
DATASET_PATHS = {
    "AMP-Diffusion": "./data/amp-diffusion.fasta",
    "AMP-GAN": "./data/amp-gan.fasta",
    "CPL-Diff": "./data/cpl-diff.fasta",
    "HydrAMP": "./data/hydramp.fasta",
    "OmegAMP": "./data/omegamp.fasta",
    "DBAASP": "./data/dbaasp.fasta",
    "UniProt": "./data/uniprot/uniprot_8_50_100.fasta",
    "AMPs (E. coli)": "./data/amps_ecoli.fasta",
    "AMPs": "./data/amps.fasta",
}

In [None]:
datasets = {name: read_fasta_file(path) for name, path in DATASET_PATHS.items()}

In [None]:
for model_name, sequences in datasets.items():
    print(f"{model_name}: {len(sequences)} sequences")

AMP-Diffusion: 47671 sequences
AMP-GAN: 150000 sequences
CPL-Diff: 49985 sequences
HydrAMP: 50000 sequences
OmegAMP: 149504 sequences
DBAASP: 8967 sequences
UniProt: 2933310 sequences
AMPs (E. coli): 4928 sequences
AMPs: 7204 sequences


Let's setup the data and models.


In [None]:
n_samples = 100  # 3_000
seed = 42

benchmark_datasets = {
    name: random_subset(sequences, n_samples=n_samples, seed=seed) if len(sequences) > n_samples else sequences
    for name, sequences in datasets.items()
}
benchmark_datasets["DBAASP (shuffled)"] = shuffle_sequences(benchmark_datasets["DBAASP"])

# seqs_amps = random_subset(benchmark_datasets.pop("AMPs"), n_samples=1000, seed=seed)
seqs_amps = benchmark_datasets["AMPs"]


def my_embedder(sequences: list[str]) -> np.ndarray:
    lengths = [len(sequence) for sequence in sequences]
    counts = [sequence.count("K") for sequence in sequences]
    return np.array([lengths, counts]).T


esm2 = Esm2(
    model_name=Esm2Checkpoint.t6_8M,
    batch_size=256,
    device="cpu",
    verbose=False,
)

cache = FeatureCache(
    models={
        "embedder": my_embedder,
        "esm2-embed": esm2.embed,
        "esm2-perplexity": lambda seqs: esm2.compute_pseudo_perplexity(seqs, mask_size=3),
        "gravy": Gravy(),
        "charge": Charge(),
    }
)

Let's select the metrics.


In [None]:
embedder = "esm2-embed"  # "embedder"


def hit_rate_condition_fn(sequences: list[str]) -> np.ndarray:
    no_cystine = ~np.array(["C" in seq for seq in sequences])
    return no_cystine


metrics = [
    Uniqueness(),
    Novelty(reference=seqs_amps),
    FBD(reference=seqs_amps, embedder=cache.model(embedder)),
    MMD(reference=seqs_amps, embedder=cache.model(embedder)),
    Fold(
        Precision(neighborhood_size=3, reference=seqs_amps, embedder=cache.model(embedder), strict=True),
        split_size=len(seqs_amps),
        drop_last=True,
    ),
    Fold(
        Recall(neighborhood_size=3, reference=seqs_amps, embedder=cache.model(embedder), strict=True),
        split_size=len(seqs_amps),
        drop_last=True,
    ),
    Fold(
        metric=HitRate(condition_fn=hit_rate_condition_fn),
        k=5,
    ),
    Fold(
        metric=HV(predictors=[cache.model("gravy"), cache.model("charge")], nadir=np.array([-10, -50])),
        k=5,
    ),
    ID(predictor=cache.model("esm2-perplexity"), name="Perplexity", objective="minimize"),
]

`Fold` computes the metrics multiple using different folds, and aggregate the values (mean and standard deviation).

Wrapping `Fold` around the `Precision` and `Recall` metric, removes the sample size bias inherent in these metrics (introduced by k-NN), while still utilizing as many of the available sequences as possible.


In [None]:
df = compute_metrics(benchmark_datasets, metrics)

100%|██████████| 90/90 [01:25<00:00,  1.05it/s, data=DBAASP (shuffled), metric=Perplexity]


Let's look at the results.


In [None]:
show_table(df, decimals=[3, 3, 2, 2, 2, 2, 2, 0, 2])

Unnamed: 0,Uniqueness↑,Novelty↑,FBD↓,MMD↓,Precision↑,Recall↑,Hit-rate↑,HV-2↑,Perplexity↓
AMP-Diffusion,0.99,1.0,3.22,8.56,0.68,0.84,0.53±0.12,707±48,10.66±3.58
AMP-GAN,1.0,1.0,6.23,27.05,0.95,0.24,0.62±0.20,652±24,13.35±2.25
CPL-Diff,1.0,1.0,2.44,5.07,0.83,0.91,0.71±0.05,721±61,9.39±3.41
HydrAMP,1.0,1.0,5.61,25.07,0.7,0.45,0.63±0.11,665±27,13.81±3.76
OmegAMP,1.0,1.0,3.25,13.95,0.81,0.84,0.69±0.07,699±26,12.19±3.29
DBAASP,1.0,1.0,2.36,5.93,0.8,0.92,0.61±0.10,655±38,12.05±4.00
UniProt,1.0,1.0,6.79,25.16,0.6,0.4,0.53±0.10,675±56,14.37±3.21
AMPs (E. coli),1.0,1.0,1.58,1.78,0.88,0.8,0.82±0.05,684±25,10.61±3.71
AMPs,1.0,0.0,0.0,0.0,1.0,1.0,0.83±0.07,693±35,10.65±3.52
DBAASP (shuffled),1.0,1.0,4.13,17.14,0.82,0.72,0.61±0.10,655±38,13.56±3.50
