# Benchmark


In [None]:
import numpy as np

from pepme import FeatureCache, compute_metrics, show_table
from pepme.metrics import FID, Novelty, Precision, Recall, Uniqueness
from pepme.models.embeddings import ESM2
from pepme.utils import random_subset, read_fasta_file, shuffle_sequences

Let's load the datasets.


In [None]:
DATASET_PATHS = {
    "AMP-Diffusion": "../data/amp-diffusion.fasta",
    "AMP-GAN": "../data/amp-gan.fasta",
    "CPL-Diff": "../data/cpl-diff.fasta",
    "HydrAMP": "../data/hydramp.fasta",
    "OmegAMP": "../data/omegamp.fasta",
    "DBAASP": "../data/dbaasp.fasta",
    "UniProt": "../data/uniprot/uniprot_8_50_100.fasta",
}

In [None]:
datasets = {name: read_fasta_file(path) for name, path in DATASET_PATHS.items()}

In [None]:
for model_name, sequences in datasets.items():
    print(f"{model_name}: {len(sequences)} sequences")

AMP-Diffusion: 47671 sequences
AMP-GAN: 150000 sequences
CPL-Diff: 49985 sequences
HydrAMP: 50000 sequences
OmegAMP: 149504 sequences
DBAASP: 8967 sequences
UniProt: 2933310 sequences


In [None]:
n_samples = 3_000
seed = 42

benchmark_datasets = {
    name: random_subset(sequences, n_samples=n_samples, seed=seed)
    if len(sequences) > n_samples
    else sequences
    for name, sequences in datasets.items()
}

benchmark_datasets["DBAASP (shuffled)"] = shuffle_sequences(
    benchmark_datasets["DBAASP"]
)


def my_embedder(sequences: list[str]) -> np.ndarray:
    lengths = [len(sequence) for sequence in sequences]
    counts = [sequence.count("K") for sequence in sequences]
    return np.array([lengths, counts]).T


cache = FeatureCache(
    models={
        "embedder": my_embedder,
        "esm2": ESM2(
            model_name="esm2_t6_8M_UR50D",
            batch_size=256,
            device="cpu",
            verbose=False,
        ),
    }
)

Let's select the metrics.

In [None]:
embedder = "esm2"  # "embedder"

metrics = [
    Uniqueness(),
    Novelty(reference=benchmark_datasets["DBAASP"], reference_name="DBAASP"),
    FID(
        reference=benchmark_datasets["DBAASP"],
        embedder=cache.model(embedder),
    ),
    Precision(
        neighborhood_size=3,
        reference=benchmark_datasets["DBAASP"],
        embedder=cache.model(embedder),
        strict=False,
    ),
    Recall(
        neighborhood_size=3,
        reference=benchmark_datasets["DBAASP"],
        embedder=cache.model(embedder),
        strict=False,
    ),
]

In [None]:
df = compute_metrics(benchmark_datasets, metrics)

100%|██████████| 40/40 [01:19<00:00,  1.98s/it, data=DBAASP (shuffled), metric=Recall]          


Let's look at the results.


In [None]:
show_table(df, decimals=[4, 3, 2, 2, 2])

Unnamed: 0,Uniqueness↑,Novelty (DBAASP)↑,FID↓,Precision↑,Recall↑
AMP-Diffusion,0.9727,1.0,1.08,0.65,0.52
AMP-GAN,0.9997,1.0,3.64,0.82,0.25
CPL-Diff,0.99,0.999,1.27,0.7,0.77
HydrAMP,1.0,1.0,4.38,0.56,0.45
OmegAMP,0.9823,0.996,1.0,0.81,0.75
DBAASP,0.995,0.0,0.0,1.0,1.0
UniProt,1.0,1.0,3.57,0.68,0.52
DBAASP (shuffled),0.9993,0.993,1.22,0.88,0.78
