Last chapter we implemented a [decision tree](../05_decision_tree/decision_tree.ipynb). Our decision tree performed slightly different than the `sklearn` decision tree. This is in part due to the difference between loss functions used by these models.

A loss function is a way to compare model predictions to actual values. That should sound familiar. We've been using accuracy to compare model predictions to actual values from the start. Loss functions are used during model training to _tell the model how it's doing_, allowing the learning algorithm to adjust the model accordingly and make it better.

We've used accuracy to train and benchmark our models but in practice accuracy is usually used as a benchmarking metric and not for training. There's nothing wrong with using accuracy as a loss function, especially with rule based models like what we've been working with, but as we move towards neural networks, more effective loss functions become possible.

We will implement a different loss function for our decision tree in this chapter, not because this new loss function is better, but because it's important to understand that there's nothing magical about the loss function _as long as it's informative for the model!_

# Loss so far

A brief refresh on how our loss function works. We construct a binary tree by repeatedly splitting the data on a feature value where all data points less than that value are in one branch and all data points greater than that value are in the other branch. The best feature value to split the data on is determined by the accuracy of the branches. If the accuracy of all the data points is better than the accuracy of the branches identified from the best feature value, then binary tree construction on that set of data points stops.

We've done all the hard work. Changing the loss function is easy, it's just a matter of swapping out accuracy for an alternative.

# Gini impurity

The Gini impurity is the typical loss function I've seen used by decision trees. In order to understand this loss function you need to understand probability. But you already understand probability.

## Probability

Think back to our [baseline classifier](../02_baseline/baseline_classifier.ipynb) which always predicted the most common label in the training set. What is the accuracy of that model on the same training set? It is the ratio of the most common label over all labels. That happens to be the probability of the most common label.

At its core probability is asking the question, "how likely is event X going to happen?" It's commonly taught using simple scenarios because it's easy to grasp. Flip a coin. What's the chance it will be heads? 1/2. Roll a die. What's the chance it will be 2? 1/6. Play the lotto. What's the chance you will win? Very, very small.

That's all you need to know really and now we can implement the Gini impurity.

## How it works

We have a bunch of reviews. They have positive (1) or negative (0) labels. The Gini impurity is the probability of a positive label times the probability of _not_ a positive label plus the probability of a negative label times the probability of _not_ a negative label. That's a mouthful. Maybe we should start over and work up piece by piece with something more concrete.

Let's say 75% of the reviews are positive.

In [1]:
positive_prob = 0.75

The thing about probability is there is a 100% chance of some outcome. In our case the review must be positive or negative, so if 75% of the reviews are positive then the rest must be negative.

In [2]:
negative_prob = 1 - positive_prob
negative_prob

0.25

Now let's compute the Gini impurity using these numbers.

In [3]:
(
    positive_prob * (1 - positive_prob)
    + negative_prob * (1 - negative_prob)
)

0.375

Wait a second. This is just the positive probability times the negative probability twice!

In [4]:
positive_prob * negative_prob + negative_prob * positive_prob

0.375

Can't we just multiply the positive and negative probabilities once and be done with it? It will work the same as a loss function either way.

But does it matter when we look at 3 or more classes? Imagine we have another category of reviews which have a neutral label and say their probabilities are 50% positive, 30% negative, and 20% neutral. That will give us this gini impurity.

In [5]:
positive_prob, negative_prob, neutral_prob = 0.5, 0.3, 0.2
(
    positive_prob * (1 - positive_prob)
    + negative_prob * (1 - negative_prob)
    + neutral_prob * (1 - neutral_prob)
)

0.62

Let's go back to the earlier definition.

> The Gini impurity is the probability of a positive label times the probability of not a positive label...

In the binary classification problem (two labels) the probability of "not a positive label" is just the probability of a negative label, but in the ternary classification problem (three labels) the probability of "not a positive label" is the probability of a _negative or neutral_ label. Our above equation can be rewritten as

In [6]:
(
    positive_prob * (negative_prob + neutral_prob)
    + negative_prob * (positive_prob + neutral_prob)
    + neutral_prob * (positive_prob + negative_prob)
)
# We can further simplify the above equation.
# 1.
(
    positive_prob * negative_prob
    + positive_prob * neutral_prob
    + negative_prob * positive_prob
    + negative_prob * neutral_prob
    + neutral_prob * positive_prob
    + neutral_prob * negative_prob
)
# 2.
(
    2 * positive_prob * negative_prob
    + 2 * positive_prob * neutral_prob
    + 2 * negative_prob * neutral_prob
)

0.62

After expanding the terms and simplifying the equation we still see each combination of probabilities twice. Again I assert we can multiply them just once and our loss function will work the same.

But there is a trend. As we add more labels, we must multiply each combination of labels. This will result in exponentially more terms as the number of labels increases. If we use the first form of the Gini impurity we end up with a linearly increasing number of terms as the number of labels increases. From a practical standpoint this doesn't make a big difference as most classification problems have a limited number of labels, but it's very easy to code the first form of the Gini impurity.

In [7]:
def gini(probabilities):
    return sum(p * (1 - p) for p in probabilities)


gini([0.75, 0.25]), gini([0.5, 0.3, 0.2])

(0.375, 0.62)

## Gini impurity vs accuracy



In [8]:
# TODO: code up plot showing gini vs accuracy.