# Strain clustering with Bacformer tutorial

This tutorial outlines how one can finetune Bacformer model to cluster strains. Bacformer outputs contextual protein embeddings and we use the average of all
contextual protein embeddings as a genome embedding.

We use a small random sample of 30 genomes across 4 distinct species and 3 families to demonstrate how we can embed the genomes with Bacformer and use it for clustering.
The genomes have been extracted from [MGnify](https://www.ebi.ac.uk/metagenomics).

Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.

## Step 1: Import required dependencies

In [1]:
import anndata as ad
import numpy as np
import scanpy as sc
from bacformer.pp import embed_dataset_col
from datasets import load_dataset
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.preprocessing import LabelEncoder

  from .autonotebook import tqdm as notebook_tqdm


## Step 2: Load the dataset

Load the sample dataset from HuggingFace.

In [None]:
# load the dataset
dataset = load_dataset("macwiatrak/strain-clustering-protein-sequences-sample", split="train")

## Step 3: Compute Bacformer embeddings

Convert the protein sequences to genome embeddings. This is done in 2 steps:
1. Embed the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).
2. Use the protein embeddings as input to the Bacformer model which computes contextual protein embeddings and takes the average of them to get genome embedding.

This step takes ~2 min on a single A100 NVIDIA GPU with `flash-attention` installed.

In [None]:
# embed the protein sequences with Bacformer
dataset = embed_dataset_col(
    dataset=dataset,
    model_path="macwiatrak/bacformer-masked-MAG",
    max_n_proteins=9000,
    genome_pooling_method="mean",
)

## Step 4: Cluster the genome embeddings

We use [scanpy](https://scanpy.readthedocs.io/en/stable/) for clustering, so we convert the data to an `AnnData` object and use it to compute the `UMAP`.

In [None]:
# convert dataset to pandas DataFrame
df = dataset.to_pandas()

# create anndata object needed for clustering
embeddings = np.stack(df["embeddings"].tolist())  # get embedding matrix
adata = ad.AnnData(
    X=embeddings,
    obs=df.drop(columns=["embeddings"]).copy(),
)

# compute neighbors witg scanpy
sc.pp.neighbors(adata, use_rep="X")

# compute UMAP
sc.tl.umap(adata)

## Step 5: Plot UMAPs

Plot UMAPs by species and family labels.

In [None]:
# plot UMAP by species
sc.pl.umap(adata, color="species")

In [None]:
# plot UMAP by family
sc.pl.umap(adata, color="family")

## [Optional] Step 6: Compute clustering metrics

Compute `Leiden` clustering and compute the metrics, useful for evaluating how well does the model cluster strains by label (here, species).

In [None]:
# compute clustering metrics (optional)
sc.tl.leiden(adata, resolution=0.1, key_added="leiden_clusters")

# Convert Leiden cluster labels to integer labels
leiden_clusters = adata.obs["leiden_clusters"].astype(int)

# Encode ground-truth labels
label_encoder = LabelEncoder()
numeric_labels = label_encoder.fit_transform(adata.obs["species"])

# Compute ARI, NMI, and Silhouette
ari = adjusted_rand_score(numeric_labels, leiden_clusters)
nmi = normalized_mutual_info_score(numeric_labels, leiden_clusters)
# Silhouette requires sample-level features + predicted labels
sil = silhouette_score(adata.X, leiden_clusters)
print(f"ARI: {ari:.3f}, NMI: {nmi:.3f}, Silhouette Score: {sil:.3f}")

----------------------

#### Voilà, you made it 👏! 

In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues.