## Example

In this tutorial, we will learn how to use `rofigs` on diabetes dataset from the paper.

First, you need to install the required dependencies, with `pip install -r requirements.txt`.

In [1]:
%load_ext autoreload        
%autoreload 2

In [2]:
import os
import sys                                                    
sys.path.insert(0, os.path.abspath('..'))           

from src.utils import load_final_data, load_data
from src.rofigs import ROFIGSClassifier
from sklearn.metrics import balanced_accuracy_score


Load preprocessed diabetes dataset

In [3]:
(X_train, y_train), (X_test, y_test) = load_final_data(dataset="diabetes", fold=6)

Fit the model and compute balanced accuracy

In [4]:
# fitting models with various beam_size values

model_4 = ROFIGSClassifier(beam_size=4, max_splits=10, min_impurity_decrease=10, random_state=12345)
model_4.fit(X_train, y_train)
accuracy = 100 * balanced_accuracy_score(y_test, model_4.predict(X_test))
print(f"Model with beam_size=4 has {model_4.count_trees()} trees, {model_4.count_splits()} splits, and {model_4.get_average_num_feat_per_split():.1f} features per split. Its balanced accuracy on the test set is: {accuracy:.1f}")

model_8 = ROFIGSClassifier(beam_size=8, max_splits=10, min_impurity_decrease=10, random_state=12345)
model_8.fit(X_train, y_train)
accuracy = 100 * balanced_accuracy_score(y_test, model_8.predict(X_test))
print(f"Model with beam_size=8 has {model_8.count_trees()} trees, {model_8.count_splits()} splits, and {model_8.get_average_num_feat_per_split():.1f} features per split. Its balanced accuracy on the test set is: {accuracy:.1f}")

Model with beam_size=4 has 2 trees, 2 splits, and 1.5 features per split. Its balanced accuracy on the test set is: 65.1
Model with beam_size=8 has 1 trees, 1 splits, and 3.0 features per split. Its balanced accuracy on the test set is: 72.5


In [5]:
model_4

In [6]:
model_8

Access feature combinations that appear in splits

In [7]:
model_4.feature_combinations

[(5, 6), (1,)]

In [8]:
model_8.feature_combinations

[(1, 5, 7)]