Explores a case_when based conditional performance v a short-circuiting if_then_else
implementation, as discussed in this PR: https://github.com/wandb/weave/pull/406#discussion_r1299238062

In [1]:
import timeit

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc

In [59]:
def if_then_else(arr, if_fn, then_fn, else_fn):
    case = if_fn(arr)
    case_inverted = pc.invert(case)
    filt_true = pc.filter(arr, case)
    filt_false = pc.filter(arr, case_inverted)
    true_result = then_fn(filt_true)
    false_result = else_fn(filt_false)
    new_arr = pc.replace_with_mask(arr, case, true_result)
    new_arr = pc.replace_with_mask(new_arr, case_inverted, false_result)
    return new_arr

In [55]:
def cond_if_then_else(arr, case_fns, result_fns):
    result_arr = arr
    then_fn = result_fns[0]

    # have to do this because of variable capture
    def make_else(ef):
        return lambda arr: if_then_else(arr, case_fns[i - 1], result_fns[i - 1], ef)

    else_fn = result_fns[-1]
    for i in range(len(result_fns) - 1, 1, -1):
        else_fn = make_else(else_fn)
    return if_then_else(arr, case_fns[0], result_fns[0], else_fn)

In [60]:
def cond(arr, case_fns, result_fns):
    cases = [case_fn(arr) for case_fn in case_fns]
    case_names = ["%s" % i for i in range(len(case_fns))]
    results = [result_fn(arr) for result_fn in result_fns]
    return pc.case_when(pa.StructArray.from_arrays(cases, names=case_names), *results)

In [63]:
cond_args = {
    "case_fns": [lambda arr: pc.greater(arr, 0.9), lambda arr: pc.greater(arr, 0.25)],
    "result_fns": [
        lambda arr: pc.add(arr, 50),
        lambda arr: pc.subtract(arr, 5),
        lambda arr: pc.add(arr, 5),
    ],
}
# cond_res = cond(arr, **cond_args)
# if_res = cond_if_then_else(arr, **cond_args)
# cond_res == if_res

True

In [98]:
n_cases = 100


def make_case(i):
    # With this variant, we eliminate most of the options during the first step. Should be very
    # favorable to the if_then_else variant
    # return lambda arr: pc.greater(arr, (i / n_cases))

    # The "1 -" version means we only eliminate a fraction of the options at each step in our
    # tree.
    return lambda arr: pc.greater(arr, 1 - (i / n_cases))


def make_result(i):
    return lambda arr: pc.add(arr, i)


cond_args = {
    "case_fns": [make_case(i) for i in range(n_cases)],
    "result_fns": [make_result(i) for i in range(n_cases + 1)],
}

In [99]:
n_trials = 10
arr = np.random.rand(1000000)
cond_res = timeit.timeit(lambda: cond(arr, **cond_args), number=n_trials) / n_trials
print("COND RES", cond_res)
cond_if_then_else_res = (
    timeit.timeit(lambda: cond_if_then_else(arr, **cond_args), number=n_trials)
    / n_trials
)
print("IFTH RES", cond_if_then_else_res)

COND RES 0.23255444579999676
IFTH RES 0.18101482080001005
