# Classification trees

Classification trees are very similar to regression trees, except that the target variable is categorical. In regression trees, we used the mean of the target variable in each region as the prediction. In classification trees, we use the most common class in each region as the prediction. Besides the most common class, we are also interesting in the proportion of each class in each region. This is because the proportion of each class in each region is a measure of the **purity** of the region. A region is pure if it contains only one class. A region is impure if it contains multiple classes.

## Gini index

Although classification accuracy is a good measure for classification performance, a popular cost function (splitting criterion) for classification trees is the [**Gini index**](https://en.wikipedia.org/wiki/Gini_coefficient), a measure of the impurity of a region (not just feature regions in machine learning, but also regions in society often heard in news). This is because classification accuracy is not sufficiently sensitive for building classification trees. 

The Gini index is defined as:

$$
G = 1 - \sum_{c=1}^C p_c^2 = \sum_{c=1}^C p_c (1 - p_c),
$$

where $p_c$ is the proportion of class $c$ in the region. The Gini index is zero if the region is pure, and is one if the region is impure. The Gini index is a measure of the probability of misclassification. The Gini index is also known as the **Gini impurity**.

## Entropy

Another popular cost function (splitting criterion) for classification trees is the [**entropy**](https://en.wikipedia.org/wiki/Entropy_(information_theory)), which is defined as:

$$
H = - \sum_{c=1}^C p_c \log_2 p_c,
$$

where $p_c$ is the proportion of class $c$ in the region. The entropy is zero if the region is pure, and is one if the region is impure. 

When building classification trees, either the Gini index or the entropy are typically used to evaluate the quality of a particular split, and the split that produces the lowest cost is chosen. The Gini index and the entropy are very similar, and the Gini index is slightly faster to compute.

## Classification trees for heart disease diagnosis

Get ready by importing the APIs needed from respective libraries.

In [None]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

%matplotlib inline

We use the [Heart dataset](https://github.com/pykale/transparentML/blob/main/data/Heart.csv) (click to explore) to predict whether a patient has heart disease or not. The target variable is `AHD`, which is a binary variable that indicates whether a patient has heart disease or not. 

Load the dataset and _remove the rows with missing values_. Inspect the first few rows of the dataset.

In [None]:
heart_url = "https://github.com/pykale/transparentML/raw/main/data/Heart.csv"

heart_df = pd.read_csv(heart_url).dropna()
heart_df.head()

Take a look at the structure of the dataset.

In [None]:
heart_df.info()

Use the [`factorize` method](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.factorize.html) from the `pandas` library to convert categorical variables to numerical variables. This is because the `sklearn` library only accepts numerical variables. `ChestPain` is a categorical variable that indicates the type of chest pain. `Thal` is a categorical variable that indicates the type of thalassemia. 

In [None]:
heart_df.ChestPain = pd.factorize(heart_df.ChestPain)[0]
heart_df.Thal = pd.factorize(heart_df.Thal)[0]

Get the feature matrix `X` and the target vector `y`.

In [None]:
X = heart_df.drop("AHD", axis=1)
y = pd.factorize(heart_df.AHD)[0]

Fit a classification tree to the data using the `DecisionTreeClassifier` class from the `sklearn.tree` module. Set the `max_leaf_nodes` parameter to 6. This will limit the number of leaf nodes to 6. Set the `max_features` parameter to 3. This will limit the number of features to consider when looking for the best split to 3. 

In [None]:
clf = DecisionTreeClassifier(max_leaf_nodes=6, max_features=3)
clf.fit(X, y)

Compute the accuracy of the classification tree on the training data.

In [None]:
clf.score(X, y)

Visualise the classification tree using the `plot_tree` function from the `sklearn.tree` module. Set the `filled` parameter to `True` to colour the nodes in the tree according to the majority class in each region. Set the `feature_names` parameter to the list of feature names. Set the `class_names` parameter to the list of class names. 

In [None]:
plt.figure(figsize=(20, 10))
plot_tree(
    clf, filled=True, feature_names=X.columns, class_names=["No", "Yes"], fontsize=14
)
plt.show()

## Exercises

min 3 max 5

