In [1]:
import numpy as np
import pandas as pd
import sklearn.metrics as sk_metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import random_split

from nam.wrapper import NAMClassifier, MultiTaskNAMClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_state = 2016

In [3]:
dataset = pd.read_csv('nam/data/data_recidivism_recid.data', delimiter=' ', header=None)
dataset.columns = ["age", "race", "sex", "priors_count", "length_of_stay", "c_charge_degree", "two_year_recid"]

In [4]:
dataset.head()

Unnamed: 0,age,race,sex,priors_count,length_of_stay,c_charge_degree,two_year_recid
0,69,6,2,0,1,1,0
1,34,1,2,0,10,1,1
2,24,1,2,4,1,1,1
3,44,6,2,0,1,2,0
4,41,3,2,14,6,1,1


In [5]:
binary = ['sex', 'c_charge_degree']
other = ['age', 'race', 'priors_count', 'length_of_stay']

In [6]:
scaler = MinMaxScaler((-1, 1))
dataset[other] = scaler.fit_transform(dataset[other])
dataset[binary] = dataset[binary] - 1

In [7]:
dataset

Unnamed: 0,age,race,sex,priors_count,length_of_stay,c_charge_degree,two_year_recid
0,0.307692,1.0,1,-1.000000,-0.9975,0,0
1,-0.589744,-1.0,1,-1.000000,-0.9750,0,1
2,-0.846154,-1.0,1,-0.789474,-0.9975,0,1
3,-0.333333,1.0,1,-1.000000,-0.9975,1,0
4,-0.410256,-0.2,1,-0.263158,-0.9850,0,1
...,...,...,...,...,...,...,...
6167,-0.871795,-1.0,1,-1.000000,-0.9950,0,0
6168,-0.871795,-1.0,1,-1.000000,-0.9950,0,0
6169,0.000000,1.0,1,-1.000000,-0.9975,0,0
6170,-0.615385,-1.0,0,-0.842105,-0.9975,1,0


In [8]:
data_train, data_test = train_test_split(dataset, train_size=0.8, test_size=0.2, random_state=random_state)
X_train, y_train = data_train[other + binary], data_train['two_year_recid']
X_test, y_test = data_test[other + binary], data_test['two_year_recid']

## Single Task NAMs Classification

In [18]:
model = NAMClassifier(
            num_epochs=100,
            num_learners=1,
            metric='auroc',
            early_stop_mode='max',
            monitor_loss=False,
            n_jobs=1,
            random_state=random_state
        )
from time import perf_counter
print("training")
start = perf_counter()
model.fit(X_train, y_train)
print(perf_counter()-start)

training


  0%|                                                                                          | 0/100 [00:00<?, ?it/s]
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 23.83it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(0):                                                                                                              [A
                    Training Loss: 1.688 |
Epoch(0):           Validation Loss: 0.782 | AUROC: 0.491:   0%| |
                    Training Loss: 1.688 |
                    Validation Loss: 0.782 | AUROC: 0.491:   1%| |
  0%|                                        

Epoch(9):                                                                                                              [A
                    Training Loss: 0.678 |
Epoch(9):           Validation Loss: 0.677 | AUROC: 0.597:   9%| |
                    Training Loss: 0.678 |
                    Validation Loss: 0.677 | AUROC: 0.597:  10%| |
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 25.11it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(10):                                                                                                             [A
                    Training Loss: 0.677 |

                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(19):                                                                                                             [A
                    Training Loss: 0.671 |
Epoch(19):          Validation Loss: 0.671 | AUROC: 0.617:  19%|▏
                    Training Loss: 0.671 |
                    Validation Loss: 0.671 | AUROC: 0.617:  20%|▏
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 22.92it/s][A
                                                                                                                       [A
  0%|                                       

  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00, 19.96it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 24.72it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(29):                                                                                                             [A
                    Training Loss: 0.677 |
Epoch(29):          Validation Loss: 0.670 | AUROC: 0.621:  29%|▎
                    Training Loss: 0.677 |
                    Validation Loss: 0.670 | AUROC: 0.621:  30%|▎
  0%|                                       

  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(38):                                                                                                             [A
                    Training Loss: 0.671 |
Epoch(38):          Validation Loss: 0.670 | AUROC: 0.626:  38%|▍
                    Training Loss: 0.671 |
                    Validation Loss: 0.670 | AUROC: 0.626:  39%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 26.08it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(39):                                  

 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 24.18it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(48):                                                                                                             [A
                    Training Loss: 0.675 |
Epoch(48):          Validation Loss: 0.671 | AUROC: 0.620:  48%|▍
                    Training Loss: 0.675 |
                    Validation Loss: 0.671 | AUROC: 0.620:  49%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 24.64it/s][A
                                            

                    Validation Loss: 0.670 | AUROC: 0.618:  58%|▌
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 21.02it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(58):                                                                                                             [A
                    Training Loss: 0.676 |
Epoch(58):          Validation Loss: 0.677 | AUROC: 0.624:  58%|▌
                    Training Loss: 0.676 |
                    Validation Loss: 0.677 | AUROC: 0.624:  59%|▌
  0%|                                                                                            | 0/

 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00, 19.34it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 18.83it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(67):                                                                                                             [A
                    Training Loss: 0.675 |
Epoch(67):          Validation Loss: 0.668 | AUROC: 0.623:  67%|▋
                    Training Loss: 0.675 |
                    Validation Loss: 0.668 | AUROC: 0.623:  68%|▋
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|███████████████████████████████████████

Epoch(76):                                                                                                             [A
                    Training Loss: 0.669 |
Epoch(76):          Validation Loss: 0.671 | AUROC: 0.623:  76%|▊
                    Training Loss: 0.669 |
                    Validation Loss: 0.671 | AUROC: 0.623:  77%|▊
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 21.17it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(77):                                                                                                             [A
                    Training Loss: 0.673 |
E

  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 20.75it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(86):                                                                                                             [A
                    Training Loss: 0.675 |
Epoch(86):          Validation Loss: 0.669 | AUROC: 0.622:  86%|▊
                    Training Loss: 0.675 |
                    Validation Loss: 0.669 | AUROC: 0.622:  87%|▊
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 40%|█████████████████████████████████▌     

Epoch(95):                                                                                                             [A
                    Training Loss: 0.665 |
Epoch(95):          Validation Loss: 0.667 | AUROC: 0.632:  95%|▉
                    Training Loss: 0.665 |
                    Validation Loss: 0.667 | AUROC: 0.632:  96%|▉
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 20.78it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(96):                                                                                                             [A
                    Training Loss: 0.668 |
E

24.113917600014247


In [19]:
pred = model.predict_proba(X_test)
sk_metrics.roc_auc_score(y_test, pred)

0.6492038972750023

## Multitask NAMs Classification

In [20]:
def make_gender_mtl_data(X, y):
    y_male = y.copy()
    y_male[X['sex'] == 1] = np.nan
    y_female = y.copy()
    y_female[X['sex'] == 0] = np.nan
    return pd.concat([y_female, y_male], axis=1)

In [21]:
y_train_mtl = make_gender_mtl_data(X_train, y_train)
y_test_mtl = make_gender_mtl_data(X_test, y_test)

In [22]:
X_train_mtl = X_train.drop(columns=['sex'])
X_test_mtl = X_test.drop(columns=['sex'])

In [23]:
# NaN indicates label missing
y_train_mtl

Unnamed: 0,two_year_recid,two_year_recid.1
4819,0.0,
1581,,1.0
0,0.0,
1575,0.0,
1159,1.0,
...,...,...
4604,0.0,
653,0.0,
4691,0.0,
5386,,1.0


In [24]:
model = MultiTaskNAMClassifier(
            num_learners=1,
            patience=60,
            num_epochs=100,
            num_subnets=1,
            metric='auroc',
            monitor_loss=False,
            early_stop_mode='max',
            n_jobs=1,
            random_state=random_state
        )

start = perf_counter()
print("training")
model.fit(X_train_mtl, y_train_mtl)
print(perf_counter()-start)

training


  0%|                                                                                          | 0/100 [00:00<?, ?it/s]
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 25.92it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(0):                                                                                                              [A
                    Training Loss: 0.841 |
Epoch(0):           Validation Loss: 0.723 | AUROC: 0.496:   0%| |
                    Training Loss: 0.841 |
                    Validation Loss: 0.723 | AUROC: 0.496:   1%| |
  0%|                                        

Epoch(9):                                                                                                              [A
                    Training Loss: 0.678 |
Epoch(9):           Validation Loss: 0.674 | AUROC: 0.567:   9%| |
                    Training Loss: 0.678 |
                    Validation Loss: 0.674 | AUROC: 0.567:  10%| |
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 26.79it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(10):                                                                                                             [A
                    Training Loss: 0.666 |

                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(19):                                                                                                             [A
                    Training Loss: 0.661 |
Epoch(19):          Validation Loss: 0.652 | AUROC: 0.656:  19%|▏
                    Training Loss: 0.661 |
                    Validation Loss: 0.652 | AUROC: 0.656:  20%|▏
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 25.84it/s][A
                                                                                                                       [A
  0%|                                       

  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 24.35it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(29):                                                                                                             [A
                    Training Loss: 0.640 |
Epoch(29):          Validation Loss: 0.635 | AUROC: 0.674:  29%|▎
                    Training Loss: 0.640 |
                    Validation Loss: 0.635 | AUROC: 0.674:  30%|▎
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|███████████████████████████████████████

Epoch(38):                                                                                                             [A
                    Training Loss: 0.650 |
Epoch(38):          Validation Loss: 0.639 | AUROC: 0.663:  38%|▍
                    Training Loss: 0.650 |
                    Validation Loss: 0.639 | AUROC: 0.663:  39%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00,  8.77it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.59it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(39):                                  

 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 23.45it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(48):                                                                                                             [A
                    Training Loss: 0.650 |
Epoch(48):          Validation Loss: 0.632 | AUROC: 0.668:  48%|▍
                    Training Loss: 0.650 |
                    Validation Loss: 0.632 | AUROC: 0.668:  49%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 25.68it/s][A
                                            

Epoch(57):          Validation Loss: 0.632 | AUROC: 0.674:  57%|▌
                    Training Loss: 0.647 |
                    Validation Loss: 0.632 | AUROC: 0.674:  58%|▌
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 23.98it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(58):                                                                                                             [A
                    Training Loss: 0.646 |
Epoch(58):          Validation Loss: 0.638 | AUROC: 0.671:  58%|▌
                    Training Loss: 0.646 |
                    Validation Loss: 0.638 | AUROC: 0.671:

  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(67):                                                                                                             [A
                    Training Loss: 0.648 |
Epoch(67):          Validation Loss: 0.635 | AUROC: 0.670:  67%|▋
                    Training Loss: 0.648 |
                    Validation Loss: 0.635 | AUROC: 0.670:  68%|▋
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 20%|████████████████▊                                                                   | 1/5 [00:00<00:00,  4.12it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00,  9.63it/s][A
                                                                                                                       [A
  0%|                                       

Epoch(76):                                                                                                             [A
                    Training Loss: 0.650 |
Epoch(76):          Validation Loss: 0.634 | AUROC: 0.679:  76%|▊
                    Training Loss: 0.650 |
                    Validation Loss: 0.634 | AUROC: 0.679:  77%|▊
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00, 18.03it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 18.27it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(77):                                  

  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 22.45it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(86):                                                                                                             [A
                    Training Loss: 0.638 |
Epoch(86):          Validation Loss: 0.631 | AUROC: 0.684:  86%|▊
                    Training Loss: 0.638 |
                    Validation Loss: 0.631 | AUROC: 0.684:  87%|▊
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|███████████████████████████████████████

  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(95):                                                                                                             [A
                    Training Loss: 0.640 |
Epoch(95):          Validation Loss: 0.625 | AUROC: 0.673:  95%|▉
                    Training Loss: 0.640 |
                    Validation Loss: 0.625 | AUROC: 0.673:  96%|▉
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 21.95it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(96):                                  

24.090321699972264


In [25]:
pred = model.predict_proba(X_test_mtl)

In [26]:
# Flatten and remove nans
y_test_mtl_flat = y_test_mtl.to_numpy().reshape(-1)
pred_flat = pred.reshape(-1)

non_nan_indices = y_test_mtl_flat == y_test_mtl_flat 
y_test_mtl_flat = y_test_mtl_flat[non_nan_indices]
pred_flat = pred_flat[non_nan_indices]

In [27]:
sk_metrics.roc_auc_score(y_test_mtl_flat, pred_flat)

0.6968139883572644