 # Machine learning methods for the analysis of bacterial genomes: classification using Random Forest

We provide here a workflow for machine learning (ML) classification of binary outcomes (e.g. healthy/sick, presence/absence, yes/no) based on a set of features or predictors (e.g. metadata) using random forest. 

To test this workflow we used a random subset (subset.csv) of the dataset utilized in Bobbo et al. (2024; https://doi.org/10.1186/s12864-024-10832-y), a study which aimed to apply ML algorithms to replicate the accurate classification of archaea and bacteria and to extract the relevant genomic features that drive their classification. Archaea and Bacteria are distinct domains of life that are adapted to a variety of ecological niches. Several genome-based methods have been developed for their accurate classification, yet many aspects of the specific genomic features that determine these differences are not fully understood. In Bobbo et al. (2024), we used publicly available whole-genome sequences from bacteria and archaea. From these, a set of genomic features (nucleotide frequencies and proportions, coding sequences (CDS), non-coding, ribosomal and transfer RNA genes (ncRNA, rRNA, tRNA), Chargaff’s, topological entropy and Shannon’s entropy scores) was extracted using GBRAP tool and used as input data to develop ML models for the classification of archaea and bacteria.
For this workflow, the input dataset should include, in order: binary Outcome (e.g., Archaea/Bacteria), ID (e.g., sample ID), features/predictors. A total of 363 records (109 Archaea and 254 Bacteria) and 79 genomic features (numerical variables) were considered in "subset.csv". 

## Parameters setting
The parameters that the user have to set prior to the analysis are listed below:

- input_file: input dataset 
- split_ratio: dataset train/test split proportion (e.g., 80:20 = 0.80, 70:30 = 0.70);
- k_folds: number of folds in (repeated) K-fold crossvalidation  (e.g., 3,5,10);
- nrepeats_cv: number of k-fold crossvalidation repeats (e.g., 10,100,1000; 1 is recommended as a first test) during model training;

The ML analysis is performed using the R Tidymodels package and detailed information can be found at https://www.tidymodels.org/.

In [1]:
input_file = "subset.csv" 
split_ratio = 0.80 
k_folds = 10 
nrepeats_cv = 3 

## Install libraries
Install required libraries.

In [None]:
install.packages('vip') 
install.packages('randomForest') 
#install.packages('ggplot2')
#install.packages('tidyverse') 
#install.packages('data.table') 
install.packages('tidymodels') 
install.packages('themis') 

## Load libraries
Load required libraries.

In [None]:
library("vip")
library("randomForest")
library("ggplot2")
library("tidyverse")
library("data.table")
library("tidymodels")
library("themis")  # for step_upsample

## Import dataset
- Import dataset : specify "." for decimals; ".", "-", "NA" will be considered as missing values; convert character columns to factors
- Check dataset dimension (number of records and columns)
- Visualize first six records
- Check structure of the dataset : binary outcome and ID should be considered as factors, features can be all numeric (e.g, height, weight), all factors (e.g., sex) or both.


In [None]:
dataset <- fread(input_file, dec = ".", na.strings = c(".", "-", "NA"), stringsAsFactors = TRUE)
dim(dataset)
head(dataset)
str(dataset)

## Dataset preprocessing : missing values
This workflow works only with complete datasets (no missing values), so sanity check before running the analysis in required.


In [None]:
# Sanity check

if (sum(is.na(dataset)) == 0) {
  print("No missing data in the dataset: OK! Go to 'Dataset preprocessing : descriptive statistics'")
} else {
  print("Missing data in the dataset: please remove them! Go to 'Keep only complete records' cell")
}

If there are no missing values, proceed with the analysis. 
If missing values are detected after the sanity check, please remove them with the appropriate code using complete.cases().

In [None]:
# Keep only complete records and check dataset dimension

dataset <- dataset[complete.cases(dataset), ]
dim(dataset)

## Dataset preprocessing : descriptive statistics
Descriptive statistics (frequencies or distribution) of all variables.


In [None]:
summary(dataset)

## Machine learning analysis


### Training/test split

The dataset will be splitted into a subset used to train and validate the model, and a subset that will be used to test the model's performance


In [None]:
# parameter "split_ratio" will be applied to choose the ratio used to split tha dataset by outcome (e.g., 80:20, 70:30)

rf_dt <- select(dataset, -c(ID))
rf_split <- initial_split(rf_dt, strata = Outcome, prop = split_ratio)
rf_train <- training(rf_split)
rf_test <- testing(rf_split)

# Sanity check on outcome frequencies in train and test sets

rf_train %>% count(Outcome) |> print()
rf_test %>% count(Outcome) |> print()

## PreprocessingWe use Tidymodels to build a recipe for data preprocessing: -   remove correlated variables -   remove non informative variables (zero variance) -   upsampling to handle output class imbalance
 -   standardize all variables

In [None]:
rf_recipe <- rf_train %>%
  recipe(Outcome ~ .) %>%
  step_corr(all_predictors(), threshold = 0.99) %>%
  step_zv(all_numeric(), -all_outcomes()) %>%
  step_normalize(all_numeric(), -all_outcomes())

prep_rf <- prep(rf_recipe)
print(prep_rf)

training_set <- juice(prep_rf) # Extract transformed training set
head(training_set)

## Model building

### Model training
We now specify the structure of our model:-   hyperparameters to tune: `mtry` (number of features to sample for each tree) and `min_n` (minimum number of data points in a node to allow further splitting)-   number of trees in the forest-   the problem at hand (classification)-   the engine (R package)Then we put this in a workflow together with the preprocessing recipe.

In [9]:
tune_spec <- rand_forest(
  mtry = tune(),
  trees = 500,
  min_n = tune()
) %>%
  set_mode("classification") %>%
  set_engine("randomForest")

tune_wf <- workflow() %>%
  add_formula(Outcome ~ .) %>%
  add_model(tune_spec)


### Tuning of hyperparameters
We use k-fold cross-validation to tune the hyperparameters in the training set.

In [None]:
# Parameters "k_folds" and "nrepeats_cv" will be applied to choose number of folds in (repeated) K-fold crossvalidation 
# and number of k-fold crossvalidation repeats during model training.
# Several metrics (e.g. accuracy, AUC, MCC) will be calculated.

set.seed(123)
trees_folds <- vfold_cv(training_set, v = k_folds, repeats = nrepeats_cv)

# In Random Forest models, mtry is the number of features randomly selected at each split.
# A common rule of thumb for classification is √p, where p is the number of predictors (excluding the target).
# ncol(training_set)-1 excludes the target variable from the count.

m <- round(sqrt(ncol(training_set)-1),0)
print(m)

rf_grid <- grid_regular(
  mtry(range = c(m-5, m+5)),
  min_n(range = c(2, 10)),
  levels = c(5,5) # 5 different values will be generated evenly spaced between m-5 and m+5 (mtry) and 2 and 10  (min_n)
)

# Displays the first few rows of the grid and the total number of combinations.

head(rf_grid)
nrow(rf_grid)

# Performs model tuning using the tune_wf workflow.
# Evaluation metrics:
# - roc_auc: Area Under the ROC Curve.
# - accuracy: Overall classification accuracy.
# - mcc: Matthews Correlation Coefficient (great for imbalanced datasets).

regular_res <- tune_grid(
  tune_wf,
  metrics = metric_set(roc_auc, accuracy, mcc),
  resamples = trees_folds,  # provides the cross-validation strategy
  grid = rf_grid
)

# Collects and prints the average performance metrics for each combination of hyperparameters

regular_res |>
  collect_metrics() |>
  print()


# Plots the MCC metric across different values of mtry, grouped by min_n.
# Helps visualize which hyperparameter combinations performed best.

library("repr")
options(repr.plot.width=14, repr.plot.height=8)

regular_res %>%
  collect_metrics() %>%
  filter(.metric == "mcc") %>%
  mutate(min_n = factor(min_n)) %>%
  ggplot(aes(mtry, mean, color = min_n)) +
  geom_line(alpha = 0.5, size = 1.5) +
  geom_point() +
  labs(y = "mcc")

# Selects the best performing parameter set based on the MCC score.

best_auc <- select_best(x = regular_res, metric = "mcc")
show_best(regular_res, metric = "mcc") # Lists the top combinations ranked by MCC.

### Final model

In [None]:
# It replaces the placeholders with the best values, creating the final tuned model.

final_rf <- finalize_model(
  tune_spec,  # your model specification with placeholder hyperparameters
  best_auc  # the best hyperparameter combination found during tuning (in this case, chosen using MCC metric earlier)
)

print(final_rf)


# Finalise the workflow including the final RF model and fit it to the initial split (training and test data):

final_wf <- workflow() %>%
  add_recipe(rf_recipe) %>%
  add_model(final_rf)

# Fits the finalized workflow to the training portion of rf_split and evaluates the model on the test portion of rf_split.

final_res <- final_wf %>%
  last_fit(rf_split, metrics = metric_set(roc_auc, accuracy, mcc, brier_class))

# Evaluate the fine-tuned rf model

print(final_res)
final_res %>%
  collect_metrics()

## Get variable importance

In [None]:
final_res %>% 
  pluck(".workflow", 1) %>%   
  extract_fit_parsnip() %>% 
  vip(num_features = 16)

## Predictions on test set

The predictive ability of the ML methods on the test set (final evaluation of model performance) will be assessed based on several metrics obtained from the confusion matrix, including accuracy, sensitivity, specificity and the Matthew’s Correlation Coefficient ("MCC").

In [None]:
# We collect the predictions on the test set

final_res %>%
  collect_predictions()

cm <- final_res %>%
  collect_predictions() %>%
  conf_mat(Outcome, .pred_class)

print(cm)

# Calculate classification metrics on test set
# "Bacteria" is the positive class when using event_level = "second"

test_metrics <- final_res %>%
  collect_predictions() %>%
  summarise(
    accuracy = accuracy_vec(truth = Outcome, estimate = .pred_class),
    sensitivity = sens_vec(truth = Outcome, estimate = .pred_class, event_level = "second"),
    specificity = spec_vec(truth = Outcome, estimate = .pred_class, event_level = "second"),
    mcc = mcc_vec(truth = Outcome, estimate = .pred_class)
  )

print("Test set metrics:")
print(test_metrics)


print("DONE!!")

## Save workspace

In [18]:
save.image("rf_workshop.RData")