In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import DecisionTree_Classifier as dtc

In [4]:
df = pd.read_csv('./iris.csv')

cols = list(df.columns)
cols[-1] = 'label'
df.columns = cols
print(df.head())

   sepal_length  sepal_width  petal_length  petal_width   label
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa


In [5]:
train_df, test_df = dtc.train_test_split(df, test_size=0.2)

In [6]:
tree = dtc.decision_tree_algorithm(train_df, max_depth=3)
dtc.print_tree(tree)

    Is petal_width <= 0.8?
    --> True:
    Predict setosa
    --> False:
        Is petal_width <= 1.55?
        --> True:
                Is petal_length <= 5.25?
                --> True:
                Predict versicolor
                --> False:
                Predict virginica
        --> False:
        Predict virginica


In [7]:
dtc.classify_example(test_df.values[0], tree)

'versicolor'

# Rule from decision tree

In [144]:
class RuleQuestion:

    def __init__(self, dtrQuestion, true:bool):
        self.column = dtrQuestion.column
        self.value = dtrQuestion.value
        self.continuous = dtrQuestion.continuous
        self.header = dtrQuestion.header
        self.true = true

    def match(self, row):
        val = row[self.column]
        if self.continuous:
            to_return= (val <= self.value)
        else:
            to_return= (val == self.value)
        return to_return if self.true else (not to_return)
        
    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        if self.continuous:
            condition = "<=" if self.true else ">"
        else: 
            condition = "==" if self.true else '!='
    
        if self.header is None:
            header = self.column
        else:
            header = self.header
        return "%s %s %s" % (
            header, condition, str(self.value))


class Rule(object):
    
    def __init__(self, classes):
        self.rules = {}
        self.classes = classes
        for clas in self.classes:
            self.rules[clas] = []
            
        self.nodes = []
    
    def depth_first_search(self, tree):
        if isinstance(tree, dtc.Leaf):
            question_till_now = self.nodes.copy()
            self.rules[tree.predictions].append(question_till_now)
            return
        
        self.nodes.append(RuleQuestion(tree.question, True))
        self.depth_first_search(tree.true_branch)
        self.nodes.pop()
        
        self.nodes.append(RuleQuestion(tree.question, False))
        self.depth_first_search(tree.false_branch)
        self.nodes.pop()
        
        
    def print_rules(self):
        for i, (clas, rules) in enumerate(self.rules.items()):
            n_j = len(rules)-1
            for j, rule in enumerate(rules):
                n_k = len(rule)-1
                for k, qn in enumerate(rule):
                    print(f'{qn}')
                    if k==n_k: break
                    print('AND')
                if j==n_j: break
                print('---OR---')
            print(f'{clas} \n')
    
    
    def predict(self, example):
        classification = None
        rule_satisfied = None
        for clas, rules in self.rules.items():
            for rule in rules:
                for qn in rule:
                    matches = qn.match(example)
#                     print(clas, qn, matches,)
                    if matches:
                        classification = clas
                    else:
                        classification = None
                        break
                if classification is not None:
                    rule_satisfied = rules
                    break
            if classification is not None:
                break
        return classification, rule_satisfied
    

In [145]:
all_class = np.unique(df.label.values)
rule = Rule(all_class)

In [146]:
rule.rules

{'setosa': [], 'versicolor': [], 'virginica': []}

In [147]:
rule.depth_first_search(tree)
rule.rules

{'setosa': [[petal_width <= 0.8]],
 'versicolor': [[petal_width > 0.8,
   petal_width <= 1.55,
   petal_length <= 5.25]],
 'virginica': [[petal_width > 0.8, petal_width <= 1.55, petal_length > 5.25],
  [petal_width > 0.8, petal_width > 1.55]]}

In [148]:
rule.print_rules()

petal_width <= 0.8
setosa 

petal_width > 0.8
AND
petal_width <= 1.55
AND
petal_length <= 5.25
versicolor 

petal_width > 0.8
AND
petal_width <= 1.55
AND
petal_length > 5.25
---OR---
petal_width > 0.8
AND
petal_width > 1.55
virginica 



In [152]:
example = test_df.iloc[0]
classified, rule_followed = rule.predict(example.values)
print(classified)
print(rule_followed)
print(example)

versicolor
[[petal_width > 0.8, petal_width <= 1.55, petal_length <= 5.25]]
sepal_length           6.9
sepal_width            3.1
petal_length           4.9
petal_width            1.5
label           versicolor
Name: 52, dtype: object


In [106]:
test_df.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,label
52,6.9,3.1,4.9,1.5,versicolor
141,6.9,3.1,5.1,2.3,virginica
122,7.7,2.8,6.7,2.0,virginica
113,5.7,2.5,5.0,2.0,virginica
133,6.3,2.8,5.1,1.5,virginica
