# Benchmark


In this notebook, we use pepme to evalulate antimicrobial peptide sequences.


In [None]:
import numpy as np

from pepme import FeatureCache, compute_metrics, show_table
from pepme.metrics import FID, HV, MMD, Count, Fold, HitRate, Novelty, Precision, Recall, Uniqueness
from pepme.models.embeddings import ESM2
from pepme.models.properties import Charge, Gravy
from pepme.utils import random_subset, read_fasta_file, shuffle_sequences

Let's load the datasets.


In [None]:
DATASET_PATHS = {
    "AMP-Diffusion": "./data/amp-diffusion.fasta",
    "AMP-GAN": "./data/amp-gan.fasta",
    "CPL-Diff": "./data/cpl-diff.fasta",
    "HydrAMP": "./data/hydramp.fasta",
    "OmegAMP": "./data/omegamp.fasta",
    "DBAASP": "./data/dbaasp.fasta",
    "UniProt": "./data/uniprot/uniprot_8_50_100.fasta",
}

In [None]:
datasets = {name: read_fasta_file(path) for name, path in DATASET_PATHS.items()}

In [None]:
for model_name, sequences in datasets.items():
    print(f"{model_name}: {len(sequences)} sequences")

AMP-Diffusion: 47671 sequences
AMP-GAN: 150000 sequences
CPL-Diff: 49985 sequences
HydrAMP: 50000 sequences
OmegAMP: 149504 sequences
DBAASP: 8967 sequences
UniProt: 2933310 sequences


In [None]:
n_samples = 3_000
seed = 42

benchmark_datasets = {
    name: random_subset(sequences, n_samples=n_samples, seed=seed) if len(sequences) > n_samples else sequences
    for name, sequences in datasets.items()
}

seqs_dbaasp = benchmark_datasets.pop("DBAASP")
benchmark_datasets["DBAASP (shuffled)"] = shuffle_sequences(seqs_dbaasp)


def my_embedder(sequences: list[str]) -> np.ndarray:
    lengths = [len(sequence) for sequence in sequences]
    counts = [sequence.count("K") for sequence in sequences]
    return np.array([lengths, counts]).T


cache = FeatureCache(
    models={
        "embedder": my_embedder,
        "esm2": ESM2(
            model_name="esm2_t6_8M_UR50D",
            batch_size=256,
            device="cpu",
            verbose=False,
        ),
    }
)

Let's select the metrics.


In [None]:
embedder = "esm2"  # "embedder"


def hit_rate_condition_fn(sequences: list[str]) -> np.ndarray:
    no_cystine = ~np.array(["C" in seq for seq in sequences])
    return no_cystine


metrics = [
    Count(),
    Uniqueness(),
    Novelty(reference=seqs_dbaasp, reference_name="DBAASP"),
    FID(
        reference=seqs_dbaasp,
        embedder=cache.model(embedder),
    ),
    MMD(
        reference=seqs_dbaasp,
        embedder=cache.model(embedder),
        strict=False,
    ),
    Precision(
        neighborhood_size=3,
        reference=seqs_dbaasp,
        embedder=cache.model(embedder),
        strict=False,
    ),
    Recall(
        neighborhood_size=3,
        reference=seqs_dbaasp,
        embedder=cache.model(embedder),
        strict=False,
    ),
    Fold(
        metric=HitRate(condition_fn=hit_rate_condition_fn),
        k=5,
    ),
    Fold(
        metric=HV(
            predictors=[Gravy(), Charge()],
            nadir=np.array([-10, -50]),
        ),
        k=5,
    ),
]

In [None]:
df = compute_metrics(benchmark_datasets, metrics)

100%|██████████| 63/63 [01:10<00:00,  1.11s/it, data=DBAASP (shuffled), metric=HV]              


Let's look at the results.


In [None]:
show_table(df, decimals=[0, 4, 3, 2, 2, 2, 2, 2, 1])

Unnamed: 0,Count↑,Uniqueness↑,Novelty (DBAASP)↑,FID↓,MMD↓,Precision↑,Recall↑,Hit-rate↑,HV↑
AMP-Diffusion,3000,0.9727,1.0,1.08,3.12,0.65,0.52,0.61±0.03,830.4±64.9
AMP-GAN,3000,0.9997,1.0,3.64,12.4,0.82,0.25,0.52±0.02,775.5±32.7
CPL-Diff,3000,0.99,0.999,1.27,5.12,0.7,0.77,0.72±0.02,886.8±33.5
HydrAMP,3000,1.0,1.0,4.38,17.81,0.56,0.45,0.60±0.01,785.4±19.1
OmegAMP,3000,0.9823,0.996,1.0,4.7,0.81,0.75,0.69±0.02,822.7±31.6
UniProt,3000,1.0,1.0,3.57,12.74,0.68,0.52,0.56±0.01,814.2±25.9
DBAASP (shuffled),3000,0.9993,0.993,1.22,4.61,0.88,0.78,0.66±0.01,939.7±72.5
