# Feature Selection

This user guide details how DNAMite can be used for feature selection / feature-sparse prediction.

### Why Bother with Feature Selection?

When training a black-box machine learning model, it is common practice to use all available features even for high-dimensional datasets, as modern ML models can easily handle many features. When training a glass-box model, however, we need to care about both predictive performance as well as accurate and utility of explanations. While glass-box models often have good accurate on high-dimensional datasets, model explanations are much more likely to be impaired in such settings. In particular, when sets of correlated features are all used in the same dataset, additive models like DNAMite run into identifiability issues with how to spread contribution across the feature set.

### DNAMite Example

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
sns.set_theme()
from sklearn.model_selection import train_test_split
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

df_train = pd.read_csv("mortality_tab_train.csv")
X_train = df_train.drop(["target"], axis=1)
y_train = df_train["target"]

df_test = pd.read_csv("mortality_tab_test.csv")
X_test = df_test.drop(["target"], axis=1)
y_test = df_test["target"]

In [2]:
from dnamite.models import DNAMiteBinaryClassifier

model = DNAMiteBinaryClassifier(n_features=X_train.shape[1], device=device, fit_pairs=False)
model.fit(X_train, y_train)

Discretizing features...


100%|██████████| 714/714 [00:00<00:00, 1095.73it/s]


SPlIT 0
TRAINING MAINS


                                                

Early stopping at 6 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 1
TRAINING MAINS


                                                

Early stopping at 7 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 2
TRAINING MAINS


                                                

Early stopping at 8 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 3
TRAINING MAINS


                                                

Early stopping at 8 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 4
TRAINING MAINS


                                                

Early stopping at 7 epochs: Test loss has not improved for 5 consecutive epochs.


In [3]:
from sklearn.metrics import roc_auc_score
print("TEST AUC: ", roc_auc_score(y_test, model.predict_proba(X_test)))

                                      

TEST AUC:  0.8468704806107692




In [5]:
model = DNAMiteBinaryClassifier(
    n_features=X_train.shape[1], 
    device=device, 
    fit_pairs=False,
    reg_param=0.008,
    gamma=0.1
)
model.select_features(X_train, y_train)

Discretizing features...


100%|██████████| 714/714 [00:00<00:00, 1217.36it/s]
                                                

Epoch 1 | Train loss: 2.167 | Test loss: 1.101 | Active features: 714


                                                

Epoch 2 | Train loss: 0.645 | Test loss: 0.425 | Active features: 714


                                                

Epoch 3 | Train loss: 0.362 | Test loss: 0.357 | Active features: 461


                                                

Epoch 4 | Train loss: 0.334 | Test loss: 0.348 | Active features: 287


                                                

Epoch 5 | Train loss: 0.326 | Test loss: 0.346 | Active features: 130


                                                

Epoch 6 | Train loss: 0.323 | Test loss: 0.342 | Active features: 58


                                                

Epoch 7 | Train loss: 0.318 | Test loss: 0.337 | Active features: 39


                                                

Epoch 8 | Train loss: 0.316 | Test loss: 0.337 | Active features: 22


                                                

Epoch 9 | Train loss: 0.314 | Test loss: 0.339 | Active features: 21


                                                

Epoch 10 | Train loss: 0.313 | Test loss: 0.338 | Active features: 21


                                                

Epoch 11 | Train loss: 0.312 | Test loss: 0.334 | Active features: 21


                                                

Epoch 12 | Train loss: 0.312 | Test loss: 0.333 | Active features: 21


                                                

Epoch 13 | Train loss: 0.310 | Test loss: 0.335 | Active features: 21


                                                

Epoch 14 | Train loss: 0.309 | Test loss: 0.335 | Active features: 21


                                                

Epoch 15 | Train loss: 0.308 | Test loss: 0.334 | Active features: 21


                                                

Epoch 16 | Train loss: 0.309 | Test loss: 0.339 | Active features: 20


                                                

Epoch 17 | Train loss: 0.307 | Test loss: 0.331 | Active features: 20


                                                

Epoch 18 | Train loss: 0.308 | Test loss: 0.331 | Active features: 20


                                                

Epoch 19 | Train loss: 0.306 | Test loss: 0.329 | Active features: 20


                                                

Epoch 20 | Train loss: 0.306 | Test loss: 0.331 | Active features: 20


                                                

Epoch 21 | Train loss: 0.305 | Test loss: 0.338 | Active features: 20


                                                

Epoch 22 | Train loss: 0.305 | Test loss: 0.333 | Active features: 20


                                                

Epoch 23 | Train loss: 0.305 | Test loss: 0.332 | Active features: 20


                                                

Epoch 24 | Train loss: 0.305 | Test loss: 0.329 | Active features: 20
Early stopping at 24 epochs: Test loss has not improved for 5 consecutive epochs.
Number of main features selected:  20


In [6]:
model.selected_feats

['Glascow coma scale eye opening | 50-100% | min',
 'Glascow coma scale eye opening | 75-100% | min',
 'Glascow coma scale eye opening | 90-100% | mean',
 'Glascow coma scale motor response | 0-50% | std',
 'Glascow coma scale motor response | 75-100% | max',
 'Glascow coma scale motor response | 90-100% | max',
 'Glascow coma scale verbal response | 75-100% | max',
 'Glascow coma scale verbal response | 75-100% | mean',
 'Glucose | 0-25% | len',
 'Glucose | 50-100% | min',
 'Glucose | 90-100% | mean',
 'Mean blood pressure | 0-10% | mean',
 'Oxygen saturation | 90-100% | len',
 'Respiratory rate | 0-25% | mean',
 'Respiratory rate | 90-100% | min',
 'Respiratory rate | 90-100% | std',
 'Systolic blood pressure | 50-100% | mean',
 'Temperature | 75-100% | min',
 'Temperature | 90-100% | mean',
 'pH | 0-25% | len']

In [7]:
model.fit(X_train, y_train)

SPlIT 0
Found selected features. Using only those features.
TRAINING MAINS


                                                

Early stopping at 11 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 1
Found selected features. Using only those features.
TRAINING MAINS


                                                

Early stopping at 14 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 2
Found selected features. Using only those features.
TRAINING MAINS


                                                

Early stopping at 13 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 3
Found selected features. Using only those features.
TRAINING MAINS


                                                

Early stopping at 14 epochs: Test loss has not improved for 5 consecutive epochs.
SPlIT 4
Found selected features. Using only those features.
TRAINING MAINS


                                                

Early stopping at 10 epochs: Test loss has not improved for 5 consecutive epochs.


In [8]:
print("TEST AUC: ", roc_auc_score(y_test, model.predict_proba(X_test)))

                                               

TEST AUC:  0.49946187737530684




### Hyperparameters

DNAMite has multiple hyperparameters that can be set to control the 