## Random Forest for multiclass classification (tidymodels inside)

We now move on from binary to multiclass classification, and put together also the use of `tidymodels`. We use the same dataset on diabetes and metabolomics that we used for the Lasso model using `tidymodels`

In [1]:
library("vip")
library("ggplot2")
library("tidyverse")
library("tidymodels")
library("data.table")
library("randomForest")


Attaching package: ‘vip’


The following object is masked from ‘package:utils’:

    vi


── [1mAttaching packages[22m ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.1 ──

[32m✔[39m [34mtibble [39m 3.1.6     [32m✔[39m [34mdplyr  [39m 1.0.6
[32m✔[39m [34mtidyr  [39m 1.1.3     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 2.1.2     [32m✔[39m [34mforcats[39m 0.5.1
[32m✔[39m [34mpurrr  [39m 0.3.4     

── [1mConflicts[22m ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

Registered S3 method overwritten by 'tune':
  method  

In [2]:
mtbsl1 <- fread("../data/MTBSL1.tsv")
names(mtbsl1)[c(4:ncol(mtbsl1))] <- paste("mtbl",seq(1,ncol(mtbsl1)-3), sep = "_")

#### Creating the multinomial variable 

We combine the variables `Gender` and `Metabolic_syndrom` to create a synthetic outcome variable with four classes:

In [3]:
mtbsl1$gender_status <- paste(mtbsl1$Gender,mtbsl1$Metabolic_syndrome,sep="_")
mtbsl1 %>% group_by(gender_status) %>%
    summarise(N=n())

gender_status,N
<chr>,<int>
Female_Control Group,28
Female_diabetes mellitus,26
Male_Control Group,56
Male_diabetes mellitus,22


#### Data splitting

We first split the data in the training and test sets (stratifying by the categorical outcome):

In [4]:
diab_dt <- select(mtbsl1, -c(`Primary ID`, Gender, Metabolic_syndrome))
mtbsl1_split <- initial_split(diab_dt, strata = gender_status, prop = 0.75)
mtbsl1_train <- training(mtbsl1_split)
mtbsl1_test <- testing(mtbsl1_split)

nrow(mtbsl1_train)
nrow(mtbsl1_test)

#### Preprocessing

We use tidymodels to build a recipe for data preprocessing:

- remove correlated variables
- remove non informative variables (zero variance)
- standardize all variables
- impute missing data (Random Forest does not handle missing data)

In [5]:
diab_recipe <- mtbsl1_train %>%
  recipe(gender_status ~ .) %>%
  step_corr(all_predictors(), threshold = 0.9) %>%
  step_zv(all_numeric(), -all_outcomes()) %>%
  step_normalize(all_numeric(), -all_outcomes()) %>%
  step_impute_knn(all_numeric(), neighbors = 5) ## there are no missing data here, but in case!

In [6]:
prep_diab <- prep(diab_recipe)
print(prep_diab)

Data Recipe

Inputs:

      role #variables
   outcome          1
 predictor        188

Training data contained 98 data points and no missing data.

Operations:

Correlation filter removed mtbl_12, mtbl_27, mtbl_39, ... [trained]
Zero variance filter removed no terms [trained]
Centering and scaling for mtbl_1, mtbl_2, mtbl_3, mtbl_4, mtbl_5, ... [trained]
K-nearest neighbor imputation for mtbl_2, mtbl_3, mtbl_4, mtbl_5, mtbl_6, ... [trained]


In [7]:
training_set <- juice(prep_diab)
head(training_set)

mtbl_1,mtbl_2,mtbl_3,mtbl_4,mtbl_5,mtbl_6,mtbl_7,mtbl_8,mtbl_9,mtbl_10,⋯,mtbl_174,mtbl_175,mtbl_177,mtbl_178,mtbl_180,mtbl_181,mtbl_182,mtbl_184,mtbl_186,gender_status
<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,⋯,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<fct>
0.848389,-0.51527203,-0.6659165,-0.2974594,-0.3921862,-0.3841426,-0.4831229,-1.15946671,-0.7027864,-0.4415499,⋯,-0.3975454,-0.6502834,-0.9914318,-0.3636709,-0.3895601,-0.356787,-0.4000826,-0.1833678,-0.1632953,Female_Control Group
-0.7004617,-0.51527203,-0.6659165,-0.2974594,-0.3921862,-0.3841426,-0.4831229,0.32415278,-0.7027864,-0.4415499,⋯,-0.3975454,-0.6502834,-0.9914318,-0.3636709,-0.3895601,-0.356787,-0.4000826,-0.1833678,-0.1632953,Female_Control Group
1.9628136,-0.51527203,0.3467648,-0.2974594,-0.3921862,-0.3841426,-0.4831229,-1.11536714,-0.7027864,-0.4415499,⋯,-0.3975454,-0.6502834,-0.9914318,-0.3636709,-0.3895601,-0.356787,-0.4000826,-0.1833678,-0.1632953,Female_Control Group
2.9997977,0.07792547,2.9066275,4.5197979,0.3712267,0.2017653,5.0664146,0.05415026,0.9020492,0.8545959,⋯,-0.3975454,0.746297,-0.7430283,-0.3636709,-0.3895601,-0.356787,-0.4000826,-0.1833678,-0.1632953,Female_Control Group
-0.7103493,-0.51527203,-0.6659165,-0.2974594,-0.3921862,-0.3841426,-0.4831229,0.51894333,-0.4544876,-0.4415499,⋯,0.6136194,1.3724225,1.1346307,0.3828432,0.2182378,0.9800042,-0.4000826,-0.1258766,-0.1632953,Female_Control Group
0.3723048,0.03711451,-0.6659165,-0.2974594,-0.3921862,-0.3841426,-0.4831229,-1.15946671,-0.7027864,-0.4415499,⋯,-0.3975454,-0.6502834,-0.9914318,-0.3636709,-0.3895601,-0.356787,-0.4000826,-0.1833678,-0.1632953,Female_Control Group


#### Model building

We now specify the structure of our model:

- hyperparameters to tune: `mtry` (number of features to sample for each tree) and `min_n` (minimum number of data points in a node to allow further splitting)
- number of trees in the forest
- the problem at hand (classification)
- the engine (R package)

Then we put this in a workflow together with the preprocessing recipe

In [None]:
tune_spec <- rand_forest(
  mtry = tune(),
  trees = 100,
  min_n = tune()
) %>%
  set_mode("classification") %>%
  set_engine("randomForest")

In [None]:
tune_wf <- workflow() %>%
  add_formula(gender_status ~ .) %>%
  add_model(tune_spec)

#### Tuning the hyperparameters

We use k-fold cross-validation to tune the hyperparameters in the training set

In [None]:
trees_folds <- vfold_cv(training_set, v = 5, repeats = 5)

In [None]:
print(trees_folds)

In [None]:
doParallel::registerDoParallel()

tune_res <- tune_grid(
  tune_wf,
  resamples = trees_folds,
  grid = 20
)


In [None]:
print(tune_res)

In [None]:
library("repr")
options(repr.plot.width=14, repr.plot.height=8)

tune_res %>%
  collect_metrics() %>%
  filter(.metric == "roc_auc") %>%
  select(mean, min_n, mtry) %>%
  pivot_longer(min_n:mtry,
               values_to = "value",
               names_to = "parameter"
  ) %>%
  ggplot(aes(value, mean, color = parameter)) +
  geom_point(show.legend = FALSE) +
  facet_wrap(~parameter, scales = "free_x") +
  labs(x = NULL, y = "AUC")

We now try to start from $\sqrt{p}$  (classification problem)

In [8]:
m <- round(sqrt(ncol(training_set)-1),0)
print(m)
rf_grid <- grid_regular(
  mtry(range = c(m-2, m+2)),
  min_n(range = c(8, 12)),
  levels = 3
)

[1] 12


In [9]:
print(rf_grid)

[90m# A tibble: 9 × 2[39m
   mtry min_n
  [3m[90m<int>[39m[23m [3m[90m<int>[39m[23m
[90m1[39m    10     8
[90m2[39m    12     8
[90m3[39m    14     8
[90m4[39m    10    10
[90m5[39m    12    10
[90m6[39m    14    10
[90m7[39m    10    12
[90m8[39m    12    12
[90m9[39m    14    12


In [11]:
regular_res <- tune_grid(
  tune_wf,
  resamples = trees_folds,
  grid = rf_grid
)

ERROR: Error in tune_grid(tune_wf, resamples = trees_folds, grid = rf_grid): object 'tune_wf' not found


In [10]:
print(regular_res)

ERROR: Error in print(regular_res): object 'regular_res' not found


In [None]:
regular_res %>%
  collect_metrics() %>%
  filter(.metric == "roc_auc") %>%
  mutate(min_n = factor(min_n)) %>%
  ggplot(aes(mtry, mean, color = min_n)) +
  geom_line(alpha = 0.5, size = 1.5) +
  geom_point() +
  labs(y = "AUC")

#### Final model

We now select the best model from the hyperparameters tuning, and fit it to the training set:

1. selecting the best model based on AUC:

In [None]:
best_auc <- select_best(tune_res, "roc_auc")
print(best_auc)

2. finalise the model:

In [None]:
final_rf <- finalize_model(
  tune_spec,
  best_auc
)

print(final_rf)

3. finalise the workflow and fit it to the initial split (training and test data):

In [None]:
final_wf <- workflow() %>%
  add_recipe(diab_recipe) %>%
  add_model(final_rf)

final_res <- final_wf %>%
  last_fit(mtbsl1_split)

4. evaluate the fine-tuned RF model:

In [None]:
print(final_res)
final_res %>%
  collect_metrics()

5. get variable importance:

In [None]:
final_res %>% 
  pluck(".workflow", 1) %>%   
  pull_workflow_fit() %>% 
  #vip(num_features = 20, geom = "point")
  vip(num_features = 25)

#### Predictions

We collect the predictions on the test set: for each test observations we get the probabilities of belonging to each of the four classes.

In [None]:
final_res %>%
  collect_predictions()

In [None]:
cm <- final_res %>%
  collect_predictions() %>%
  conf_mat(gender_status, .pred_class)

print(cm)

In [None]:
autoplot(cm, type = "heatmap")