# Gene essentiality prediction with Bacformer tutorial

This tutorial outlines how one can finetune Bacformer model to prediction gene essentiality.

We use a dataset from [Database of Essential Genes](http://origin.tubic.org/deg/public/index.php/browse/bacteria) as our training and evaluation set, evaluating the performance at the genome-level.

Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.

## Step 1: Import required dependencies

In [1]:
import os
from functools import partial

import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from bacformer.modeling import (
    SPECIAL_TOKENS_DICT,
    BacformerTrainer,
    collate_genome_samples,
    compute_metrics_gene_essentiality_pred, adjust_prot_labels,
)
from bacformer.pp import dataset_col_to_bacformer_inputs
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForTokenClassification, EarlyStoppingCallback, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


## Step 2: Load the dataset

We will be using the gene essentiality dataset preprocessed for this task. In this task, each protein in a genome has an `essentiality` label (binary) which we are predicting
given 1) protein sequence itself, 2) the whole-genome context.

In [None]:
# load the dataset from HuggingFace
dataset = load_dataset("macwiatrak/bacbench-essential-genes-protein-sequences")

## Step 3: Embed the dataset with the base protein language model (pLM)

The first step to using Bacformer is embedding the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).

This step takes ~5 min on a single A100 NVIDIA GPU.

In [None]:
# embed the protein sequences with the ESM-2 base model and map the labels
for split_name in dataset.keys():
    dataset[split_name] = dataset_col_to_bacformer_inputs(
        dataset=dataset[split_name],
        max_n_proteins=7000,
    )
    # map the essentiality labels to a binary format
    dataset[split_name] = dataset[split_name].map(
        lambda row: adjust_prot_labels(
            labels=row["essential"],
            special_tokens=row["special_tokens_mask"],
        ),
        batched=False,
    )

## Step 4: Load the Bacformer model

Load the pre-trained Bacformer model with a classification layer on top which we finetune.

In [None]:
# load the Bacformer model for protein classification
# for this task we use the Bacformer model trained on masked complete genomes
# with a token (here protein) classification head
config = AutoConfig.from_pretrained("macwiatrak/bacformer-masked-complete-genomes", trust_remote_code=True)
config.num_labels = 1
config.problem_type = "binary_classification"

bacformer_model = AutoModelForTokenClassification.from_pretrained(
    "macwiatrak/bacformer-masked-complete-genomes", config=config, trust_remote_code=True
).to(torch.bfloat16)

print("Nr of parameters:", sum(p.numel() for p in bacformer_model.parameters()))
print("Nr of trainable parameters:", sum(p.numel() for p in bacformer_model.parameters() if p.requires_grad))

## Step 5: Setup the trainer for finetuning

Setup a trainer object to allow for finetuning.

In [None]:
# define the output directory for the model and metrics
output_dir = "output/gene_essentiality_pred"
os.makedirs(output_dir, exist_ok=True)

# create a trainer
# get training args
training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    learning_rate=0.00015,
    num_train_epochs=100,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    seed=1,
    dataloader_num_workers=4,
    bf16=True,
    metric_for_best_model="eval_macro_auroc",
    load_best_model_at_end=True,
    greater_is_better=True,
)

# define a collate function for the dataset
collate_genome_samples_fn = partial(collate_genome_samples, SPECIAL_TOKENS_DICT["PAD"], 7000, 1000)
trainer = BacformerTrainer(
    model=bacformer_model,
    data_collator=collate_genome_samples_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    args=training_args,
    compute_metrics=compute_metrics_gene_essentiality_pred,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)

## Step 6: Finetune the model 🎉🚂😎

Finetune the model to predict gene essentiality. The training should take ~15 min on a single A100 NVIDIA GPU.

In [None]:
trainer.train()

## Step 7: Evaluate on the test and run predictions

Having a trained model you can now evaluate the model on the test set and run the predictions.

In [None]:
# evaluate the model on the test set
test_output = trainer.predict(dataset["test"])
print("Test output:", test_output.metrics)

# get the predictions and labels for a single genome from the test set
preds_strain = torch.sigmoid(torch.tensor(test_output.predictions.squeeze(-1)))[0]
labels_strain = torch.tensor(test_output.label_ids)[0]
genome_id = dataset["test"][0]["genome_id"]

## [Optional] Step 8: Plot gene essentiality probabilities

Plot predicted gene essentiality probabilities for a single genome.

In [None]:
# make DF for plotting
df = pd.DataFrame({'probability': preds_strain.tolist(), 'label': labels_strain.tolist()})
# remove the ignore index rows
df = df[df.label != -100]

# Create the KDE plot
plt.figure(figsize=(8, 6))
sns.kdeplot(
    data=df[df['label'] == 0],
    x='probability',
    fill=True,
    color='blue',
    alpha=0.6,
    label='Non-essential genes'
)
sns.kdeplot(
    data=df[df['label'] == 1],
    x='probability',
    fill=True,
    color='goldenrod',
    alpha=0.6,
    label='Essential genes'
)

# Add legend and labels
plt.title(f"Gene Essentiality Prediction for {genome_id}", fontsize=18)
plt.legend(fontsize=20, title_fontsize=20, frameon=False, loc="upper left")
plt.xlabel("", fontsize=12)
plt.ylabel("", fontsize=12)
plt.title("", fontsize=14)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlim(-0.1, 1.1)

# Show the plot
plt.tight_layout()
plt.show()

----------------------

#### Voilà, you made it 👏! 

In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues.