# The implementation in this notebook is influenced by https://www.youtube.com/watch?v=LDRbO9a6XPU

# Configure hyper-parameters

In [None]:
DATA_PATH = '/kaggle/input/red-wine-quality-cortez-et-al-2009/winequality-red.csv'

# Import libraries

In [None]:
import pandas as pd
from tqdm import tqdm

In [None]:
class DecisionNode(object):
    def __init__(self, feature, value, true_branch, false_branch):
        self.feature = feature
        self.value = value
        self.true_branch = true_branch
        self.false_branch = false_branch
        
    def trace(self, idx=0):
        print(f'[{idx}] {self.feature} ({self.value})')
        print(' ' * idx + f'/ ', end='')
        self.true_branch.trace(idx+1)
        print(' ' * idx + f'\\ ', end='')
        self.false_branch.trace(idx+1)
        
        
class Leaf(object):
    def __init__(self, sub_df, label_col):
        self.predictions = sub_df[label_col].value_counts().to_dict()
        
    def trace(self, idx=0):
        final_prediction = max(self.predictions, key=self.predictions.get)
        print(f'[{idx}]' + str(final_prediction))

In [None]:
class DecisionTree(object):
    def __init__(self, data_path, label_col='quality'):
        self.label_col = label_col
        self.df = pd.read_csv(data_path)
        self.feature_cols = self.df.columns.drop(self.label_col)
        
    @staticmethod
    def is_numeric(value):
        return isinstance(value, int) or isinstance(value, float)
    
    def _count_classes(self, sub_df):
        return sub_df['quality'].value_counts().to_dict()
    
    def _get_gini(self, sub_df):
        ''' https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
        '''
        count_dict = self._count_classes(sub_df)
        impurity = 1
        
        for label, count in count_dict.items():
            prob = count / len(sub_df)
            impurity -= prob ** 2
            
        return impurity
    
    def _compute_info_gain(self, left, right, current_uncertainty):
        p = len(left) / (len(left) + len(right))
        
        return current_uncertainty - p * self._get_gini(left) - (1 - p) * self._get_gini(right)
    
    def _partition(self, sub_df, feature, value):
        ''' Partitions a dataset
        '''
        if self.is_numeric(value):
            mask = sub_df[feature] >= value
        else:
            mask = sub_df[feature] == value
            
        return sub_df.loc[mask], sub_df.loc[~mask]
        
    
    def _find_best_split(self, sub_df):
        best_gain = 0
        best_feature = None
        best_value = None
        current_uncertainty = self._get_gini(sub_df)
        
        for feature in self.feature_cols:
            unique_values = sub_df[feature].unique()
            
            for value in unique_values:
                true_sub_df, false_sub_df = self._partition(sub_df, feature, value)
                
                if len(true_sub_df) == 0 or len(false_sub_df) == 0:
                    continue
                    
                gain = self._compute_info_gain(true_sub_df, false_sub_df, current_uncertainty)
                
                if gain >= best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_value = value
                    
        return best_gain, best_feature, best_value
    
    def _build_tree(self, sub_df):
        gain, best_feature, best_value = self._find_best_split(sub_df)
        
        if gain == 0:
            return Leaf(sub_df, self.label_col)
        
        true_sub_df, false_sub_df = self._partition(sub_df, best_feature, best_value)
        
        true_branch = self._build_tree(true_sub_df)
        false_branch = self._build_tree(false_sub_df)
        
        return DecisionNode(best_feature, best_value, true_branch, false_branch)
    
    def build_tree(self):
        return self._build_tree(self.df)

In [None]:
dt = DecisionTree(data_path=DATA_PATH)

In [None]:
tree = dt.build_tree()

In [None]:
tree.trace()

In [None]:
# root = tree
# stack = [root]

# while True:
#     node = stack[0]
#     del stack[0]

#     if isinstance(node, DecisionNode):
#         feature = node.feature
#         value = node.value
#         true_branch = node.true_branch
#         false_branch = node.false_branch
        
#         print(feature, value)
        
#         stack.append(true_branch)
#         stack.append(false_branch)
#     else:
#         predictions = node.predictions
        
#         print(predictions)
        
#     if len(stack) == 0:
#         break