Skip to content

Feature Request: Decision Tree & Random Forest for Classification Tasks #1

@noahgift

Description

@noahgift

Feature Request: Decision Tree & Random Forest

Motivation

While migrating PMAT from linfa to aprender, we currently use linfa_trees::DecisionTree for mutant survivability prediction (classification task). Aprender currently provides LinearRegression and KMeans, but adding tree-based classifiers would:

  1. Ease migration: Drop-in replacement for existing DecisionTree usage
  2. Better accuracy: Tree-based models often outperform linear models for complex classification
  3. Interpretability: Decision trees provide clear decision paths
  4. Feature importance: Understand which features matter most

Proposed API

use aprender::prelude::*;
use aprender::tree::{DecisionTreeClassifier, RandomForestClassifier};

// Decision Tree
let mut tree = DecisionTreeClassifier::new()
    .with_max_depth(10)
    .with_min_samples_split(2)
    .with_criterion(SplitCriterion::Gini);

tree.fit(&x, &y)?;
let predictions = tree.predict(&x);
let probabilities = tree.predict_proba(&x)?;

// Feature importance
let importance = tree.feature_importance()?;

// Random Forest (ensemble)
let mut forest = RandomForestClassifier::new()
    .with_n_estimators(100)
    .with_max_depth(10)
    .with_random_state(42);

forest.fit(&x, &y)?;
let predictions = forest.predict(&x);

Use Case: PMAT Mutation Testing

Current (linfa):

use linfa_trees::{DecisionTree, SplitQuality};

let tree = DecisionTree::params()
    .split_quality(SplitQuality::Gini)
    .max_depth(Some(10))
    .fit(&dataset)?;

Target (aprender):

use aprender::tree::DecisionTreeClassifier;

let mut tree = DecisionTreeClassifier::new()
    .with_max_depth(10)
    .with_criterion(SplitCriterion::Gini);

tree.fit(&x, &y)?;

Implementation Considerations

Core Requirements:

  • Classification (binary + multi-class)
  • Gini impurity & entropy split criteria
  • Max depth, min samples split, min samples leaf
  • Feature importance (Gini importance)
  • Predict class labels + probabilities

Nice-to-Have:

  • CART algorithm (Classification and Regression Trees)
  • Pruning (pre-pruning, post-pruning)
  • Parallel tree building for Random Forest
  • Out-of-bag error estimation

API Design:

  • Follow aprender's Estimator trait pattern
  • Support fit(), predict(), score() like LinearRegression
  • Return Vector for predictions (consistent with existing API)

Alternatives Considered

  1. Use LinearRegression: Less accurate for non-linear classification
  2. Use KMeans: Unsupervised, doesn't learn from labeled data
  3. Keep linfa: Defeats purpose of migration (50+ deps)

Priority

Medium-High: Would unblock PMAT ML migration and provide valuable classification capabilities for other users.

References


Context: Filed while creating comprehensive aprender integration specification for PMAT (725 lines, commit 394eb523)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions