In [29]:
import numpy as np
import pandas as pd
from anytree import Node, RenderTree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.tree import _tree
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder
from rich.tree import Tree as RichTree
from rich import print
from rich.console import Console
import json

In [None]:
class CustomDecisionTreeClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, max_depth=None, similarity_threshold=0.01, preference_list=None,
                 interactive=False, interactive_threshold=0.01, random_state=None, min_samples_split=2):
        self.max_depth = max_depth
        self.similarity_threshold = similarity_threshold
        self.preference_list = preference_list or []
        self.interactive = interactive
        self.interactive_threshold = interactive_threshold
        self.random_state = random_state
        self.min_samples_split = min_samples_split
        self.tree_ = None

    def _calculate_split_quality(self, X, y, feature_index):
        values = X[:, feature_index]
        unique_values = np.unique(values)
        splits = [y[values == val] for val in unique_values]
        weighted_gini = sum(len(s) / len(y) * (1 - sum((np.mean(s == c) ** 2 for c in np.unique(s)))) for s in splits)
        return -weighted_gini

    def _choose_split(self, qualities):
        best_quality = qualities[0][1]
        candidates = [q for q in qualities if abs(q[1] - best_quality) <= self.similarity_threshold]
        if self.interactive and abs(candidates[0][1] - candidates[-1][1]) <= self.interactive_threshold:
            print("Smiliar quality for features:")
            for idx, (f, q) in enumerate(candidates):
                print(f"{idx}: Feature {f} (Quality: {q:.4f})")
            choice = int(input("Choose feature (index): "))
            if choice < 0 or choice >= len(candidates):
                print("Invalid choice, using the first candidate.")
                choice = 0
            return candidates[choice][0]

        if self.preference_list:
            candidate_features = {f for f, _ in candidates}
            for preferred in self.preference_list:
                if preferred in candidate_features:
                    return preferred
        return candidates[0][0]

    def _build_tree(self, X, y, depth=0, used_features=None):
        if used_features is None:
            used_features = set()
        
        # Stop conditions
        # 1. Gdy osiągnięto maksymalną głębokość drzewa
        # 2. Gdy wszystkie etykiety są takie same
        # 3. Gdy liczba próbek jest mniejsza niż min_samples_split
        # 4. Gdy wszystkie cechy zostały już użyte
        if (self.max_depth is not None and depth >= self.max_depth) or len(set(y)) == 1 or len(y) < self.min_samples_split or (len(used_features) == X.shape[1]):
            return {'type': 'leaf', 'class': np.bincount(y).argmax()}

        qualities = [
            (i, q)
            for i in range(X.shape[1]) if i not in used_features
            for q in [self._calculate_split_quality(X, y, i)]
        ]
        if not qualities:
            return {'type': 'leaf', 'class': np.bincount(y).argmax()}

        qualities.sort(key=lambda x: -x[1])
        feature = self._choose_split(qualities)
        new_used_features = used_features.copy()
        new_used_features.add(feature)

        # Ustalenie domyślnej klasy na podstawie bieżącego węzła
        default_class = np.bincount(y).argmax()
        node = {'type': 'split', 'feature': feature, 'branches': {}, 'default_class': default_class}
        
        values = np.unique(X[:, feature])
        for val in values:
            idx = X[:, feature] == val
            if np.sum(idx) == 0:
                continue  # zapobiega pustym gałęziom
            node['branches'][val] = self._build_tree(X[idx], y[idx], depth + 1, new_used_features)

        # Jeżeli wszystkie gałęzie są puste, zwracamy liść
        if not node['branches']:
            return {'type': 'leaf', 'class': default_class}
        
        return node

    def fit(self, X, y):
        self.tree_ = self._build_tree(np.array(X), np.array(y))
        return self

    def _predict_sample(self, x, node):
        if node['type'] == 'leaf':
            return node['class']
        val = x[node['feature']]
        branch = node['branches'].get(val)
        if branch is None:
            return node.get('default_class', 0)
        return self._predict_sample(x, branch)

    def predict(self, X):
        return np.array([self._predict_sample(x, self.tree_) for x in X.to_numpy()])

    def print_tree(self, node=None):
        if node is None:
            node = self.tree_

        def build_tree(node):
            if node['type'] == 'leaf':
                return RichTree(f"[green]Klasa: {node['class']}[/green]")

            tree = RichTree(f"[cyan]Cecha {node['feature']}[/cyan]")
            for val, child in node['branches'].items():
                branch = build_tree(child)
                subtree = tree.add(f"[yellow]Wartość = {val}[/yellow]")
                subtree.children.append(branch)
            return tree

        return build_tree(node)



In [31]:

# --- Przygotowanie danych ---
data = fetch_openml(name="mushroom", version=1, as_frame=True)
X = data.data
y = data.target

# Kodowanie danych kategorycznych
X = X.apply(LabelEncoder().fit_transform)
y = LabelEncoder().fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# --- Trenowanie modelu ---
feature_order = list(range(X.shape[1]))
clf = CustomDecisionTreeClassifier(preference_list=feature_order[::-1],
                                   similarity_threshold=0.01,
                                   max_depth=4,
                                   min_samples_split=5)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print("Dokładność:", accuracy_score(y_test, y_pred))

# root = clf.to_anytree()
# for pre, fill, node in RenderTree(root):
#     print(f"{pre}{node.name}")

console = Console()
console.print(clf.print_tree())
