In [108]:
#training_data=[
#    ['Green',3,'Apple'],
#    ['Yellow',3,'Apple'],
#    ['Red',1,'Grape'],
#    ['Red',1,'Grape'],
#    ['Yellow',3,'Lemon']
#]

training_data=[['Sunny','Hot','High','Weak','No'],
['Sunny','Hot','High','Strong','No'],
['Overcast','Hot','High','Weak','Yes'],
['Rain','Mild','High','Weak','Yes'],
['Rain','Cool','Normal','Weak','Yes'],
['Rain','Cool','Normal','Strong','No'],
['Overcast','Cool','Normal','Strong','Yes'],
['Sunny','Mild','High','Weak','No'],
['Sunny','Cool','Normal','Weak','Yes'],
['Rain','Mild','Normal','Weak','Yes'],
['Sunny','Mild','Normal','Strong','Yes'],
['Overcast','Mild','High','Strong','Yes'],
['Overcast','Hot','Normal','Weak','Yes'],
['Rain','Mild','High','Strong','No'],]

In [109]:
#header=["Color","Diameter","Label"]

header=["outlook","temp","humidity","wind","play"]

In [110]:
def unique_vals(Data,col):
    return set([row[col] for row in Data])

In [111]:
unique_vals(training_data, 1)

{'Cool', 'Hot', 'Mild'}

In [112]:
unique_vals(training_data, 0)

{'Overcast', 'Rain', 'Sunny'}

In [113]:
def class_counts(Data):
    counts={}
    for row in Data:
        label=row[-1]
        if label not in counts:
            counts[label]=0
        counts[label]+=1
    return counts

In [114]:
class_counts(training_data)

{'No': 5, 'Yes': 9}

In [115]:
def is_numeric(value):
    return isinstance(value,int) or isinstance(value, float)

In [116]:
is_numeric(5)

True

In [117]:
class Question:
    def __init__(self,column,value):
        self.column=column
        self.value=value
    def match(self,example):
        val=example[self.column]
        if is_numeric(val):
            #return val>=self.value
            return val==self.value
        else:
            return val==self.value
    def __repr__(self):
        condition="=="
        if is_numeric(self.value):
            condition=">="
        return "Is %s %s %s?"% (header[self.column],condition, str(self.value))

In [141]:
Question(0,'Sunny')

Is outlook == Sunny?

In [142]:
q=Question(1,'HOT')
q

Is temp == HOT?

In [143]:
q.match(training_data[2])

False

In [144]:
def partition(rows,question):
    true_rows,false_rows=[],[]
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows,false_rows

In [145]:
true_rows,false_rows=partition(training_data,Question(1,'Hot'))

In [146]:
true_rows

[['Sunny', 'Hot', 'High', 'Weak', 'No'],
 ['Sunny', 'Hot', 'High', 'Strong', 'No'],
 ['Overcast', 'Hot', 'High', 'Weak', 'Yes'],
 ['Overcast', 'Hot', 'Normal', 'Weak', 'Yes']]

In [147]:
false_rows

[['Rain', 'Mild', 'High', 'Weak', 'Yes'],
 ['Rain', 'Cool', 'Normal', 'Weak', 'Yes'],
 ['Rain', 'Cool', 'Normal', 'Strong', 'No'],
 ['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'],
 ['Sunny', 'Mild', 'High', 'Weak', 'No'],
 ['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'],
 ['Rain', 'Mild', 'Normal', 'Weak', 'Yes'],
 ['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'],
 ['Overcast', 'Mild', 'High', 'Strong', 'Yes'],
 ['Rain', 'Mild', 'High', 'Strong', 'No']]

In [148]:
def gini(rows):
    counts=class_counts(rows)
    impurity=1
    for lbl in counts:
        prob_of_lbl=counts[lbl]/float(len(rows))
        impurity-=prob_of_lbl**2
    return impurity

In [149]:
x=[['apple'],['apple'],['grapes'],['grapes']]
gini(x)

0.5

In [150]:
gini(training_data)

0.4591836734693877

In [151]:
def info_gain(left,right,current_uncertainity):
    p=float(len(left))/(len(left)+len(right))
    q=float(len(right))/(len(left)+len(right))
    return current_uncertainity-p*gini(left)-q*gini(right)

In [155]:
true_rows,false_rows=partition(training_data,Question(1,'Hot'))

In [156]:
true_rows

[['Sunny', 'Hot', 'High', 'Weak', 'No'],
 ['Sunny', 'Hot', 'High', 'Strong', 'No'],
 ['Overcast', 'Hot', 'High', 'Weak', 'Yes'],
 ['Overcast', 'Hot', 'Normal', 'Weak', 'Yes']]

In [157]:
false_rows

[['Rain', 'Mild', 'High', 'Weak', 'Yes'],
 ['Rain', 'Cool', 'Normal', 'Weak', 'Yes'],
 ['Rain', 'Cool', 'Normal', 'Strong', 'No'],
 ['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'],
 ['Sunny', 'Mild', 'High', 'Weak', 'No'],
 ['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'],
 ['Rain', 'Mild', 'Normal', 'Weak', 'Yes'],
 ['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'],
 ['Overcast', 'Mild', 'High', 'Strong', 'Yes'],
 ['Rain', 'Mild', 'High', 'Strong', 'No']]

In [158]:
info_gain(true_rows,false_rows,0.639)

0.19614285714285712

In [159]:
def find_best_split(rows):
    best_gain=0
    best_question=None
    current_uncertainity=gini(rows)
    n_features=len(rows[0])-1 #no of column
    for col in range(n_features): #for each feature
        values=set([row[col]for row in rows])#unique values in the dataset
        for val in values: #for each value
            question=Question(col,val)
            true_rows,false_rows=partition(rows,question)
            if len(true_rows)==0 or len(false_rows)==0:
                continue
            gain=info_gain(true_rows,false_rows,current_uncertainity)
            if gain>=best_gain:
                best_gain,best_question=gain,question
    return best_gain,best_question

In [160]:
best_gain,best_question=find_best_split(training_data)
print(best_question)
print(best_gain)

Is outlook == Overcast?
0.10204081632653056


In [161]:
class Leaf:
    def __init__(self,rows):
        self.predictions=class_counts(rows)

In [162]:
class Decision_Node:
    def __init__(self,question,true_branch,false_branch):
        self.question=question
        self.true_branch=true_branch
        self.false_branch=false_branch

In [163]:
def build_tree(rows):
    gain,question=find_best_split(rows)
    if gain==0:
        return Leaf(rows)
    true_rows,false_rows=partition(rows,question)
    true_branch=build_tree(true_rows)
    false_branch=build_tree(false_rows)
    return Decision_Node(question,true_branch,false_branch)

In [164]:
my_tree=build_tree(training_data)
print(my_tree)

<__main__.Decision_Node object at 0x0000028A0889EF98>


In [165]:
def print_tree(node,spacing=""):
    if isinstance(node,Leaf):
        print(spacing+"Predict",node.predictions)
        return
    #Print the question at this node
    print(spacing+str(node.question))
    #Call this function recursively on the true branch
    print(spacing+"-->True:")
    print_tree(node.true_branch,spacing+" ")
    #Call this function recursively on the false branch
    print(spacing+"-->False:")
    print_tree(node.false_branch,spacing+" ")

In [140]:
print_tree(my_tree)

Is outlook == Overcast?
-->True:
 Predict {'Yes': 4}
-->False:
 Is humidity == High?
 -->True:
  Is outlook == Rain?
  -->True:
   Is wind == Weak?
   -->True:
    Predict {'Yes': 1}
   -->False:
    Predict {'No': 1}
  -->False:
   Predict {'No': 3}
 -->False:
  Is wind == Weak?
  -->True:
   Predict {'Yes': 3}
  -->False:
   Is temp == Mild?
   -->True:
    Predict {'Yes': 1}
   -->False:
    Predict {'No': 1}
