# Decision Tree with CART using Information gain, Gini Impurity.

In [1]:
# Toy dataset.
# Format: each row is an example.
# The last column is the label.
# The first two columns are features.
# Feel free to play with it by adding more features & examples.
# Interesting note: I've written this so the 2nd and 5th examples
# have the same features, but different labels - so we can see how the
# tree handles this case.
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [3]:
# Column labels.
# These are used only to print the tree.
header = ["color", "diameter", "label"]

In [6]:
print(header)
training_data

['color', 'diameter', 'label']


[['Green', 3, 'Apple'],
 ['Yellow', 3, 'Apple'],
 ['Red', 1, 'Grape'],
 ['Red', 1, 'Grape'],
 ['Yellow', 3, 'Lemon']]

In [27]:
def unique_vals(rows, col):
    """Find the unique values for a column in a given dataset."""
    return set([row_element[col] for row_element in rows])

In [28]:
# get the unique entires in the required cols
unique_vals(training_data,0)

{'Green', 'Red', 'Yellow'}

In [33]:
def class_counts(rows):
    """ Counts the number of observations that belong to the final target class. 
        This function assumes the target label is the last column always. """
    
    # class : observation counts
    counts = {}
    
    # for every row check if the element is present & update accordingly.
    for row_element in rows:
        if row_element[-1] not in counts:
            counts[row_element[-1]] = 1
        counts[row_element[-1]]+=1
    return counts

In [34]:
class_counts(training_data)

{'Apple': 3, 'Grape': 3, 'Lemon': 2}

In [60]:
def is_numeric(value):
    """checks if an entry is a number."""
    return isinstance(value, int) or isinstance(value, float)

In [77]:
class Question:
    """A Question is used to partition the dataset.

    This class supports to ask questions on the dataset.
    match() checks if the dataset specified  meets the required condition.
    The condition is usually to check if the datapoint > number in the case of a numerical,
    or equals in case of a string.
    repr() makes the question more understandable - readable.
    """
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
    
    def match(self, example):
        """
        example: Sample datapoint to check on.
        """
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        return val == self.value
    
    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = '=='
        
        if is_numeric(self.value):
            condition = '>='
        
        return "Is %s %s %s?" % (
            header[self.column],condition,str(self.value))
            

In [88]:
q = Question(0,'Red')

In [89]:
Question(0,'Yellow').match(training_data[0])

False

In [90]:
def partition(rows, question):
    """Partitions a dataset.

    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    """
    true_rows = []
    false_rows = []
    
    for row_element in rows:
        if question.match(row_element):
            true_rows.append(row_element)
        else:
            false_rows.append(row_element)
    
    return true_rows,false_rows

In [92]:
true_rows, false_rows = partition(training_data, Question(0, 'Red'))

In [93]:
false_rows

[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

In [94]:
true_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]