# Degenerate Trees Example

We only need numpy for the heavy lifting.

In [1]:
import itertools
import dataclasses
import pprint
from typing import List

import numpy as np

*An* implementation of an optimizer for degenerate trees. **Note** this (serial grid search) is by far the simplest implementation and one should expect far better performance from a more sophisticated algorithm (e.g. simulated annealing).

In [2]:
@dataclasses.dataclass
class ModelInfo:
  # The threshold applied to the target variable copied from the invocation to
  # _optimal_coeffs. 
  cutoff: float

  # The number of rows in the dataset greater than cutoff.
  num_gt: int

  # The number of rows in the dataset less than or equal to the cutoff.
  num_lte: int

  # The set of thresholds that maximizes the F1 across the dataset.
  parameter_thresholds: List[float]

  # The number of votes needed to maximize the F1 across the dataset.
  vote_threshold: int

  # The F1 from the optimal set of thresholds.
  f1: float

  # The omission error from the optimal set of thresholds.
  omission_error: float

  # The commission error from the optimal set of thresholds.
  commission_error: float


def _optimal_coeffs(
    x: np.ndarray, y: np.ndarray, threshold: float,
    stops: List[List[float]]) -> ModelInfo:
  """Determines an optimal set of thresholds and votes to predict y > threshold.

  Performs a parameter sweep over the last dimension of x using the 
  cross-product of stops as thresholds in an attempt to predict y > threshold.

  This will return the collection of elements from the last dimension of stops
  and vote threshold over x that yields the highest F1 score in predicting y.

  Args:
    x: an array of shape [N, M, COLS_N] describing the dataset of thresholdable
      columns (where thresholds are chosen from stops).
    y: an array of shape [N] which is the target variable to be predicted by the
      intersection of thresholds on x.  
    stops: a list of length M * COLS_N containing lists of thresholds to sweep
      over x (referencing the row-major last two dimensions in x).

  Returns:
    A tuple of a list of length M * COLS_N of selected stops, the selected
    number of votes, the best F1 score, and the omission and commission error
    of the selected candidate (ModelInfo).
  """
  y = y > threshold
  num_gt = np.sum(y)
  num_lte = np.sum(~y)

  vote_sweep_array = np.expand_dims(
      np.linspace(1, x.shape[1], x.shape[1]), axis=0)
  y = np.expand_dims(y, axis=-1)

  best_model = None
  for candidate in itertools.product(*stops):
    candidate_array = np.reshape(np.array(candidate), [x.shape[1], x.shape[2]])
    votes_per_row = np.sum(np.all(x > candidate_array, axis=-1), axis=-1)
    
    votes_cross_totals = np.tile(np.expand_dims(votes_per_row, axis=1), 
                                 [1, x.shape[1]])
    vote_sweep = votes_cross_totals >= vote_sweep_array
    tp = np.sum(np.logical_and(vote_sweep, y), axis=0)
    fp = np.sum(np.logical_and(vote_sweep, ~y), axis=0)
    fn = np.sum(np.logical_and(~vote_sweep, y), axis=0)
    f1_all = tp / (tp + 0.5 * (fp + fn))

    vote_thresh = np.argmax(f1_all) + 1
    f1 = f1_all[vote_thresh - 1]

    if (best_model is None) or (f1 > best_model.f1):
      selected_tp = tp[vote_thresh - 1]
      selected_fp = fp[vote_thresh - 1]
      selected_fn = fn[vote_thresh - 1]

      best_model = ModelInfo(
          cutoff=threshold,
          num_gt=num_gt,
          num_lte=num_lte,
          parameter_thresholds=candidate,
          vote_threshold=vote_thresh,
          f1=f1,
          omission_error=selected_fn / (selected_tp + selected_fn),
          commission_error=selected_fp / (selected_tp + selected_fp))

  return best_model


def cutoff_sweep(
    x: np.ndarray, y: np.ndarray, x_stops: List[List[float]],
    y_stops: List[float]) -> List[ModelInfo]:
  """Computes the optimal threshold stops for x given a sweep over y.

  Args:
    x: an array of shape [N, M, COLS_N] describing the dataset of thresholdable
        columns (where thresholds are chosen from stops).
    y: an array of shape [N] to be thresholded to a binary variable.
    x_stops: a list of length M * COLS_N containing lists of thresholds to sweep
        over x (referencing the row-major last two dimensions in x).
    y_stops: a list of thresholds to sweep over y.

  Returns:
    A list of tuples for each of the thresholds in y_stops consisting of: a
    tuple of a list of length M * COLS_N of selected stops, the selected number
    of votes, the best F1 score, and the omission and commission error of the
    selected candidate of columns in x that best predict the thresholded value
    in y (ModelInfo)
  """
  return [_optimal_coeffs(x, y, threshold, x_stops) for threshold in y_stops]


Lets test on some fake data. We'll use an ensemble of 2 degenerate trees with some arbitrary stops on [0, 1].

In [3]:
# Note that stops do not need to be in sorted order.

# These apply to the values from the first tree.
a_stops = [0, 0.1, 0.5, 0.6]
b_stops = [0.3, 0.4, 0.9, 0.95]

# These apply to the values from the second tree.
c_stops = [0.2, 0.4, 0.7]
d_stops = [0.1, 0.8]

parameter_stops = [a_stops, b_stops, c_stops, d_stops]

Run the model on some fake data.

In [4]:
# We've got 20 rows x 2 trees x 2 variables per tree.
#
# Note that for a given row in our example:
#     [0, 0] is visited by a_stops for the first tree.
#     [0, 1] is visited by b_stops for the first tree.
#     [1, 0] is visited by c_stops for the second tree.
#     [1, 1] is visited by d_stops for the second tree.

test_x = np.random.uniform(size=[20, 2, 2])

# This part is easy, just 20 random values.

test_y = np.random.uniform(size=[20])

# Find the optimal stops for our pair of degenerate trees!

model_results = cutoff_sweep(
    test_x,
    test_y,
    parameter_stops,
    # We test 2 stops for y, and will therefore get 2 models out the other end.
    [0.3333, 0.6666])

pprint.pprint(model_results)

[ModelInfo(cutoff=0.3333, num_gt=14, num_lte=6, parameter_thresholds=(0, 0.3, 0.7, 0.1), vote_threshold=1, f1=0.8125, omission_error=0.07142857142857142, commission_error=0.2777777777777778),
 ModelInfo(cutoff=0.6666, num_gt=7, num_lte=13, parameter_thresholds=(0.5, 0.3, 0.7, 0.1), vote_threshold=1, f1=0.5555555555555556, omission_error=0.2857142857142857, commission_error=0.5454545454545454)]
