In [162]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import LabelEncoder as LE

In [163]:
data=pd.read_csv("titanic_train.csv")

In [164]:
data.head(2)

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C


In [165]:
data.columns

Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],
      dtype='object')

In [166]:
data.Age.fillna(np.mean(data.Age,axis=0),inplace=True)

In [167]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            891 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.6+ KB


In [168]:
X=data.loc[:,['Pclass', 'Sex', 'Age', 'SibSp',
       'Parch']]

In [169]:
y=data['Survived']

In [170]:
le=LE()
X["Sex"]=le.fit_transform(X.Sex)

In [171]:
def entropy(columns):
    val,count=np.unique(columns,return_counts=True)
    entropy=0
    for count in count:
        p=count/len(columns)
        entropy+=p*np.log2(p)
    return -entropy

In [172]:
entropy(X.Sex)

0.9362046432498521

In [173]:
def info_gain(X,y,label):
    pivot=np.mean(X[label])
    y_left=y[X[label]<pivot]
    y_right=y[X[label]>=pivot]
    if(len(y_left)==0) or (len(y_right)==0):
        return -1000
    p_left=len(y_left)/len(y)
    p_right=len(y_right)/len(y)
    
    info_gain=entropy(y)-p_left*entropy(y_left)-p_right*entropy(y_right)
    return info_gain

In [174]:
for label in X.columns:
    print(label,info_gain(X,y,label))

Pclass 0.07579362743608165
Sex 0.2176601066606143
Age 0.001158644038169343
SibSp 0.009584541813400127
Parch 0.015380754493137666


In [175]:
class Node:
    def __init__(self,label=None,value=None,result=None):
        self.label=label
        self.value=value
        self.result=result
        

In [188]:
class DecisionTree:
    
    def __init__(self,max_depth):
        self.max_depth=max_depth
    
    def fit(self,X,y):
        self.root=self.generate(X,y,self.max_depth)
    
    def generate(self,X,y,depth):
        if depth==1:
            return Node(result=np.mean(y))
        
        gains=[]
        for label in X.columns:
            gain=info_gain(X,y,label)
            gains.append((gain,label))
        
        selected=max(gains)[1]
        pivot=np.mean(X[selected])
 
        #node=Node(selected,pivot)
        X_right=X[X[selected]>=pivot]
        X_left=X[X[selected]<pivot]
        y_right=y[X[selected]>=pivot]
        y_left=y[X[selected]<pivot]
        
        if (len(y_left)==0) or (len(y_right)==0):
            return Node(result=np.mean(y))
        
        node=Node(selected,pivot)
        node.right=self.generate(X_right,y_right,depth-1)
        node.left=self.generate(X_left,y_left,depth-1)
        return node
    
    def display(self,node,indent=""):
        if node.label==None:
            if node.result<0.5:
                print(indent,"Died")
            else:
                print(indent,"Survived")
            return
        
        print(indent,node.label,node.value)
        
        self.display(node.left,indent+"\t")
        self.display(node.right,indent+"\t")

In [189]:
model=DecisionTree(5)

In [190]:
model.fit(X,y)

In [191]:
model.display(model.root)

 Sex 0.6475869809203143
	 Pclass 2.159235668789809
		 Pclass 1.4470588235294117
			 Parch 0.4574468085106383
				 Survived
				 Survived
			 Parch 0.6052631578947368
				 Survived
				 Survived
		 Parch 0.7986111111111112
			 Age 26.090266435986155
				 Survived
				 Survived
			 SibSp 1.694915254237288
				 Died
				 Died
	 Pclass 2.389948006932409
		 Pclass 1.4695652173913043
			 Age 39.287716972034715
				 Died
				 Died
			 Parch 0.2222222222222222
				 Died
				 Died
		 Parch 0.22478386167146974
			 Age 29.07305445151033
				 Died
				 Died
			 SibSp 2.607843137254902
				 Died
				 Died
