# Fine-tuning with the ProkBERT Model Family
This notebook demonstrates how to utilize ProkBERT's pre-trained models for transfer learning tasks. We will apply the model to identify promoter sequences, framed as a binary classification problem where each segment is assigned a label.

The main steps include:
- Preparing the dataset to outline the labels for each segment.
- Tokenizing nucleotide sequences.
- Creating splits and PyTorch datasets.
- Configuring training parameters such as learning rate, epochs, batch size, etc.
- Training and evaluating the model.


## Setting Up the Environment

While ProkBERT can operate on CPUs, leveraging GPUs significantly accelerates the process. Google Colab offers free GPU usage (subject to time and memory limits), making it an ideal platform for trying and experimenting with ProkBERT models.

## Enabling and testing the GPU (if you are using google colab)

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down
- 

First, we'll install the ProkBERT package directly from its GitHub repository:


In [None]:
# ProkBERT
!pip install prokbert

# Imports
import torch
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from prokbert.training_utils import get_default_pretrained_model_parameters, get_torch_data_from_segmentdb_classification
from prokbert.models import BertForBinaryClassificationWithPooling
from prokbert.prok_datasets import ProkBERTTrainingDatasetPT
from prokbert.config_utils import ProkBERTConfig
from prokbert.training_utils import compute_metrics_eval_prediction
from os.path import join


Next, we'll confirm that we can connect to the GPU with pytorch:

In [None]:

# Check if CUDA (GPU support) is available
if not torch.cuda.is_available():
    raise SystemError('GPU device not found')
else:
    device_name = torch.cuda.get_device_name(0)
    print(f'Found GPU at: {device_name}')

## Sequence Data Preparation

In this project, we will work with prokaryotic promoter sequences. The positive, known promoter sequences are derived from a prokaryotic promoter database. These sequences include the Transcription Start Site (TSS) located at position 60, with the promoter itself being an 80bp length sequence. Below is an illustration of the data structure we'll be working with.

The data is labeled in a column named `label`, where `y=1` indicates a known promoter sequence, and `y=0` otherwise. It's crucial to ensure the sequence data is clean and segmented appropriately. This means verifying that it contains only nucleotide sequences, there are no empty sequences, etc.

For detailed steps on sequence preprocessing, refer to the [segmentation notebook](https://github.com/nbrg-ppcu/prokbert/blob/main/examples/Segmentation.ipynb). For large-scale preprocessing, check out [this notebook](https://github.com/nbrg-ppcu/prokbert/blob/main/examples/Tokenization.ipynb).

### Loading the Dataset

We'll start by loading a predefined dataset of bacterial promoters:



In [None]:
# Loading the predefined dataset
dataset = load_dataset("neuralbioinfo/bacterial_promoters")

train_set = dataset["train"]
test_sigma70_set = dataset["test_sigma70"]
multispecies_set = dataset["test_multispecies"]


train_db = train_set.to_pandas()
test_sigma70_db = test_sigma70_set.to_pandas()
test_ms_db = multispecies_set.to_pandas()

train_db.head()

## Loading the Pretrained Model

At this stage, we will load the pretrained ProkBERT model from Hugging Face. For comprehensive details about the pretrained model and its architecture, please refer to the relevant documentation.

Traditionally, models like `...SequenceClassification` classify sequences based on the hidden representation of the `[CLS]` or starting token. However, in our approach, we utilize the base model enhanced with a pooling layer that integrates information across all nucleotides in the sequence. The function `get_default_pretrained_model_parameters` is used here to load the model along with its corresponding tokenizer. It's crucial to ensure that the tokenizer's parameters, specifically LCA (Local Context Aware) tokenization settings, are aligned with those used during the model's pretraining phase. For our purposes, we adopt a k-mer size of 6 and a shift of 1.

Here's how to load the ProkBERT model along with its tokenizer:


In [None]:
model_name_path = 'neuralbioinfo/prokbert-mini'


pretrained_model, tokenizer = get_default_pretrained_model_parameters(
    model_name=model_name_path, 
    model_class='MegatronBertModel', 
    output_hidden_states=False, 
    output_attentions=False,
    move_to_gpu=False
)
fine_tuned_model = BertForBinaryClassificationWithPooling(pretrained_model)



## Tokenization and Dataset Creation

In this phase, we proceed to tokenize the nucleotide sequences from our dataset. This process converts each sequence into a format that the ProkBERT model can understand and process. To ensure that our model pays attention only to meaningful tokens, we will pad the arrays and employ the `AddAttentionMask` flag. This flag helps the model distinguish between informative tokens and padding or non-informative tokens, allowing it to focus on relevant sequence parts during training and evaluation.

### Creating Datasets

We start by processing the training, testing, and validation datasets. Each database is tokenized using the ProkBERT tokenizer, and the resulting token arrays are prepared along with their corresponding labels. Here's a breakdown of the process for each dataset:


In [None]:

## Creating datasets!
print(f'Processing train database!')
[X_train, y_train, torchdb_train] = get_torch_data_from_segmentdb_classification(tokenizer, train_db)
print(f'Processing test database!')
[X_test, y_test, torchdb_test] = get_torch_data_from_segmentdb_classification(tokenizer, test_ms_db)
print(f'Processing validation database!')
[X_val, y_val, torchdb_val] = get_torch_data_from_segmentdb_classification(tokenizer, test_sigma70_db)

train_ds = ProkBERTTrainingDatasetPT(X_train, y_train, AddAttentionMask=True)
test_ds = ProkBERTTrainingDatasetPT(X_test, y_test, AddAttentionMask=True)
val_ds = ProkBERTTrainingDatasetPT(X_val, y_val, AddAttentionMask=True)




## Training Configuration Setup

We are setting up the configuration for fine-tuning a specific ProkBERT model. The configurations are divided into several categories to manage different aspects of the training process. Below is a brief overview of the parameters used in this example and their significance:

### Model Parameters
- **`model_outputpath`, `model_name`, `resume_or_initiation_model_path`**: We designate the model's output directory and name as `prokbert_mini_promoter`. These parameters ensure that the model's training outputs, including checkpoints, are saved under a specific directory named after the model. The model initiation path is also set to the same name, indicating where the model's initial weights are loaded from.
- **`ResumeTraining`**: Set to `False` to start training from scratch rather than resuming previous training sessions.

### Training Parameters
- **`output_dir`**: Specifies the directory where training artifacts like model checkpoints will be saved. It combines a base directory `finetuned_models` with the model name.
- **`warmup_steps`**: The number of warmup steps for the learning rate scheduler is set to 1, indicating minimal warmup before reaching the full learning rate.
- **`save_steps` and `eval_steps`**: Both are set to 50, dictating how frequently the model should be saved and evaluated.
- **`save_total_limit`**: Limits the total number of model checkpoints to keep to 10, helping manage storage efficiently.
- **`learning_rate`**: The learning rate for fine-tuning is set at 0.0001.
- **`per_device_train_batch_size`**: Defines the batch size for training as 128.
- **`num_train_epochs`**: Specifies that the model will be trained for 1 epoch.
- **`evaluation_strategy`**: Set to 'steps', indicating that evaluation will occur based on the number of steps defined.
- **`per_device_eval_batch_size`**: The evaluation batch size is set to twice the training batch size, enhancing evaluation throughput.

### ProkBERT Configuration

These parameters are crucial for customizing the training process, allowing for specific training, evaluation strategies, and resource management tailored to the task and available computational resources. 


## Configuration Parameters Overview

The table below outlines the key configuration parameters for pretraining with ProkBERT, detailing their purpose, descriptions, default values, and types.

| Section | Parameter | Description | Type | Default |
|---------|-----------|-------------|------|---------|
| **training** | | | | |
| | `output_dir` | Output directory for training artifacts. | string | './train_output' |
| | `num_train_epochs` | Total number of training epochs. | float | 1 |
| | `save_steps` | Save model checkpoint every N steps. | integer | 1000 |
| | `save_total_limit` | Maximum number of total checkpoints to keep. | integer | 20 |
| | `logging_steps` | Log metrics every N steps. | integer | 50 |
| | `logging_first_step` | Whether to log metrics for the first step. | boolean | True |
| | `per_device_train_batch_size` | Batch size for training. | integer | 48 |
| | `dataloader_num_workers` | Number of subprocesses for data loading. | integer | 1 |
| | `learning_rate` | Learning rate for training. | float | 0.0005 |
| | `adam_epsilon` | Epsilon for the Adam optimizer. | float | 5e-05 |
| | `warmup_steps` | Number of warmup steps for learning rate scheduler. | integer | 500 |
| | `weight_decay` | Weight decay for optimizer. | float | 0.1 |
| | `adam_beta1` | Beta1 hyperparameter for the Adam optimizer. | float | 0.95 |
| | `adam_beta2` | Beta2 hyperparameter for the Adam optimizer. | float | 0.98 |
| | `gradient_accumulation_steps` | Number of steps to accumulate gradients before updating weights. | integer | 1 |
| | `optim` | Optimizer to use for training. | string | "adamw_torch" |
| | `ignore_data_skip` | Whether to ignore data skip or not. | boolean | True |
| **dataset** | | | | |
| | `dataset_path` | Path to the dataset. It triggers an error if empty. | string | '' |
| | `pretraining_dataset_data` | The raw dataset data. | list | [[]] |
| | `dataset_class` | The class of the dataset to be used. | string | 'IterableProkBERTPretrainingDataset' |
| | `input_batch_size` | Batch size to be loaded into memory from the disk for HDF datasets. | int | 10000 |
| | `dataset_iteration_batch_offset` | The offset value for dataset iteration start. | int | 0 |
| | `max_iteration_over_dataset` | Maximum times to iterate over a dataset. | int | 10 |

In [None]:
finetuned_model_name = 'prokbert_mini_promoter'
default_ft_model_dir = 'finetuned_models'


lr_rate = 0.0001
batch_size = 128
eval_steps = 50
warmup_steps = 1
num_train_epochs = 0.1


model_params = {'model_outputpath': finetuned_model_name,
                'model_name' : finetuned_model_name,
                'resume_or_initiation_model_path' : finetuned_model_name, 
                'ResumeTraining' : False}
dataset_params = {}
training_params = {'output_dir': join(default_ft_model_dir, finetuned_model_name),
                'warmup_steps' : warmup_steps,
                'save_steps' : eval_steps,
                'save_total_limit' : 10,
                'learning_rate' : lr_rate,
                'per_device_train_batch_size': batch_size,
                'num_train_epochs': num_train_epochs,
                'eval_steps' : eval_steps,
                'logging_steps' : eval_steps,
                'evaluation_strategy': 'steps',
                'per_device_eval_batch_size': batch_size*2
                }
prokbert_config = ProkBERTConfig()
prokbert_config.default_torchtype = torch.long

_ = prokbert_config.get_and_set_model_parameters(model_params)
_ = prokbert_config.get_and_set_dataset_parameters(dataset_params)
_ = prokbert_config.get_and_set_pretraining_parameters(training_params)

_ = prokbert_config.get_and_set_tokenization_parameters(tokenizer.tokenization_params)
_ = prokbert_config.get_and_set_segmentation_parameters(tokenizer.segmentation_params)
_ = prokbert_config.get_and_set_computation_params(tokenizer.comp_params)

final_model_output = join(prokbert_config.model_params['model_outputpath'], prokbert_config.model_params['model_name'])




In [None]:
training_args = TrainingArguments(**prokbert_config.pretraining_params)
trainer = Trainer(
                model=fine_tuned_model,
                args=training_args,
                train_dataset=train_ds,
                eval_dataset = val_ds,
                compute_metrics=compute_metrics_eval_prediction,
            )
trainer.train()
# Saving the final model
print(f'Saving the model to: {final_model_output}')
fine_tuned_model.save_pretrained(final_model_output)


# Fine-tuned Model

The final fine-tuned model is available at the specified path, ready for deployment or further evaluation. While the current setup provides a good foundation, there's always room for improvement by experimenting with different hyperparameters. Fine-tuning these parameters can help greatly improve the model's performance on specific tasks or datasets.

## Considerations for Further Optimization

- **Experiment with Hyperparameters**: Adjust learning rate, batch size, number of epochs, and other training parameters to find the optimal configuration for your specific use case.
- **Cross-validation**: Use cross-validation techniques to ensure that your model generalizes well across different subsets of your data.
- **Data Augmentation**: Explore data augmentation strategies for sequence data, such as introducing random mutations or utilizing synthetic data generation, to increase the robustness of your model.
- **Advanced Architectures**: Consider experimenting with different model architectures or integrating additional layers (i.e. convolution could be a good idea) to improve the model's capacity to capture complex patterns in the data.

## Closing Remarks

Fine-tuning a pre-trained model like ProkBERT offers a powerful approach to leveraging large language moels for biological sequence analysis. By carefully selecting and optimizing your model's hyperparameters, you can achieve significant improvements in performance. 
