### Decision Tree from Scratch for Classification Problem

In [1]:
from collections import Counter
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd

In [2]:
class DecisionTreeClassifier:
    """
    A Decision Tree classifier supporting:
    - Continuous and categorical features
    - Multiclass classification
    - Missing value handling
    - Post-pruning with validation data
    
    Parameters:
    -----------
    max_depth : int or None
        Maximum depth of the tree. If None, tree expands until pure leaves or no gain.
    
    min_samples_split : int
        Minimum number of samples required to split a node.
    
    Methods:
    --------
    fit(X, y, feature_names=None, X_val=None, y_val=None):
        Train the decision tree classifier.
    
    predict(X):
        Predict class labels for samples in X.
    """
    
    def __init__(self, max_depth=None, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.tree = None
        self.feature_names = None
        
    def fit(self, X, y, feature_names=None, X_val=None, y_val=None):
        """
        Build the decision tree from training data.
        
        Parameters:
        -----------
        X : list of list
            Training samples, each sample is a list of feature values.
        y : list
            Target class labels.
        feature_names : list of str, optional
            List of feature names in order.
        X_val, y_val : optional
            Validation data for post-pruning.
        """
        self.feature_names = feature_names
        self.tree = self._build_tree(X, y, depth=0)
        
        if X_val is not None and y_val is not None:
            self._prune(self.tree, X_val, y_val)
    
    def _gini(self, y):
        """Calculate Gini impurity for labels y"""
        counts = Counter(y)
        n = len(y)
        impurity = 1 - sum((count / n) ** 2 for count in counts.values())
        return impurity
    
    def _split_dataset(self, X, y, feature_index, threshold):
        """
        Split dataset based on feature at feature_index and threshold.
        
        For continuous features: left <= threshold, right > threshold.
        For categorical features: left == threshold, right != threshold.
        
        Returns:
        --------
        (X_left, y_left), (X_right, y_right)
        """
        X_left, y_left, X_right, y_right = [], [], [], []
        for xi, yi in zip(X, y):
            val = xi[feature_index]
            if val is None:  # Missing values handled by ignoring this sample in split calculation
                continue
            if isinstance(threshold, (int, float)):  # Continuous feature
                if val <= threshold:
                    X_left.append(xi)
                    y_left.append(yi)
                else:
                    X_right.append(xi)
                    y_right.append(yi)
            else:  # Categorical feature
                if val == threshold:
                    X_left.append(xi)
                    y_left.append(yi)
                else:
                    X_right.append(xi)
                    y_right.append(yi)
        return (X_left, y_left), (X_right, y_right)
    
    def _best_split(self, X, y):
        """
        Find the best split for the current node by maximizing Gini gain.
        
        Returns:
        --------
        best_feature_index, best_threshold, best_gain, best_splits
        """
        base_gini = self._gini(y) # Parent Impurity
        best_gain = 0
        best_feature, best_threshold = None, None
        best_splits = None
        
        n_features = len(X[0])
        
        for feature_index in range(n_features):
            # Get all possible values for this feature, ignoring None
            values = [x[feature_index] for x in X if x[feature_index] is not None]
            unique_values = set(values)
            
            # For continuous features: try midpoints between sorted unique values
            if all(isinstance(v, (int, float)) for v in unique_values):
                sorted_vals = sorted(unique_values)
                thresholds = [(sorted_vals[i] + sorted_vals[i+1])/2 for i in range(len(sorted_vals)-1)]
            else:
                # Categorical: thresholds are unique values themselves
                thresholds = unique_values
            
            for threshold in thresholds:
                (X_left, y_left), (X_right, y_right) = self._split_dataset(X, y, feature_index, threshold)
                
                if len(y_left) == 0 or len(y_right) == 0:
                    continue
                
                # Calculate weighted Gini impurity after split
                n = len(y_left) + len(y_right)
                gini_left = self._gini(y_left)
                gini_right = self._gini(y_right)
                weighted_gini = (len(y_left) / n) * gini_left + (len(y_right) / n) * gini_right
                
                gain = base_gini - weighted_gini
                
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature_index
                    best_threshold = threshold
                    best_splits = ((X_left, y_left), (X_right, y_right))
        
        return best_feature, best_threshold, best_gain, best_splits
    
    def _build_tree(self, X, y, depth):
        """
        Recursively build the decision tree.
        """
        num_samples = len(y)
        num_labels = len(set(y))
        
        # Stopping criteria
        if (self.max_depth is not None and depth >= self.max_depth) or (num_labels == 1) or (num_samples < self.min_samples_split):
            leaf_label = Counter(y).most_common(1)[0][0]
            return {'type': 'leaf', 'class': leaf_label}
        
        feature_index, threshold, gain, splits = self._best_split(X, y)
        
        if gain == 0 or splits is None:
            leaf_label = Counter(y).most_common(1)[0][0]
            return {'type': 'leaf', 'class': leaf_label}
        
        (X_left, y_left), (X_right, y_right) = splits
        
        left_branch = self._build_tree(X_left, y_left, depth + 1)
        right_branch = self._build_tree(X_right, y_right, depth + 1)
        
        return {
            'type': 'node',
            'feature_index': feature_index,
            'feature_name': self.feature_names[feature_index] if self.feature_names else None,
            'threshold': threshold,
            'left': left_branch,
            'right': right_branch
        }
    
    def _predict_sample(self, node, sample):
        """
        Predict class label for a single sample by traversing the tree.
        Missing values handled by traversing both branches and majority voting.
        """
        if node['type'] == 'leaf':
            return node['class']
        
        val = sample[node['feature_index']]
        threshold = node['threshold']
        
        if val is None:
            # Missing feature value: traverse both branches and return majority vote
            left_pred = self._predict_sample(node['left'], sample)
            right_pred = self._predict_sample(node['right'], sample)
            return Counter([left_pred, right_pred]).most_common(1)[0][0]
        
        if isinstance(threshold, (int, float)):
            # Continuous feature split
            if val <= threshold:
                return self._predict_sample(node['left'], sample)
            else:
                return self._predict_sample(node['right'], sample)
        else:
            # Categorical feature split
            if val == threshold:
                return self._predict_sample(node['left'], sample)
            else:
                return self._predict_sample(node['right'], sample)
    
    def predict(self, X):
        """
        Predict class labels for multiple samples.
        
        Parameters:
        -----------
        X : list of list
            Samples to predict.
        
        Returns:
        --------
        List of predicted class labels.
        """
        return [self._predict_sample(self.tree, sample) for sample in X]
    
    def _prune(self, node, X_val, y_val):
        """
        Post-pruning using reduced error pruning on validation set.
        """
        if node['type'] == 'leaf':
            return
        
        # Prune children first
        self._prune(node['left'], X_val, y_val)
        self._prune(node['right'], X_val, y_val)
        
        # If both children are leaves, try pruning this node
        if node['left']['type'] == 'leaf' and node['right']['type'] == 'leaf':
            # Current prediction error
            y_pred = [self._predict_sample(node, x) for x in X_val]
            error_before = sum(yp != yt for yp, yt in zip(y_pred, y_val))
            
            # Temporarily prune node into leaf with majority class
            combined_labels = []
            for x in X_val:
                val = x[node['feature_index']]
                if val is None:
                    continue
                if isinstance(node['threshold'], (int, float)):
                    if val <= node['threshold']:
                        combined_labels.append(y_val[X_val.index(x)])
                    else:
                        combined_labels.append(y_val[X_val.index(x)])
                else:
                    combined_labels.append(y_val[X_val.index(x)])
            if combined_labels:
                majority_class = Counter(combined_labels).most_common(1)[0][0]
            else:
                majority_class = Counter(y_val).most_common(1)[0][0]
            
            # Save original node
            original_node = node.copy()
            node.clear()
            node.update({'type': 'leaf', 'class': majority_class})
            
            # Prediction error after pruning
            y_pred_pruned = [self._predict_sample(node, x) for x in X_val]
            error_after = sum(yp != yt for yp, yt in zip(y_pred_pruned, y_val))
            
            # Revert if pruning doesn't improve error
            if error_after > error_before:
                node.clear()
                node.update(original_node)

In [3]:
# read data

data = pd.read_csv('../data\mnist_train.csv')
data.head()

Unnamed: 0,label,1x1,1x2,1x3,1x4,1x5,1x6,1x7,1x8,1x9,...,28x19,28x20,28x21,28x22,28x23,28x24,28x25,28x26,28x27,28x28
0,5,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,9,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
# constant features

constant_features = [i for i in data.columns if data[i].nunique() == 1]
constant_features

['1x1',
 '1x2',
 '1x3',
 '1x4',
 '1x5',
 '1x6',
 '1x7',
 '1x8',
 '1x9',
 '1x10',
 '1x11',
 '1x12',
 '1x17',
 '1x18',
 '1x19',
 '1x20',
 '1x21',
 '1x22',
 '1x23',
 '1x24',
 '1x25',
 '1x26',
 '1x27',
 '1x28',
 '2x1',
 '2x2',
 '2x3',
 '2x4',
 '2x25',
 '2x26',
 '2x27',
 '2x28',
 '3x1',
 '3x2',
 '3x27',
 '3x28',
 '4x1',
 '4x2',
 '4x28',
 '5x1',
 '6x1',
 '6x2',
 '7x1',
 '18x1',
 '21x1',
 '24x1',
 '24x2',
 '24x28',
 '25x1',
 '25x2',
 '25x28',
 '26x1',
 '26x2',
 '26x28',
 '27x1',
 '27x2',
 '27x3',
 '27x27',
 '27x28',
 '28x1',
 '28x2',
 '28x3',
 '28x4',
 '28x25',
 '28x26',
 '28x27',
 '28x28']

In [5]:
# Drop constant Features

data.drop(constant_features, axis= 1, inplace= True)

In [6]:
data = data.iloc[:500,:]
data.shape

(500, 718)

In [7]:
# Split the data into train and test

X_train, X_test, y_train, y_test = train_test_split(data.drop('label', axis= 1),
                                                   data['label'],
                                                   test_size= 0.2,
                                                   random_state= 0)

X_train.shape, X_test.shape

((400, 717), (100, 717))

In [8]:
# create instsnce

tree = DecisionTreeClassifier(max_depth= 5)
tree

<__main__.DecisionTreeClassifier at 0x1c4829a5f40>

In [9]:
X_train = X_train.values.tolist()
y_train = y_train.tolist()

In [10]:
# fit the model

tree.fit(X_train, y_train)

In [11]:
# predictions

tree.predict(X_test.values.tolist())

[6,
 9,
 0,
 8,
 7,
 7,
 9,
 0,
 2,
 9,
 3,
 9,
 3,
 4,
 2,
 4,
 7,
 1,
 9,
 0,
 8,
 1,
 3,
 3,
 9,
 4,
 5,
 6,
 7,
 2,
 3,
 0,
 1,
 5,
 3,
 3,
 6,
 0,
 7,
 3,
 7,
 3,
 5,
 6,
 2,
 8,
 2,
 2,
 0,
 1,
 4,
 3,
 8,
 3,
 7,
 3,
 3,
 2,
 7,
 4,
 1,
 1,
 9,
 5,
 7,
 7,
 5,
 8,
 4,
 0,
 4,
 2,
 3,
 1,
 5,
 9,
 4,
 3,
 2,
 1,
 1,
 3,
 3,
 5,
 3,
 1,
 6,
 0,
 6,
 9,
 2,
 4,
 4,
 1,
 3,
 4,
 0,
 0,
 7,
 1]

In [12]:
# view tree

tree._build_tree(X_train, y_train, 3)

{'type': 'node',
 'feature_index': 523,
 'feature_name': None,
 'threshold': 61.0,
 'left': {'type': 'node',
  'feature_index': 359,
  'feature_name': None,
  'threshold': 3.0,
  'left': {'type': 'leaf', 'class': 1},
  'right': {'type': 'leaf', 'class': 4}},
 'right': {'type': 'node',
  'feature_index': 316,
  'feature_name': None,
  'threshold': 3.0,
  'left': {'type': 'leaf', 'class': 2},
  'right': {'type': 'leaf', 'class': 0}}}

In [13]:
# Sample data
data = pd.DataFrame({
    'Outlook': ['Sunny', 'Sunny', 'Overcast', 'Rainy', 'Rainy', 'Rainy', 'Overcast', 'Sunny'],
    'Humidity': ['High', 'High', None, 'High', 'Normal', 'Normal', 'Normal', 'Normal'],
    'Temperature': [85, 80, 83, 70, 68, 65, 64, 72],
    'Windy': ['False', 'True', 'False', 'False', 'False', 'True', 'True', 'False'],
    'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'Yes']
})

X = data.drop(columns='Play')
y = data['Play']

tree = DecisionTreeClassifier(max_depth=3)
tree.fit(X.values.tolist(), y.tolist(), feature_names=X.columns.tolist())

sample = ['Rainy', 'Normal', 70, 'False']
print("Prediction:", tree.predict([sample]))
tree._build_tree(X,y,3)

Prediction: ['Yes']


{'type': 'leaf', 'class': 'Yes'}