In [1]:
import os

os.chdir("/Users/xichenshe/IAI/scikit-learn/")
os.getcwd()

'/Users/xichenshe/IAI/scikit-learn'

In [2]:
from abc import ABC, abstractmethod
from functools import partial
import warnings

import numpy as np
from timeit import default_timer as time

from sklearn._loss.loss import (
    _LOSSES,
    BaseLoss,
    AbsoluteError,
    HalfBinomialLoss,
    HalfMultinomialLoss,
    HalfPoissonLoss,
    HalfSquaredError,
    PinballLoss,
)
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin, is_classifier
from sklearn.utils import check_random_state, resample
from sklearn.utils.validation import (
    check_is_fitted,
    check_consistent_length,
    _check_sample_weight,
)

from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
from sklearn.utils.multiclass import check_classification_targets
from sklearn.metrics import check_scoring

from sklearn.preprocessing import LabelEncoder

from sklearn.ensemble._hist_gradient_boosting._gradient_boosting import _update_raw_predictions
from sklearn.ensemble._hist_gradient_boosting.common import Y_DTYPE, X_DTYPE, G_H_DTYPE

from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower

from ddsketch import DDSketch
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.common import HISTOGRAM_DTYPE

from heapq import heappush, heappop
import numpy as np
import pandas as pd
from timeit import default_timer as time
import numbers

from sklearn.ensemble._hist_gradient_boosting.grower import TreeNode

### Generate data

In [3]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


X, y = make_classification(n_samples=5000, n_features=5, n_informative=4, n_redundant=0, n_repeated=0, n_classes=2, random_state=23)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)

X_train0 = X_train[:int(X_train.shape[0] / 2), :]
X_train1 = X_train[int(X_train.shape[0] / 2):, :]
y_train0 = y_train[:int(X_train.shape[0] / 2)]
y_train1 = y_train[int(X_train.shape[0] / 2):]

### Define server and client

In [4]:
def merge_histograms(hist_lst):
    hist_merge = np.zeros(shape=hist_lst[0].shape, dtype=HISTOGRAM_DTYPE)
    for feature_idx in range(hist_merge.shape[0]):
        for bin_idx in range(hist_merge.shape[1]):
            for hist_ii in  hist_lst:
                hist_merge[feature_idx, bin_idx]["count"] += hist_ii[feature_idx, bin_idx]["count"]
                hist_merge[feature_idx, bin_idx]["sum_gradients"] += hist_ii[feature_idx, bin_idx]["sum_gradients"]
                hist_merge[feature_idx, bin_idx]["sum_hessians"] += hist_ii[feature_idx, bin_idx]["sum_hessians"]

    return hist_merge

In [5]:
class HistGbmServer:
    def __init__(self, max_bins=255):
        self.max_bins = max_bins
        self.quantiles = np.linspace(0, 1, num=max_bins + 1)[1:-1]  # +1 for the bin for missing values

    def aggregate_quantile_sketch(self, clients_skt_lst):
        n_features = np.array([len(client) for client in clients_skt_lst])
        assert np.all(n_features == n_features[0]), "number of features differs across clients!"

        bin_thresh = []
        for f_idx in range(n_features[0]):
            skt_merge = clients_skt_lst[0][f_idx]
            for ii in range(1, len(clients_skt_lst)):
                skt_merge.merge(clients_skt_lst[ii][f_idx])
            q_merge = [skt_merge.get_quantile_value(q) for q in self.quantiles]
            bin_thresh.append(np.array(q_merge))
        return bin_thresh

    def aggregate_baseline_prediction(self, clients_baseline):
        return np.average(
            [x[0] for x in clients_baseline],
            weights=[x[1] for x in clients_baseline]
        )

    def aggregate_root_attr(self, clients_root_attr):
        return (
            np.sum([x[0] for x in clients_root_attr], axis=0),
            merge_histograms([x[1] for x in clients_root_attr])
        )

    def aggregate_next_node_attr(self, clients_next_node_attr):
        left_n_sample = 0
        right_n_sample = 0
        left_histogram_lst = []
        right_histogram_lst = []
        for client in clients_next_node_attr:
            left_n_sample += client[0][0]
            right_n_sample += client[1][0]
            left_histogram_lst.append(client[0][1])
            right_histogram_lst.append(client[1][1])
        return [int(left_n_sample), merge_histograms(left_histogram_lst)], [int(right_n_sample), merge_histograms(right_histogram_lst)]

In [6]:
class HistGbmClient:
    def __init__(self, sketch_relative_accuracy, max_depth, random_state):
        self.sketch_relative_accuracy = sketch_relative_accuracy
        self.max_depth = max_depth
        self.random_state = random_state

    # Return a list of sketches, one for each column
    # NOTE: the columns should be in the same order across clients
    def quantile_sketch(self, X):
        skt_lst = []
        for f_idx in range(X.shape[1]):
            skt_ii = DDSketch(relative_accuracy=self.sketch_relative_accuracy)
            for v in X[:, f_idx]:
                skt_ii.add(v)
            skt_lst.append(skt_ii)
        return skt_lst

    def set_bin_thresh(self, bin_thresh):
        self.bin_thresh = bin_thresh

    def init_learner(self, X, y, sample_weight=None):
        # TODO: handle regression tasks
        self.histgbm = HistGradientBoostingClassifier(
            max_depth=self.max_depth,
            early_stopping=False,
            random_state=self.random_state,
            # the # of bins = len(bin_thresh) + 2 which includes the bin for missing values, here the max_bins does not include it, thus only + 1
            max_bins=len(self.bin_thresh[0]) + 1,
        )
        X, y = self.histgbm._validate_data(X, y, dtype=[X_DTYPE], force_all_finite=False)
        y = self.histgbm._encode_y(y)
        check_consistent_length(X, y)
        
        # Do not create unit sample weights by default to later skip some
        # computation
        if sample_weight is not None:
            sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float64)
            # TODO: remove when PDP supports sample weights
            self.histgbm._fitted_with_sw = True

        rng = check_random_state(self.histgbm.random_state)
        self.histgbm._random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")

        self.histgbm._validate_parameters()

        # used for validation in predict
        n_samples, self.histgbm._n_features = X.shape

        self.histgbm.is_categorical_, self.histgbm.known_categories = self.histgbm._check_categories(X)

        if isinstance(self.histgbm.loss, str):
            self.histgbm._loss = self.histgbm._get_loss(sample_weight=sample_weight)
        elif isinstance(self.histgbm.loss, BaseLoss):
            self.histgbm._loss = self.histgbm.loss
        
        return X, y, sample_weight

    def get_local_unlinked_baseline(self, y, sample_weight=None):
        return np.average(y, weights=sample_weight, axis=0), y.shape[0]

    def set_baseline(self, baseline_prediction):
        self.histgbm._baseline_prediction = self.histgbm._loss.link.link(baseline_prediction).reshape((1, -1))

    def bin_data(self, X, y, sample_weight=None):
        # y needs to be encoded, i.e., the output from init_learner
        X_train, y_train, sample_weight_train = X, y, sample_weight
        X_val = y_val = sample_weight_val = None

        # `_openmp_effective_n_threads` is used to take cgroups CPU quotes
        # into account when determine the maximum number of threads to use.
        n_threads = _openmp_effective_n_threads()
        n_bins = self.histgbm.max_bins + 1  # +1 for missing values
        self.histgbm._bin_mapper = _BinMapper(
            n_bins=n_bins,
            is_categorical=self.histgbm.is_categorical_,
            known_categories=self.histgbm.known_categories,
            random_state=self.histgbm._random_seed,
            n_threads=n_threads,
        )

        X_binned_train = self.histgbm._bin_data(X_train, is_training_data=True, bin_thresholds=self.bin_thresh)
        if X_val is not None:
            X_binned_val = self.histgbm._bin_data(X_val, is_training_data=False)
        else:
            X_binned_val = None

        # Uses binned data to check for missing values
        self.histgbm.has_missing_values = (
            (X_binned_train == self.histgbm._bin_mapper.missing_values_bin_idx_)
            .any(axis=0)
            .astype(np.uint8)
        )
        return X_binned_train, y_train, X_binned_val, y_val

    def get_gradient_hession(self, X_binned, y):
        y = self.histgbm._encode_y(y)
        
        # initialize gradients and hessians (empty arrays).
        # shape = (n_samples, n_trees_per_iteration).
        n_samples = X_binned.shape[0]
        gradient, hessian = self.histgbm._loss.init_gradient_and_hessian(
            n_samples=n_samples, dtype=G_H_DTYPE, order="F"
        )

        # `_openmp_effective_n_threads` is used to take cgroups CPU quotes
        # into account when determine the maximum number of threads to use.
        n_threads = _openmp_effective_n_threads()

        raw_predictions = np.zeros(
            shape=(n_samples, self.histgbm.n_trees_per_iteration_),
            dtype=self.histgbm._baseline_prediction.dtype,
            order="F",
        )
        raw_predictions += self.histgbm._baseline_prediction
        # Update gradients and hessians, inplace
        # Note that self.histgbm._loss expects shape (n_samples,) for
        # n_trees_per_iteration = 1 else shape (n_samples, n_trees_per_iteration).
        if self.histgbm._loss.constant_hessian:
            self.histgbm._loss.gradient(
                y_true=y,
                raw_prediction=raw_predictions,
                sample_weight=None,
                gradient_out=gradient,
                n_threads=n_threads,
            )
        else:
            self.histgbm._loss.gradient_hessian(
                y_true=y,
                raw_prediction=raw_predictions,
                sample_weight=None,
                gradient_out=gradient,
                hessian_out=hessian,
                n_threads=n_threads,
            )

        # 2-d views of shape (n_samples, n_trees_per_iteration_) or (n_samples, 1)
        # on gradient and hessian to simplify the loop over n_trees_per_iteration_.
        if gradient.ndim == 1:
            g_view = gradient.reshape((-1, 1))
            h_view = hessian.reshape((-1, 1))
        else:
            g_view = gradient
            h_view = hessian

        return g_view, h_view

    def init_grower(self, X, y, sample_weight=None):

        X_binned_train, y_train, X_binned_val, y_val = self.bin_data(X, y, sample_weight=None)
        g_view, h_view = self.get_gradient_hession(X_binned_train, y_train)
        n_threads = _openmp_effective_n_threads()

        # Build `n_trees_per_iteration` trees.
        for k in range(self.histgbm.n_trees_per_iteration_):
            self.grower = TreeGrower(
                X_binned=X_binned_train,
                gradients=g_view[:, k],
                hessians=h_view[:, k],
                n_bins=self.histgbm.max_bins + 1,
                n_bins_non_missing=self.histgbm._bin_mapper.n_bins_non_missing_,
                has_missing_values=self.histgbm.has_missing_values,
                is_categorical=self.histgbm.is_categorical_,
                monotonic_cst=self.histgbm.monotonic_cst,
                max_leaf_nodes=self.histgbm.max_leaf_nodes,
                max_depth=self.histgbm.max_depth,
                min_samples_leaf=self.histgbm.min_samples_leaf,
                l2_regularization=self.histgbm.l2_regularization,
                shrinkage=self.histgbm.learning_rate,
                n_threads=n_threads,
            )

    def get_root_attr(self):
        return [self.grower.root.n_samples, self.grower.root.sum_gradients, self.grower.root.sum_hessians], self.grower.root.histograms

    def split_root(self, root_attr):
        self.grower.root.split_info = self.grower.splitter.find_node_split(
            root_attr[0][0],
            root_attr[1],
            root_attr[0][1],
            root_attr[0][2],
            self.grower.root.value,
            self.grower.root.children_lower_bound,
            self.grower.root.children_upper_bound,
        )

    def prepare_next_node(self):
        node = heappop(self.grower.splittable_nodes)
        (
            sample_indices_left,
            sample_indices_right,
            right_child_pos,
        ) = self.grower.splitter.split_indices(node.split_info, node.sample_indices)

        depth = node.depth + 1
        n_leaf_nodes = len(self.grower.finalized_leaves) + len(self.grower.splittable_nodes)
        n_leaf_nodes += 2

        left_child_node = TreeNode(
            depth,
            sample_indices_left,
            # NOTE: the sum gradient and hession are already the aggregated versions
            node.split_info.sum_gradient_left,
            node.split_info.sum_hessian_left,
            value=node.split_info.value_left,
        )
        right_child_node = TreeNode(
            depth,
            sample_indices_right,
            node.split_info.sum_gradient_right,
            node.split_info.sum_hessian_right,
            value=node.split_info.value_right,
        )

        node.right_child = right_child_node
        node.left_child = left_child_node

        # set start and stop indices
        left_child_node.partition_start = node.partition_start
        left_child_node.partition_stop = node.partition_start + right_child_pos
        right_child_node.partition_start = left_child_node.partition_stop
        right_child_node.partition_stop = node.partition_stop

        if not self.grower.has_missing_values[node.split_info.feature_idx]:
            # If no missing values are encountered at fit time, then samples
            # with missing values during predict() will go to whichever child
            # has the most samples.
            node.split_info.missing_go_to_left = (
                left_child_node.n_samples > right_child_node.n_samples
            )

        self.grower.n_nodes += 2
        self.grower.n_categorical_splits += node.split_info.is_categorical

        # if grower.max_leaf_nodes is not None and n_leaf_nodes == grower.max_leaf_nodes:
        #     grower._finalize_leaf(left_child_node)
        #     grower._finalize_leaf(right_child_node)
        #     grower._finalize_splittable_nodes()
        #     # return left_child_node, right_child_node

        if self.grower.max_depth is not None and depth == self.grower.max_depth:
            self.grower._finalize_leaf(left_child_node)
            self.grower._finalize_leaf(right_child_node)
            # return left_child_node, right_child_node
        else:
            # We will compute the histograms of both nodes even if one of them
            # is a leaf, since computing the second histogram is very cheap
            # (using histogram subtraction).
            n_samples_left = left_child_node.sample_indices.shape[0]
            n_samples_right = right_child_node.sample_indices.shape[0]
            if n_samples_left < n_samples_right:
                smallest_child = left_child_node
                largest_child = right_child_node
            else:
                smallest_child = right_child_node
                largest_child = left_child_node

            smallest_child.histograms = self.grower.histogram_builder.compute_histograms_brute(
                smallest_child.sample_indices
            )
            largest_child.histograms = (
                self.grower.histogram_builder.compute_histograms_subtraction(
                    node.histograms, smallest_child.histograms
                )
            )
        self.left_child_node = left_child_node
        self.right_child_node = right_child_node

    def get_next_node_attr(self):
        return [self.left_child_node.n_samples, self.left_child_node.histograms], [self.right_child_node.n_samples, self.right_child_node.histograms]

    def split_next_node(self, next_node_attr):
        self.grower._compute_best_split_and_push(self.left_child_node, next_node_attr[0][0], next_node_attr[0][1])
        self.grower._compute_best_split_and_push(self.right_child_node, next_node_attr[1][0], next_node_attr[1][1])

In [7]:
server = HistGbmServer(max_bins=255)
client_0 = HistGbmClient(sketch_relative_accuracy=0.001, max_depth=2, random_state=23)
client_1 = HistGbmClient(sketch_relative_accuracy=0.001, max_depth=2, random_state=23)


### Federated binning

In [8]:
clients_skt_lst = [
    client_0.quantile_sketch(X=X_train0),
    client_1.quantile_sketch(X=X_train1),
]

In [9]:
bin_thresh = server.aggregate_quantile_sketch(clients_skt_lst)

In [10]:
bin_thresh[0][:10]

array([-2.5523125 , -2.37026032, -2.23669598, -2.13187048, -2.05648814,
       -1.97585206, -1.90217832, -1.84965605, -1.79858401, -1.7454278 ])

In [11]:
bin_thresh[3][:10]

array([-4.07552661, -3.74716645, -3.56441492, -3.43837822, -3.31017116,
       -3.23167278, -3.14873216, -3.07406219, -3.00116296, -2.92999249])

In [12]:
client_0.set_bin_thresh(bin_thresh)
client_1.set_bin_thresh(bin_thresh)

### Federated training

### 1. set the global baseline

In [13]:
X_train0, y_train0, _ = client_0.init_learner(X_train0, y_train0)
X_train1, y_train1, _ = client_1.init_learner(X_train1, y_train1)

In [14]:
clients_baseline = [
    client_0.get_local_unlinked_baseline(y_train0),
    client_1.get_local_unlinked_baseline(y_train1),
]

In [15]:
agg_baseline = server.aggregate_baseline_prediction(clients_baseline)
client_0.set_baseline(agg_baseline)
client_1.set_baseline(agg_baseline)

#### 2. initialize root node and split

In [16]:
client_0.init_grower(X_train0, y_train0)
client_1.init_grower(X_train1, y_train1)

In [17]:
clients_root_attr = [
    client_0.get_root_attr(),
    client_1.get_root_attr(),
]
agg_root_attr = server.aggregate_root_attr(clients_root_attr)

In [18]:
client_0.split_root(agg_root_attr)
client_1.split_root(agg_root_attr)

In [19]:
client_0.grower.root.split_info.__dict__

{'gain': 765.9620907859037,
 'feature_idx': 4,
 'bin_idx': 195,
 'missing_go_to_left': 0,
 'sum_gradient_left': 369.28099259734154,
 'sum_hessian_left': 768.2223420739174,
 'sum_gradient_right': -369.28101655840874,
 'sum_hessian_right': 231.74165672063828,
 'n_samples_left': 3073,
 'n_samples_right': 927,
 'value_left': -0.4806954606402345,
 'value_right': 1.5935029626700756,
 'is_categorical': 0,
 'left_cat_bitset': None}

#### 3. split next nodes

In [20]:
assert len(client_0.grower.splittable_nodes) == len(client_1.grower.splittable_nodes), "model out of sync across clients"

In [21]:
while len(client_0.grower.splittable_nodes) > 0:
    client_0.prepare_next_node()
    client_1.prepare_next_node()
    if client_0.left_child_node.histograms is not None:
        clients_next_node_attr = [
            client_0.get_next_node_attr(),
            client_1.get_next_node_attr(),
        ]
        agg_next_node_attr = server.aggregate_next_node_attr(clients_next_node_attr)
        client_0.split_next_node(agg_next_node_attr)
        client_1.split_next_node(agg_next_node_attr)

### Final predictors

In [24]:
pd.DataFrame(client_0.grower.make_predictor(bin_thresh).nodes)

Unnamed: 0,value,count,feature_idx,num_threshold,missing_go_to_left,left,right,gain,depth,is_leaf,bin_threshold,is_categorical,bitset_idx
0,0.0,2000,4,0.394159,1,1,4,765.962091,0,0,195,0,0
1,-0.480695,1519,1,-0.026596,1,2,3,775.590298,1,0,155,0,0
2,-1.278917,938,0,0.0,0,0,0,-1.0,2,1,0,0,0
3,0.784106,581,0,0.0,0,0,0,-1.0,2,1,0,0,0
4,1.593503,481,3,1.029424,1,5,6,43.643797,1,0,216,0,0
5,1.721323,443,0,0.0,0,0,0,-1.0,2,1,0,0,0
6,0.120112,38,0,0.0,0,0,0,-1.0,2,1,0,0,0


In [25]:
pd.DataFrame(client_1.grower.make_predictor(bin_thresh).nodes)

Unnamed: 0,value,count,feature_idx,num_threshold,missing_go_to_left,left,right,gain,depth,is_leaf,bin_threshold,is_categorical,bitset_idx
0,0.0,2000,4,0.394159,1,1,4,765.962091,0,0,195,0,0
1,-0.480695,1554,1,-0.026596,1,2,3,775.590298,1,0,155,0,0
2,-1.278917,946,0,0.0,0,0,0,-1.0,2,1,0,0,0
3,0.784106,608,0,0.0,0,0,0,-1.0,2,1,0,0,0
4,1.593503,446,3,1.029424,1,5,6,43.643797,1,0,216,0,0
5,1.721323,410,0,0.0,0,0,0,-1.0,2,1,0,0,0
6,0.120112,36,0,0.0,0,0,0,-1.0,2,1,0,0,0


### TODO
- grow multiple trees for multi-class targets
- accumulate predictions then grow a sequence of trees for boosting