In [4]:
import numpy as np
import time
import random
from bart_playground.util import fast_choice, fast_choice_with_weights

n_leaves = 10
n_vars = 10
n_thresholds = 100

class DummyTree:
    def __init__(self):
        self.leaves = list(range(n_leaves))
        self.dataX = np.zeros((1, n_vars))

possible_thresholds = {var: list(range(n_thresholds)) for var in range(n_vars)}
tree = DummyTree()

rng = np.random.default_rng(42)

# All candidates using list comprehension
start = time.time()
for _ in range(100):
    all_candidates = [
        (node_id, var, threshold)
        for node_id in tree.leaves
        for var in range(tree.dataX.shape[1])
        for threshold in possible_thresholds[var]
    ]
    rng.shuffle(all_candidates)
    count = 0
    for cand in all_candidates:
        count += 1
        if count >= 10:
            break
end = time.time()
print(f"List version: {end - start:.6f} seconds, total candidates: {count}")


# Random candidate
start = time.time()
for _ in range(100):
    all_candidates = [
        (node_id, var)
        for node_id in tree.leaves
        for var in range(tree.dataX.shape[1])
    ]
    count = 0
    while count < 10:
        node_id, var = fast_choice(rng, all_candidates)
        threshold = fast_choice(rng, possible_thresholds[var])
        count += 1
end = time.time()
print(f"Random version: {end - start:.6f} seconds, total candidates: {count}")

List version: 0.292203 seconds, total candidates: 10
Random version: 0.011197 seconds, total candidates: 10


In [None]:
# generator
def candidate_generator(tree, possible_thresholds):
    node_ids = list(tree.leaves)
    vars_ = list(range(tree.dataX.shape[1]))
    random.shuffle(node_ids)
    random.shuffle(vars_)
    for node_id in node_ids:
        for var in vars_:
            thresholds = list(possible_thresholds[var])
            random.shuffle(thresholds)
            for threshold in thresholds:
                yield (node_id, var, threshold)

start = time.time()
count = 0
for cand in candidate_generator(tree, possible_thresholds):
    count += 1
    print(cand)
    if count >= 10:
        break
end = time.time()
print(f"Generator version: {end - start:.6f} seconds, total candidates: {count}")

(2, 2, 49)
(2, 2, 0)
(2, 2, 67)
(2, 2, 58)
(2, 2, 39)
(2, 2, 86)
(2, 2, 89)
(2, 2, 12)
(2, 2, 85)
(2, 2, 27)
Generator version: 0.000994 seconds, total candidates: 10


In [3]:
start = time.time()
for _ in range(1000):
    count = 0
    for _ in range(1000):
        count += 1
        if count >= 100:
            break
end = time.time()
print(f"For loop version: {end - start:.6f} seconds, total candidates: {count}")

For loop version: 0.010324 seconds, total candidates: 100


In [4]:
start = time.time()
for _ in range(1000):
    count = 0
    while count < 100:
        count += 1
end = time.time()
print(f"While loop version: {end - start:.6f} seconds, total candidates: {count}")

While loop version: 0.008070 seconds, total candidates: 100


In [5]:
def fast_choice(generator, array, size=1):
    """Fast random selection from an array."""
    len_arr = len(array)
    if len_arr == 1:
        return array[0]
    if size == 1:
        return array[generator.integers(0, len_arr)]
    array = np.array(array)
    return array[generator.integers(0, len_arr, size=size)]

In [6]:
rng = np.random.default_rng(42)
start = time.time()
for _ in range(10000):
    size = 10
    fast_choice(rng, [1,2,3], size=size)
end = time.time()
print(f"Time taken: {end - start} seconds")

Time taken: 0.13588166236877441 seconds


In [7]:
li = [1,2,3,4,5]
li = np.array(li)
li[[0,1,2]]

array([1, 2, 3])

In [8]:
rng.integers(0, 10)

7

In [9]:
list(np.atleast_1d(fast_choice(rng, [1,2,3], size=7)))

[3, 2, 2, 1, 2, 3, 1]

In [10]:
import numpy as np
from numba import njit
import time

@njit
def _descendants_numba(node_id, vars):
    result = []
    n_nodes = len(vars)
    queue = [node_id]
    while queue:
        current = queue.pop(0)
        left = current * 2 + 1
        right = current * 2 + 2
        if left < n_nodes:
            result.append(left)
            queue.append(left)
        if right < n_nodes:
            result.append(right)
            queue.append(right)
    return np.array(result)

@njit
def _traverse_tree_single(X, vars, thresholds, starting_node, n_to_update):
    return starting_node

@njit
def _update_n_and_leaf_id_numba(starting_node, dataX, append: bool, vars, thresholds, prev_n, prev_leaf_id):
    n_nodes = len(vars)
    n = prev_n
    leaf_ids = prev_leaf_id
    offset = prev_leaf_id.shape[0] - dataX.shape[0] if append else 0
    desc = _descendants_numba(starting_node, vars)
    subtree_nodes = np.zeros(n_nodes, np.bool_)
    for j in desc:
        subtree_nodes[j] = True
        if not append:
            n[j] = 0
    for i in range(dataX.shape[0]):
        current_node = leaf_ids[offset + i]
        if append or current_node == starting_node or subtree_nodes[current_node]:
            leaf_ids[offset + i] = _traverse_tree_single(
                dataX[i], vars, thresholds, starting_node, n
            )
    return leaf_ids, n


@njit
def _update_n_and_leaf_id_numba_copy(starting_node, dataX, append: bool, vars, thresholds, prev_n, prev_leaf_id):
    n_nodes = len(vars)
    n = prev_n.copy()
    leaf_ids = prev_leaf_id.copy()
    offset = prev_leaf_id.shape[0] - dataX.shape[0] if append else 0
    desc = _descendants_numba(starting_node, vars)
    subtree_nodes = np.zeros(n_nodes, np.bool_)
    for j in desc:
        subtree_nodes[j] = True
        if not append:
            n[j] = 0
    for i in range(dataX.shape[0]):
        current_node = leaf_ids[offset + i]
        if append or current_node == starting_node or subtree_nodes[current_node]:
            leaf_ids[offset + i] = _traverse_tree_single(
                dataX[i], vars, thresholds, starting_node, n
            )
    return leaf_ids, n

n_nodes = 32
n_samples = 10000
vars = np.full(n_nodes, -1, dtype=np.int32)
thresholds = np.zeros(n_nodes, dtype=np.float32)
prev_n = np.zeros(n_nodes, dtype=np.int32)
prev_leaf_id = np.zeros(n_samples, dtype=np.int16)
dataX = np.random.randn(n_samples, 5).astype(np.float32)

_update_n_and_leaf_id_numba(0, dataX, False, vars, thresholds, prev_n, prev_leaf_id)
_update_n_and_leaf_id_numba_copy(0, dataX, False, vars, thresholds, prev_n, prev_leaf_id)

start = time.time()
for _ in range(1000):
    _update_n_and_leaf_id_numba(0, dataX, False, vars, thresholds, prev_n, prev_leaf_id)
print("No copy version time:", time.time() - start)

start = time.time()
for _ in range(1000):
    _update_n_and_leaf_id_numba_copy(0, dataX, False, vars, thresholds, prev_n, prev_leaf_id)
print("Copy version time:", time.time() - start)

No copy version time: 0.016204118728637695
Copy version time: 0.020060300827026367
