### Example: breast cancer

This tutorial shows how to train `rofigs` on the brain cancer dataset from scikit-learn.

Ensure that you have installed all required dependencies with `pip install -r requirements.txt`.

In [1]:
%load_ext autoreload        
%autoreload 2

In [2]:
import os
import sys                                                    
sys.path.insert(0, os.path.abspath('..'))           

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from src.rofigs import ROFIGSClassifier

In [3]:
# Load the data from scikit-learn and split it into training and test sets 
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Fit and evaluate RO-FIGS models with various `beam_size` values

In [4]:
# fitting models with various beam_size values: 1, 5, 30 (=all)

beam_sizes = [1, 5, 30]

models = {}

for beam_size in beam_sizes:
    model = ROFIGSClassifier(beam_size=beam_size, max_splits=5, random_state=42)
    model.fit(X_train, y_train)
    models[beam_size] = model

print("Balanced accuracy of the models on the test set with:")
for beam_size, model in models.items():
    accuracy = 100 * balanced_accuracy_score(y_test, model.predict(X_test))
    print(f"\t beam_size = {beam_size}: {accuracy:.1f}")

Balanced accuracy of the models on the test set with:
	 beam_size = 1: 93.9
	 beam_size = 5: 97.0
	 beam_size = 30: 94.9


In [5]:
models[1]

Get the number of trees, splits, and average number of features per split

In [6]:
for beam_size, model in models.items():
    print(f"Model with beam_size = {beam_size} has: {model.count_trees()} tree(s), {model.count_splits()} splits, and {model.get_average_num_feat_per_split()} features per split.")

Model with beam_size = 1 has: 1 tree(s), 5 splits, and 1.0 features per split.
Model with beam_size = 5 has: 1 tree(s), 5 splits, and 2.2 features per split.
Model with beam_size = 30 has: 1 tree(s), 5 splits, and 10.0 features per split.
