# Decision Tree Classifier on Wine DataSet

## Import Modules

In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import random
from pprint import pprint

## Load and Prepare Data

In [2]:
df = pd.read_csv("Datasets/wine.csv")
df.head()

Unnamed: 0,Wine,Alcohol,Malic.acid,Ash,Acl,Mg,Phenols,Flavanoids,Nonflavanoid.phenols,Proanth,Color.int,Hue,OD,Proline
0,1,14.23,1.71,2.43,15.6,127,2.8,3.06,0.28,2.29,5.64,1.04,3.92,1065
1,1,13.2,1.78,2.14,11.2,100,2.65,2.76,0.26,1.28,4.38,1.05,3.4,1050
2,1,13.16,2.36,2.67,18.6,101,2.8,3.24,0.3,2.81,5.68,1.03,3.17,1185
3,1,14.37,1.95,2.5,16.8,113,3.85,3.49,0.24,2.18,7.8,0.86,3.45,1480
4,1,13.24,2.59,2.87,21.0,118,2.8,2.69,0.39,1.82,4.32,1.04,2.93,735


In [3]:
df["label"] = df.Wine
df = df.drop("Wine",axis=1)
df.head()

Unnamed: 0,Alcohol,Malic.acid,Ash,Acl,Mg,Phenols,Flavanoids,Nonflavanoid.phenols,Proanth,Color.int,Hue,OD,Proline,label
0,14.23,1.71,2.43,15.6,127,2.8,3.06,0.28,2.29,5.64,1.04,3.92,1065,1
1,13.2,1.78,2.14,11.2,100,2.65,2.76,0.26,1.28,4.38,1.05,3.4,1050,1
2,13.16,2.36,2.67,18.6,101,2.8,3.24,0.3,2.81,5.68,1.03,3.17,1185,1
3,14.37,1.95,2.5,16.8,113,3.85,3.49,0.24,2.18,7.8,0.86,3.45,1480,1
4,13.24,2.59,2.87,21.0,118,2.8,2.69,0.39,1.82,4.32,1.04,2.93,735,1


In [4]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 178 entries, 0 to 177
Data columns (total 14 columns):
 #   Column                Non-Null Count  Dtype  
---  ------                --------------  -----  
 0   Alcohol               178 non-null    float64
 1   Malic.acid            178 non-null    float64
 2   Ash                   178 non-null    float64
 3   Acl                   178 non-null    float64
 4   Mg                    178 non-null    int64  
 5   Phenols               178 non-null    float64
 6   Flavanoids            178 non-null    float64
 7   Nonflavanoid.phenols  178 non-null    float64
 8   Proanth               178 non-null    float64
 9   Color.int             178 non-null    float64
 10  Hue                   178 non-null    float64
 11  OD                    178 non-null    float64
 12  Proline               178 non-null    int64  
 13  label                 178 non-null    int64  
dtypes: float64(11), int64(3)
memory usage: 19.6 KB


## Train-Test Split

In [5]:
def train_test_split(df,test_size):
    indices = df.index.tolist()
    if isinstance(test_size,float):
        test_size=round(test_size*len(df))
    
    test_indices=random.sample(population=indices,k=test_size)
    test_df=df.loc[test_indices]
    train_df=df.drop(test_indices)
    
    return train_df,test_df

In [6]:
random.seed(1)
train_df,test_df=train_test_split(df,test_size=25)

In [7]:
len(test_df)

25

In [8]:
len(train_df)

153

In [9]:
test_df.head()

Unnamed: 0,Alcohol,Malic.acid,Ash,Acl,Mg,Phenols,Flavanoids,Nonflavanoid.phenols,Proanth,Color.int,Hue,OD,Proline,label
34,13.51,1.8,2.65,19.0,110,2.35,2.53,0.29,1.54,4.2,1.1,2.87,1095,1
145,13.16,3.57,2.15,21.0,102,1.5,0.55,0.43,1.3,4.0,0.6,1.68,830,3
16,14.3,1.92,2.72,20.0,120,2.8,3.14,0.33,1.97,6.2,1.07,2.65,1280,1
65,12.37,1.21,2.56,18.1,98,2.42,2.65,0.37,2.08,4.6,1.19,2.3,678,2
30,13.73,1.5,2.7,22.5,101,3.0,3.25,0.29,2.38,5.7,1.19,2.71,1285,1


## Convert to Pandas DataFrame

In [10]:
data=train_df.values
data[:13]

array([[1.320e+01, 1.780e+00, 2.140e+00, 1.120e+01, 1.000e+02, 2.650e+00,
        2.760e+00, 2.600e-01, 1.280e+00, 4.380e+00, 1.050e+00, 3.400e+00,
        1.050e+03, 1.000e+00],
       [1.316e+01, 2.360e+00, 2.670e+00, 1.860e+01, 1.010e+02, 2.800e+00,
        3.240e+00, 3.000e-01, 2.810e+00, 5.680e+00, 1.030e+00, 3.170e+00,
        1.185e+03, 1.000e+00],
       [1.437e+01, 1.950e+00, 2.500e+00, 1.680e+01, 1.130e+02, 3.850e+00,
        3.490e+00, 2.400e-01, 2.180e+00, 7.800e+00, 8.600e-01, 3.450e+00,
        1.480e+03, 1.000e+00],
       [1.324e+01, 2.590e+00, 2.870e+00, 2.100e+01, 1.180e+02, 2.800e+00,
        2.690e+00, 3.900e-01, 1.820e+00, 4.320e+00, 1.040e+00, 2.930e+00,
        7.350e+02, 1.000e+00],
       [1.420e+01, 1.760e+00, 2.450e+00, 1.520e+01, 1.120e+02, 3.270e+00,
        3.390e+00, 3.400e-01, 1.970e+00, 6.750e+00, 1.050e+00, 2.850e+00,
        1.450e+03, 1.000e+00],
       [1.439e+01, 1.870e+00, 2.450e+00, 1.460e+01, 9.600e+01, 2.500e+00,
        2.520e+00, 3.000e-01, 1

## Check Purity Of Data

In [11]:
def check_purity(data):
    
    label_column = data[:,-1]
    unique_classes = np.unique(label_column)

    if len(unique_classes) == 1 :
        return True
    else :
        return False

In [12]:
check_purity(data)

False

## Classify Data

In [13]:
def classify_data(data):
    label_column = data[:,-1]
    unique_classes,count_unique_classes = np.unique(label_column,return_counts=True)
    
    index = count_unique_classes.argmax()
    classification = unique_classes[index]
    return classification

In [14]:
classify_data(data)

2.0

## Potential Splits

In [15]:
def get_potential_splits(data):
    _,n_columns = data.shape
    potential_splits = {}
    
    for column_index in range(n_columns-1):
        potential_splits[column_index] = []
        values =  data[:,column_index]
        unique_values = np.unique(values)
        
        for index in range(len(unique_values)):
            if index != 0:
                current_element = unique_values[index]
                previous_element = unique_values[index-1]
                potential_split = (previous_element + current_element) / 2;
                
                potential_splits[column_index].append(potential_split)
                
    return potential_splits

## Split Data

In [16]:
def split_data(data , split_column , split_value):
    split_column_values = data[:,split_column]
    
    data_below = data[ split_column_values <= split_value]
    data_above = data[ split_column_values > split_value]
    
    return data_below,data_above

## Lowest Overall Entropy Function
* Calculate Entropy Function
* Calculate Overall Entropy Function
* Determine Best Split

In [17]:
def calculate_entropy(data):
    
    label_column = data[:,-1]
    _,counts = np.unique(label_column , return_counts = True)
    
    probabilities = counts / sum(counts)
    
    entropy = sum( probabilities * (-np.log2(probabilities)))
    
    return entropy

In [18]:
def calculate_overall_entropy(data_below , data_above):
    
    n_data_points = len(data_below) + len(data_above)
    
    p_data_below = len(data_below) / n_data_points
    p_data_above = len(data_above) / n_data_points
    
    overall_entropy = (p_data_below * calculate_entropy(data_below)) + (p_data_above * calculate_entropy(data_above))
    
    return overall_entropy

In [19]:
def determine_best_split(data , potential_splits):
    
    overall_entropy = 999
    
    for column_index in potential_splits:
        for value in potential_splits[column_index]:
            
            data_below , data_above = split_data(data , split_column = column_index , split_value = value)
            current_overall_entropy = calculate_overall_entropy(data_below , data_above)
            
            if current_overall_entropy <= overall_entropy:
                overall_entropy = current_overall_entropy
                best_split_column = column_index
                best_split_value = value
                
    return best_split_column , best_split_value

## Decision Tree Algorithm

In [20]:
def decision_tree_algorithm(df , counter = 0 , min_samples = 2 , max_depth = 10):
    if counter == 0 :
        global column_headers
        column_headers = df.columns
        data = df.values
    else :
        data = df
    
    # base- case
    if check_purity(data) or len(data) < min_samples or counter == max_depth :
        classification = classify_data(data)
        return classification
        
    else :
        counter += 1
        
        #helper functions
        potential_splits = get_potential_splits(data)
        split_column , split_value = determine_best_split(data , potential_splits)
        data_below , data_above = split_data(data , split_column , split_value)
        
        #instantiate subtrees
        feature = column_headers[split_column]
        question = "{} <= {}".format(feature , split_value)
        subtree = { question : [] }
        
        yes_answer = decision_tree_algorithm(data_below , counter , min_samples , max_depth )
        no_answer = decision_tree_algorithm(data_above , counter , min_samples , max_depth )
        
        if yes_answer == no_answer :
            subtree = yes_answer
        else :
            subtree[question].append(yes_answer)
            subtree[question].append(no_answer)
        
        return subtree
         

In [21]:
tree = decision_tree_algorithm( train_df )
pprint(tree)

{'Flavanoids <= 1.5750000000000002': [{'Color.int <= 3.825': [2.0,
                                                              {'Hue <= 0.97': [3.0,
                                                                               2.0]}]},
                                      {'Proline <= 722.5': [{'Proline <= 676.0': [2.0,
                                                                                  {'Hue <= 0.8899999999999999': [1.0,
                                                                                                                 2.0]}]},
                                                            {'Color.int <= 3.49': [2.0,
                                                                                   1.0]}]}]}


## Classification Of Examples

In [22]:
def classify_example(example , tree) :
    question = list(tree.keys())[0]
    feature_name , comparision , value = question.split()
    
    # ask question
    if example[feature_name] <= float(value):
        answer = tree[question][0]
    
    else :
        answer = tree[question][1]
        
    #base case
    if not isinstance(answer , dict):
        return answer
    
    #recursive class
    else :
        residue_tree = answer
        return classify_example(example , residue_tree)

## Accuracy

In [23]:
def calculate_accuracy(df , tree):
    
    df["classification"] = df.apply(classify_example , axis = 1 , args = (tree , ))
    df["classification_correct"] = df.classification == df.label
    
    accuracy = df.classification_correct.mean()
    
    return accuracy

In [24]:
calculate_accuracy(test_df , tree)

1.0

In [25]:
test_df

Unnamed: 0,Alcohol,Malic.acid,Ash,Acl,Mg,Phenols,Flavanoids,Nonflavanoid.phenols,Proanth,Color.int,Hue,OD,Proline,label,classification,classification_correct
34,13.51,1.8,2.65,19.0,110,2.35,2.53,0.29,1.54,4.2,1.1,2.87,1095,1,1.0,True
145,13.16,3.57,2.15,21.0,102,1.5,0.55,0.43,1.3,4.0,0.6,1.68,830,3,3.0,True
16,14.3,1.92,2.72,20.0,120,2.8,3.14,0.33,1.97,6.2,1.07,2.65,1280,1,1.0,True
65,12.37,1.21,2.56,18.1,98,2.42,2.65,0.37,2.08,4.6,1.19,2.3,678,2,2.0,True
30,13.73,1.5,2.7,22.5,101,3.0,3.25,0.29,2.38,5.7,1.19,2.71,1285,1,1.0,True
126,12.43,1.53,2.29,21.5,86,2.74,3.15,0.39,1.77,3.94,0.69,2.84,352,2,2.0,True
115,11.03,1.51,2.2,21.5,85,2.46,2.17,0.52,2.01,1.9,1.71,2.87,407,2,2.0,True
120,11.45,2.4,2.42,20.0,96,2.9,2.79,0.32,1.83,3.25,0.8,3.39,625,2,2.0,True
166,13.45,3.7,2.6,23.0,111,1.7,0.92,0.43,1.46,10.68,0.85,1.56,695,3,3.0,True
97,12.29,1.41,1.98,16.0,85,2.55,2.5,0.29,1.77,2.9,1.23,2.74,428,2,2.0,True


In [26]:
random.seed(0)
train_df,test_df=train_test_split(df,test_size=25)

In [27]:
tree = decision_tree_algorithm( train_df )
pprint(tree)

{'Flavanoids <= 1.5750000000000002': [{'Color.int <= 3.9': [2.0,
                                                            {'Hue <= 0.97': [3.0,
                                                                             2.0]}]},
                                      {'Proline <= 724.5': [{'Malic.acid <= 3.92': [2.0,
                                                                                    {'Proline <= 530.0': [2.0,
                                                                                                          1.0]}]},
                                                            {'Color.int <= 3.55': [2.0,
                                                                                   1.0]}]}]}


In [28]:
calculate_accuracy(test_df , tree)

0.88

In [29]:
test_df

Unnamed: 0,Alcohol,Malic.acid,Ash,Acl,Mg,Phenols,Flavanoids,Nonflavanoid.phenols,Proanth,Color.int,Hue,OD,Proline,label,classification,classification_correct
98,12.37,1.07,2.1,18.5,88,3.52,3.75,0.24,1.95,4.5,1.04,2.77,660,2,2.0,True
107,12.72,1.75,2.28,22.5,84,1.38,1.76,0.48,1.63,3.3,0.88,2.42,488,2,2.0,True
10,14.1,2.16,2.3,18.0,105,2.95,3.32,0.22,2.38,5.75,1.25,3.17,1510,1,1.0,True
66,13.11,1.01,1.7,15.0,78,2.98,3.18,0.26,2.28,5.3,1.12,3.18,502,2,2.0,True
130,12.86,1.35,2.32,18.0,122,1.51,1.25,0.21,0.94,4.1,0.76,1.29,630,3,3.0,True
124,11.87,4.31,2.39,21.0,82,2.86,3.03,0.21,2.91,2.8,0.75,3.64,380,2,2.0,True
103,11.82,1.72,1.88,19.5,86,2.5,1.64,0.37,1.42,2.06,0.94,2.44,415,2,2.0,True
77,11.84,2.89,2.23,18.0,112,1.72,1.32,0.43,0.95,2.65,0.96,2.52,500,2,2.0,True
122,12.42,4.43,2.73,26.5,102,2.2,2.13,0.43,1.71,2.08,0.92,3.12,365,2,2.0,True
91,12.0,1.51,2.42,22.0,86,1.45,1.25,0.5,1.63,3.6,1.05,2.65,450,2,2.0,True
