# Finetuning Bacformer for phenotypic traits prediction tutorial

This tutorial outlines how to finetune a pretrained Bacformer to predict phenotypic labels

We provide a dataset containing protein sequences for over `1,000` genomes across different species, each with a binary label. We show how to train and evaluate
finetuned Bacformer for phenotype prediction. The framework presented here is in principle extendable to any bacterial phenotype.

We recommend to firstly check out the `phenotypic_traits_prediction_tutorial.ipynb`, which is significantly less computationally expensive and outlines how to train
a simple linear regression model using precomputed Bacformer embeddings for phenotype prediction. If your phenotype is challenging or you want to provide
your own phenotype label - please use this tutorial.

Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.

## Step 1: Import required dependencies

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

import torch
from bacformer.modeling import (
    SPECIAL_TOKENS_DICT,
    BacformerTrainer,
    collate_genome_samples,
    compute_metrics_binary_genome_pred,
)
from bacformer.pp import dataset_col_to_bacformer_inputs
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForSequenceClassification, EarlyStoppingCallback, TrainingArguments

## Step 2: Load the dataset¶

We will be using the dataset for predicting `Catalase`. `Catalase` denotes whether a bacterium produces the catalase enzyme that breaks down hydrogen peroxide (H₂O₂) into water and oxygen, thereby protecting the cell from oxidative stress. The phenotypic data has been collected from [1] and is a binary classification problem.

In [None]:
dataset = load_dataset("macwiatrak/phenotypic-trait-catalase-protein-sequences", keep_in_memory=False)

## Step 3: Embed the dataset with the base protein language model (pLM)

The first step to using Bacformer is embedding the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).

This step should take ~1.5h on a single A100 NVIDIA GPU.

In [None]:
# embed the protein sequences with the ESM-2 base model
for split_name in dataset.keys():
    dataset[split_name] = dataset[split_name].select(range(30))
    dataset[split_name] = dataset_col_to_bacformer_inputs(
        dataset=dataset[split_name],
        max_n_proteins=7000,
    )

## Step 4: Load the Bacformer model

Load the pre-trained Bacformer model with a classification layer on top which we finetune.

In [None]:
# load the Bacformer model for genome classification
# for this task we use the Bacformer model trained on masked complete genomes
config = AutoConfig.from_pretrained("macwiatrak/bacformer-masked-complete-genomes", trust_remote_code=True)
config.num_labels = 1
config.problem_type = "binary_classification"
bacformer_model = AutoModelForSequenceClassification.from_pretrained(
    "macwiatrak/bacformer-masked-complete-genomes", config=config, trust_remote_code=True
).to(torch.bfloat16)
print("Nr of parameters:", sum(p.numel() for p in bacformer_model.parameters()))
print("Nr of trainable parameters:", sum(p.numel() for p in bacformer_model.parameters() if p.requires_grad))

## Step 5: Setup the trainer for finetuning

Setup a trainer object to allow for finetuning.

In [None]:
# create a trainer
# get training args
output_dir = "output/pheno_trait_pred"
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    learning_rate=0.00015,
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    seed=1,
    dataloader_num_workers=4,
    bf16=True,
    metric_for_best_model="eval_auroc",
    load_best_model_at_end=True,
    greater_is_better=True,
)

# define a collate function for the dataset
collate_genome_samples_fn = partial(collate_genome_samples, SPECIAL_TOKENS_DICT["PAD"], 7000, 1000)
trainer = BacformerTrainer(
    model=bacformer_model,
    data_collator=collate_genome_samples_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    args=training_args,
    compute_metrics=compute_metrics_binary_genome_pred,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

## Step 6: Finetune the model 🎉🚂😎

Finetune the model to predict gene essentiality. The training should take ~15 min on a single A100 NVIDIA GPU.

In [None]:
# train the model
trainer.train()

## Step 7: Evaluate on the test and run predictions

Having a trained model you can now evaluate the model on the test set and run the predictions.

In [None]:
 # evaluate the model on the test set
test_output = trainer.predict(dataset["test"])
print("Test output:", test_output.metrics)

## [Optional] Step 8: Plot genome-phenotype probabilities

Plot predicted phenotype probability for test genomes.

In [None]:
# get the pandas DataFrame with data to plot
plot_df = pd.DataFrame(
    {
        'probability': test_output.predictions.squeeze().tolist(),
        'label': test_output.label_ids.tolist(),
        'genome_name': dataset["test"]["genome_name"],
    }
)

# Create the KDE plot
plt.figure(figsize=(8, 6))
sns.kdeplot(
    plot_df,
    x='probability',
    fill=True,
    hue='label',
    # color='blue',
    alpha=0.6,
    # label='Non-essential genes'
)

# Add legend and labels
plt.title(f"Genome phenotype (Catalase) prediction", fontsize=18)

# Show the plot
plt.tight_layout()
plt.show()

----------------------
#### Voilà, you made it 👏! 

This example shows how to finetune Bacformer for a genome-level task and can be applied to any other phenotype with available genomes and phenotypes.

In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues.

We also provide `139` diverse phenotypic trait labels distributed across almost 25k genomes. To use it please see `phenotypic_traits_prediction_tutorial.ipynb`, which outlines how 
to train a linear regression model on top of precomputed genome embeddings.

## References

[1] Weimann, Aaron, et al. "From genomes to phenotypes: Traitar, the microbial trait analyzer