<font size="6"><b>RECURSIVE PARTIONING TREES: BASICS</b></font>

<font size="5"><b>Serhat Ã‡evikel</b></font>

In [None]:
library(tidyverse)
library(data.table)
library(MASS) # for generating random samples from multivariate normal distribution
library(rethinking) # for generating random correlation matrix from LKJ distribution
library(rpart) # for recursive partioning trees
library(rpart.plot) # for plotting recursive partioning trees
library(visNetwork) # for better plotting recursive partioning trees
library(igraph) # for creating graphs from data.table's
library(vip) # for visualizing variable importance

In [None]:
options(repr.matrix.max.rows=20, repr.matrix.max.cols=15) # for limiting the number of top and bottom rows of tables printed 

![xkcd](../imagesba/map_age_guide_large.png)

(https://xkcd.com/1688/)

This session will be about another approach to classification problems: Decision trees, or more specifically, recursive partioning trees.

Recursive partitioning is a statistical method for multivariable analysis.Recursive partitioning creates a decision tree that strives to correctly classify members of the population by splitting it into sub-populations based on several dichotomous independent variables. The process is termed recursive because each sub-population may in turn be split an indefinite number of times until the splitting process terminates after a particular stopping criterion is reached.

(https://en.wikipedia.org/wiki/Recursive_partitioning)

A continuous feature can also be treated as a dichotomous (two-valued or binary) variable by discretizing from a cut point. Similarly a multi category variable can be treated as binary by considering multiple catergories in one value and remaining categories in the other value.

A simple decision tree based on Titanic data may look as follows:

![tree1](../imagesba/titanic_tree.png)

(https://en.wikipedia.org/wiki/Recursive_partitioning)

Here in each decision node, data is split in two based on a value or cut point of a selected variable. That split creates a subtree until all decisions do not result in further trees but ends up in leaf nodes. Each leaf node predicts a binary class. In this example, the leaf nodes classify the Titanic passengers as either "died" or "survived".

Let's see the geometric meaning of decision trees with another example. Here we have two continuous variables X1 an X2 to predict a feature:

![tree2](../imagesba/dectree2.png)

(https://link.springer.com/article/10.1007/s10044-014-0399-1)

(https://media.springernature.com/lw685/springer-static/image/art%3A10.1007%2Fs10044-014-0399-1/MediaObjects/10044_2014_399_Fig1_HTML.gif)

In a two dimensional setting, with each decision node the responses are further split into partitions by adding new vertical or horizontal decision boundaries.

How is the split made in the decision node? It is done so that node impurity is minimized. So what is impurity?

Gini impurity measures how often a randomly chosen element of a set would be incorrectly labeled if it were labeled randomly and independently according to the distribution of labels in the set. It reaches its minimum (zero) when all cases in the node fall into a single target category.

(https://en.wikipedia.org/wiki/Decision_tree_learning)

![impurity](../imagesba/impurity.png)

(https://www.baeldung.com/cs/impurity-entropy-gini-index)

(https://www.baeldung.com/wp-content/uploads/sites/4/2022/06/impurity.png)

Here the node at the left has high level of impurity: The node is comprised of equal portion of both classes. The impurity is lower in the middle node comprised mostly of the red class observations and only few blue class observations. The right node has zero impurity: All observations are of red class.

When is logistic regression or recursive partioning tree approaches more suitable for a classification problem?

![tree3](../imagesba/8_7.png)

(Garreth et al. 2023. An Introduction to Statistical Learning with Applications in R, Second Edition, Corrected Printing, p.339)

(https://www.amazon.com/Introduction-Statistical-Learning-Applications-Statistics/dp/3031387465)

(https://www.statlearning.com/resources-second-edition)

In the top row, the true decision boundary is linear and an approach like logistic regression can be successful in separating the classes (top left) while decision tree approach (top right) which can only draw a line that fixes only one dimension at a time from a cutting point is not very successful.

In the bottom row, the true decision boundary is non-linear, hence a linear model such as logistic regression (bottom left) cannot capture the separation between classes very successfully. However, a decision tree approach (bottom right) is more successful in this setting.

# Recursive Partitioning Tree Algorithm on a Toy Dataset

We will simulate a toy dataset to demonstrate decision tree algorithm for classification:

We will sample continous values from multivariate normal distribution and then discretize them to get factor variables:

# Data generation and preparation

Let's first create a random correlation matrix using the relevant LKJ distribution.
And let's sample some correlated random values from multivariate normal distribution.

First column will be the response variable, others are independent variables. All variables are discretized into factors of "yes" and "no" values for simplicity:

In [None]:
set.seed(1)
matcor <- rlkjcorr(1, 3, 0.1)
set.seed(1)
vals <- mvrnorm(2e2, rep(0, 3), matcor)
vals_dt <- as.data.table(vals)
setnames(vals_dt, c("dep", "ind1", "ind2"))
vals_dt <- vals_dt %>% mutate_all(cut, c(-Inf, 0, Inf), c("no", "yes"))

In [None]:
head(vals_dt)

Check the correlation among classes:

In [None]:
cor(vals_dt %>% mutate_all(as.integer))

Let's visualize the possible splits:

In [None]:
vals_dt %>%
mutate_at(c("ind1", "ind2"), as.integer) %>%
ggplot(aes(x = ind1, y = ind2, col = dep)) +
geom_jitter() +
geom_hline(yintercept = 1.5) +
geom_vline(xintercept = 1.5)

## Entropy/impurity

Either Shannon entropy or Gini impurity measures can be used in order to assess the class imbalance in each of the nodes:

Let's formalize this through entropy measure:

${\displaystyle \mathrm {H} (X):=-\sum _{x\in {\mathcal {X}}}p(x)\log p(x),}$

(https://en.wikipedia.org/wiki/Entropy_(information_theory))

In [None]:
entrop <- function(x)
{
    props <- prop.table(table(as.character(x)))
    sum(-props * log2(props))
}

And the gini impurity measure:

${\displaystyle \operatorname {I} _{G}(p)=1-\sum _{i=1}^{J}p_{i}^{2}.}$

(https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity)

In [None]:
ginix <- function(x)
{
    props <- prop.table(table(as.character(x)))
    1- sum(props^2)
}

Now let's make a digression to understand entropy.

For two class case, entropy value for different percentages of a case is:

In [None]:
ps <- seq(0, 1, 0.1) %>% pmax(1e-22) %>% pmin(1-(1e-16))
ps

In [None]:
psdt <- data.table(ps)
psdt[, ent := sapply(ps, function(x) -sum(c(x, 1-x) * log2(c(x, 1-x))))]
psdt

In [None]:
psdt %>%
ggplot(aes(x = ps, y = ent)) +
geom_line()

The entropy is at a maximum at ps = 0.5, so we have equal share of cases. What does that mean?

One interpretation of entropy is uncertainty, disorderliness or impurity.

But why?

Let's assume we toss a fair coin 10 times and calculate the counts of heads we get.

What are the probabilities of getting each count?

In [None]:
data.table(count = 0:10, prob = dbinom(0:10, 10, 0.5)) %>%
ggplot(aes(x = count, y = prob)) +
geom_line()

A very similar figure.

What is the probability of getting a certain sequence of head or tail values for each count of heads?

In [None]:
0.5^(0:10) * 0.5^(10:0)

Because of the symmetric probabilities of each side, all sequences or configurations have the same probability

How many configurations we may have? Size of power set

In [None]:
2^10

Total probability is:

In [None]:
0.5^10 * 2^10

Quite obvious:

$$(0.5 * 2)^{10} = 1^{10} = 1$$

Now have many configurations or sequences yield each of the head counts?

In [None]:
choose(10, 0:10)

Or:

In [None]:
factorial(10) / (factorial(10 - 0:10) * factorial(0:10))

And let's multiply this number of different ways to get a certain count and the probabilities of each configuration:

In [None]:
data.table(count = 0:10, prob = choose(10, 0:10) * 0.5^(0:10) * 0.5^(10:0)) %>%
ggplot(aes(x = count, y = prob)) +
geom_line()

That's the same figure we get from dbinom, because we used the same formulation for dbinom:

$${\displaystyle f(k,n,p)=\Pr(X=k)={\binom {n}{k}}p^{k}(1-p)^{n-k}}$$

(https://en.wikipedia.org/wiki/Binomial_distribution)

Basically we have a higher probability to get values near the middle in this distribution because **we have more ways - or combinations - to get those count of head values**.

Probability is basically number of ways of counting things.

## Partitioning

Select the independent variables to split across:

In [None]:
vars <- c("ind1", "ind2")

See the weighted entropy values of the response for each two variable:

In [None]:
ents <- sapply(vars, function(x) vals_dt[, .(N = .N, en = entrop(dep)), by = get(x)][, sum(N * en)/sum(N)])

In [None]:
ents

Let's understand the steps behind this calculation:

In the first step, the entropy of the response variable is calculated for each class label of the first and second predictor variables, along with their respective counts:

In [None]:
ents1 <- lapply(vars, function(x) vals_dt[, .(N = .N, en = entrop(dep)), by = get(x)])

In [None]:
ents1

Then the entropy for each predictor variable is weighted by the respective counts of each class:

In [None]:
sapply(ents1, function(x) x[, sum(N * en)/sum(N)])

Now let's recalculate using gini impurity values:

In [None]:
ginis <- sapply(vars, function(x) vals_dt[, .(N = .N, en = ginix(dep)), by = get(x)][, sum(N * en)/sum(N)])

In [None]:
ginis

They are parallel

Select the variable that cause the lower entropy:

In [None]:
splitvar1 <- names(ents[which.min(ents)])
splitvar1

Split the data.table across this variable's class values:

In [None]:
vals_dt_l1 <- split(vals_dt, f = vals_dt[, .(get(splitvar1))])[c("no", "yes")]

In [None]:
lapply(vals_dt_l1, head)

And repeat the entropy calculation for both splits across the other variable (there is only one variable left, anyway, nothing to compare):

In [None]:
lapply(vals_dt_l1, function(y)
    {
    sapply(setdiff(vars, splitvar1), function(x) y[, .(N = .N, en = entrop(dep)), by = get(x)][, sum(N * en)/sum(N)])
    }
)

Now let's see the information gain, the reduction in entropy, at the beginning and after each split:

First at the root:

In [None]:
counts_dt0 <- vals_dt[, .N, by = c("dep")][, prop := N / sum(N)][]
setorder(counts_dt0, dep)

In [None]:
counts_dt0

Let's visualize our tree as a graph:

In [None]:
tree1 <- make_empty_graph()
v1_lab <- counts_dt0[, paste(paste(dep, N, sep = ": "), collapse = "\n")]
tree2 <- tree1 %>% add_vertices(1, attr = list(label = v1_lab, size = 40))
plot(tree2, layout = layout_as_tree(tree2))

Get the entropy:

In [None]:
ent0 <- counts_dt0[, sum(-setdiff(prop, 0) * log2(setdiff(prop, 0)))]
ent0

And the error rate:

In [None]:
er0 <- counts_dt0[, sum((N != max(N))*N) / sum(N)]
er0

Now, after the first split:

In [None]:
counts_dt1 <- vals_dt[, .N, by = c("dep", splitvar1)][, prop := N / sum(N), by = splitvar1][]
setorder(counts_dt1, ind1, dep)

In [None]:
counts_dt1

Update our tree and visualize again:

In [None]:
v2_lab <- counts_dt1[, paste(paste(dep, N, sep = ": "), collapse = "\n"), by = ind1]
e2_lab <- paste(splitvar1, v2_lab[, get(splitvar1)], sep = "\n=\n")
tree3 <- tree2 %>% add_vertices(2, attr = list(label = v2_lab$V1, size = 40)) %>%
add_edges(c(1,2,1,3), attr = list(label = e2_lab))
plot(tree3, layout = layout_as_tree(tree3))

The entropy value:

In [None]:
ent1 <- counts_dt1[, .(N = sum(N), ent = sum(-setdiff(prop, 0) * log2(setdiff(prop, 0)))), by = splitvar1][, sum(N * ent) / sum(N)]
ent1

Entropy is reduced by:

In [None]:
ent0 - ent1

The error rate:

In [None]:
er1 <- counts_dt1[, sum((dep != ind1) * N) / sum(N)]
er1

And relative decrease in error:

In [None]:
1 - er1 / er0

Keep that value in mind!

Now the second split:

In [None]:
counts_dt2 <- vals_dt[, .N, by = c("dep", vars)][, prop := N / sum(N), by = vars][]
setorder(counts_dt2, ind1, ind2, dep)

In [None]:
counts_dt2

Again let's update our tree and visualize it:

In [None]:
v3_lab <- counts_dt2[, paste(paste(dep, N, sep = ": "), collapse = "\n"), by = c("ind1", "ind2")]
e3_lab <- paste("ind2", v3_lab[, ind2], sep = "\n=\n")
tree4 <- tree3 %>% add_vertices(4, attr = list(label = v3_lab$V1, size = 40)) %>%
add_edges(c(2,4,2,5,3,6,3,7), attr = list(label = e3_lab))
plot(tree4, layout = layout_as_tree(tree4))

Let's look at the error rate for each split:

When ind1 = "yes" condition is not split further: (labels for the splitting variable are determined such that classification error is minimized)

In [None]:
min(
counts_dt2[ind1 == "yes"][, sum((dep == ind1) * N) / sum(N)],
counts_dt2[ind1 == "yes"][, sum((dep != ind1) * N) / sum(N)])

And when the node is split further by ind2:

In [None]:
min(
counts_dt2[ind1 == "yes"][, sum((dep == ind2) * N) / sum(N)],
counts_dt2[ind1 == "yes"][, sum((dep != ind2) * N) / sum(N)])

So the error rate increases with further split on the second variable for the ind1 == "yes" split

Now let's repeat it for ind1 == "no" split

In [None]:
min(
counts_dt2[ind1 == "no"][, sum((dep == ind1) * N) / sum(N)],
counts_dt2[ind1 == "no"][, sum((dep != ind1) * N) / sum(N)])

And when the node is split further by ind2:

In [None]:
min(
counts_dt2[ind1 == "no"][, sum((dep == ind2) * N) / sum(N)],
counts_dt2[ind1 == "no"][, sum((dep != ind2) * N) / sum(N)])

The error rate decreases for that split

Let's delete the second level split on the ind1 == "yes" node:

In [None]:
counts_dt2b <- copy(counts_dt2)
counts_dt2b[ind1 == "yes", ind2 := NA]
counts_dt2b <- counts_dt2b[, .(N = sum(N)), by = c("dep", "ind1", "ind2")][, prop := N / sum(N), by = vars][]
setorder(counts_dt2b, ind1, ind2, dep)

In [None]:
counts_dt2b

And visualize the tree again:

In [None]:
v3b_lab <- counts_dt2b[!is.na(ind2), paste(paste(dep, N, sep = ": "), collapse = "\n"), by = c("ind1", "ind2")]
e3b_lab <- paste("ind2", v3b_lab[, ind2], sep = "\n=\n")
tree4b <- tree3 %>% add_vertices(2, attr = list(label = v3b_lab$V1, size = 40)) %>%
add_edges(c(2,4,2,5), attr = list(label = e3b_lab))
plot(tree4b, layout = layout_as_tree(tree4b))

Calculate the entropy:

In [None]:
ent2b <- counts_dt2b[, .(N = sum(N), ent = sum(-setdiff(prop, 0) * log2(setdiff(prop, 0)))), by = vars][, sum(N * ent) / sum(N)]
ent2b

Entropy is reduced by:

In [None]:
ent1 - ent2b

So for short, **as long as the relative classification error decreases sufficiently**, at each node, next variable for split is chosen so that entropy is reduced most

# rpart

Now let's make `rpart` function from `rpart` package: do the heavy lifting:

In [None]:
rptree <- rpart(dep ~ ., data = vals_dt)

How the splits are done:

In [None]:
rptree

Node numbers follow a breadth first search order: The nodes on the same level are scanned first then passing on to the next level.

Note the order of the labels in the dep column of the input data:

In [None]:
levels(vals_dt$dep)

Dissecting the first row:

- 1): Node number
- root: split name, either the root or the decision rule
- 200: Total number of values at the node
- 97: Misclassification losses when the majority class is taken as the basis
- yes: majority class (yvalue)
- (0.48 0.515): Probabilities of classes. Follows the order of the classes in the factor variable.

And we can compare the first line with the initial data.table of the root node:

In [None]:
counts_dt0

Nodes 2 and 3 are in line with the next level table created:

In [None]:
counts_dt1

And Nodes 4 and 5 can be compared with the last table created:

In [None]:
na.omit(counts_dt2b)

Variable importance measure is computed based on the reduction in predictive accuracy when the variable at question is removed

(https://stats.stackexchange.com/a/6485)

In [None]:
rptree$variable.importance

With `vip` package, we can extract the variable importance values as a table:

In [None]:
vi(rptree)

Or visualize those values:

In [None]:
vip(rptree)

We see that ind1 is more important than ind2

Summary of complexity parameters (CP) table:

In [None]:
printcp(rptree)

Remember the first value in the CP column: The decrease in relative error we calculated before.

Relative error is the error at the depth level divided by the error at the root node (before any splits)

xerror is cross validation error: The data is internally split into 10 parts, where one split is left for out-of-sample testing for each of the 10 runs. Cross validation error is also in relative terms, normalized with the cross validation error at the root node. The idea is that, when a tree gets too deep and complicated, there is an increasing risk of being inflexible and increased variance in the prediction accuracy of unseen data. xstd is the standard deviation of the xerror (conducted on 10 different test folds)

While rel_error always decreases (as long as the algorithm continues to create new nodes decreasing the train error), the cross validation error (test error) may start to increase again after a minimum level. In order to simplify the tree, the tree can be pruned at the point where xerror is at a minimum.

We can plot the relative cross validation error along with its standard deviation:

In [None]:
plotcp(rptree)

Let's extract the CP table

In [None]:
cpdt <- rptree$cptable %>% as.data.table

In [None]:
cpdt

Complexity parameter (CP) is the change in relative error if further splits are made, divided by the increase in number of splits:

In [None]:
cpdt[, -diff(`rel error`) / diff(nsplit)]

So while more splits decrease the relative error, increase in number of nodes penalize this error cut in complexity parameter calculation:

Let's visualize the tree:

In [None]:
rpart.plot(rptree)

or:

In [None]:
visTree(rptree)

Quite similar to our own tree:

In [None]:
plot(tree4b, layout = layout_as_tree(tree4b))

While we don't need to prune this tree, since xerror always decreases, let's try pruning for demonstration purposes.

Let's say we want to prune the tree from the second split. Let's get the complexity parameter at that split:

In [None]:
prunecp <- rptree$cptable[2, "CP"]
prunecp

Let's prune the tree by setting the cp parameter to the CP value of second row:

In [None]:
prune.tree <- prune(rptree, cp = prunecp)

In [None]:
prune.tree

We have a simpler tree now:

In [None]:
rpart.plot(prune.tree)

or:

In [None]:
visTree(prune.tree)

# Resources

- Lantz 2015, Machine Learning with R, Second Edition, Ch. 5
- Garreth et al. 2023, An Introduction to Statistical Learning with Applications in R, Second Edition, Corrected Printing, Ch. 8
- Nokeri 2021, Data Science Revealed, Ch. 8
- Yu-Wei 2015, Machine Learning with R Cookbook, Ch. 5