In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import DecisionTree_Classifier as dtc

In [3]:
df = pd.read_csv('./iris.csv')

cols = list(df.columns)
cols[-1] = 'label'
df.columns = cols
print(df.head())

   sepal_length  sepal_width  petal_length  petal_width   label
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa


In [4]:
train_df, test_df = dtc.train_test_split(df, test_size=0.2)

In [5]:
tree = dtc.decision_tree_algorithm(train_df, max_depth=3)
dtc.print_tree(tree)

    Is petal_width <= 0.8?
    --> True:
    Predict setosa
    --> False:
        Is petal_width <= 1.65?
        --> True:
                Is petal_length <= 4.95?
                --> True:
                Predict versicolor
                --> False:
                Predict virginica
        --> False:
        Predict virginica


In [6]:
test_df.values[0]

array([5.1, 2.5, 3.0, 1.1, 'versicolor'], dtype=object)

In [7]:
dtc.classify_example(test_df.values[0], tree)

'versicolor'

In [8]:
tree.question, tree.question.continuous, tree.question.column

(Is petal_width <= 0.8?, True, 3)

In [13]:
all_class = np.unique(df.label)
all_class

array(['setosa', 'versicolor', 'virginica'], dtype=object)

In [22]:
class RuleQuestion:

    def __init__(self, dtrQuestion, true:bool):
        self.column = dtrQuestion.column
        self.value = dtrQuestion.value
        self.continuous = dtrQuestion.continuous
        self.header = dtrQuestion.header
        self.true = true

    def match(self, row):
        val = row[self.column]
        if self.continuous:
            to_return= (val <= self.value)
        else:
            to_return= (val == self.value)
        return to_return if self.true else not to_return
        
    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        if self.continuous:
            condition = "<=" if self.true else ">"
        else: 
            condition = "==" if self.true else '!='
    
        if self.header is None:
            header = self.column
        else:
            header = self.header
        return "%s %s %s" % (
            header, condition, str(self.value))


class Rule(object):
    
    def __init__(self, classes):
        self.rules = {}
        self.classes = classes
        for clas in self.classes:
            self.rules[clas] = []
            
        self.nodes = []
    
    def depth_first_search(self, tree):
        if isinstance(tree, dtc.Leaf):
            question_till_now = self.nodes.copy()
            self.rules[tree.predictions].append(question_till_now)
            return
        
        self.nodes.append(RuleQuestion(tree.question, True))
        self.depth_first_search(tree.true_branch)
        self.nodes.pop()
        
        self.nodes.append(RuleQuestion(tree.question, False))
        self.depth_first_search(tree.false_branch)
        self.nodes.pop()
        
        
    def print_rules(self):
#         n_i = len(rules)
        for i, (clas, rules) in enumerate(self.rules.items()):
            n_j = len(rules)-1
            for j, rule in enumerate(rules):
                n_k = len(rule)-1
                for k, qn in enumerate(rule):
                    print(f'{qn}')
                    if k==n_k: break
                    print('AND')
                if j==n_j: break
                print('---OR---')
#             if j==n_i: break    
            print(f'{clas} \n')

In [23]:
rule = Rule(all_class)

In [24]:
rule.rules

{'setosa': [], 'versicolor': [], 'virginica': []}

In [25]:
rule.depth_first_search(tree)

In [26]:
rule.rules

{'setosa': [[petal_width <= 0.8]],
 'versicolor': [[petal_width > 0.8,
   petal_width <= 1.65,
   petal_length <= 4.95]],
 'virginica': [[petal_width > 0.8, petal_width <= 1.65, petal_length > 4.95],
  [petal_width > 0.8, petal_width > 1.65]]}

In [27]:
dtc.print_tree(tree)

    Is petal_width <= 0.8?
    --> True:
    Predict setosa
    --> False:
        Is petal_width <= 1.65?
        --> True:
                Is petal_length <= 4.95?
                --> True:
                Predict versicolor
                --> False:
                Predict virginica
        --> False:
        Predict virginica


In [28]:
rule.print_rules()

petal_width <= 0.8
setosa 

petal_width > 0.8
AND
petal_width <= 1.65
AND
petal_length <= 4.95
versicolor 

petal_width > 0.8
AND
petal_width <= 1.65
AND
petal_length > 4.95
---OR---
petal_width > 0.8
AND
petal_width > 1.65
virginica 

