In [1]:
import numpy as np
import pandas as pd
import sklearn.metrics as sk_metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import random_split

from nam.wrapper import NAMClassifier, MultiTaskNAMClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_state = 2016

In [3]:
dataset = pd.read_csv('nam/data/data_recidivism_recid.data', delimiter=' ', header=None)
dataset.columns = ["age", "race", "sex", "priors_count", "length_of_stay", "c_charge_degree", "two_year_recid"]

In [4]:
dataset.head()

Unnamed: 0,age,race,sex,priors_count,length_of_stay,c_charge_degree,two_year_recid
0,69,6,2,0,1,1,0
1,34,1,2,0,10,1,1
2,24,1,2,4,1,1,1
3,44,6,2,0,1,2,0
4,41,3,2,14,6,1,1


In [5]:
binary = ['sex', 'c_charge_degree']
other = ['age', 'race', 'priors_count', 'length_of_stay']

In [6]:
scaler = MinMaxScaler((-1, 1))
dataset[other] = scaler.fit_transform(dataset[other])
dataset[binary] = dataset[binary] - 1

In [7]:
dataset

Unnamed: 0,age,race,sex,priors_count,length_of_stay,c_charge_degree,two_year_recid
0,0.307692,1.0,1,-1.000000,-0.9975,0,0
1,-0.589744,-1.0,1,-1.000000,-0.9750,0,1
2,-0.846154,-1.0,1,-0.789474,-0.9975,0,1
3,-0.333333,1.0,1,-1.000000,-0.9975,1,0
4,-0.410256,-0.2,1,-0.263158,-0.9850,0,1
...,...,...,...,...,...,...,...
6167,-0.871795,-1.0,1,-1.000000,-0.9950,0,0
6168,-0.871795,-1.0,1,-1.000000,-0.9950,0,0
6169,0.000000,1.0,1,-1.000000,-0.9975,0,0
6170,-0.615385,-1.0,0,-0.842105,-0.9975,1,0


In [8]:
data_train, data_test = train_test_split(dataset, train_size=0.8, test_size=0.2, random_state=random_state)
X_train, y_train = data_train[other + binary], data_train['two_year_recid']
X_test, y_test = data_test[other + binary], data_test['two_year_recid']

## Single Task NAMs Classification

In [9]:
# model = NAMClassifier(
#             num_epochs=100,
#             num_learners=1,
#             metric='auroc',
#             early_stop_mode='max',
#             monitor_loss=False,
#             n_jobs=1,
#             random_state=random_state
#         )
from time import perf_counter
# print("training")
# start = perf_counter()
# model.fit(X_train, y_train)
# print(perf_counter()-start)

In [10]:
# pred = model.predict_proba(X_test)
# sk_metrics.roc_auc_score(y_test, pred)

## Multitask NAMs Classification

In [11]:
def make_gender_mtl_data(X, y):
    y_male = y.copy()
    y_male[X['sex'] == 1] = np.nan
    y_female = y.copy()
    y_female[X['sex'] == 0] = np.nan
    return pd.concat([y_female, y_male], axis=1)

In [12]:
y_train_mtl = make_gender_mtl_data(X_train, y_train)
y_test_mtl = make_gender_mtl_data(X_test, y_test)

In [13]:
X_train_mtl = X_train.drop(columns=['sex'])
X_test_mtl = X_test.drop(columns=['sex'])

In [14]:
# NaN indicates label missing
y_train_mtl

Unnamed: 0,two_year_recid,two_year_recid.1
4819,0.0,
1581,,1.0
0,0.0,
1575,0.0,
1159,1.0,
...,...,...
4604,0.0,
653,0.0,
4691,0.0,
5386,,1.0


In [15]:
import torch

model = MultiTaskNAMClassifier(
            num_learners=1,
            patience=60,
            num_epochs=100,
            num_subnets=1,
            metric='auroc',
            monitor_loss=False,
            early_stop_mode='max',
            n_jobs=1,
            random_state=random_state,
            loss_func=torch.nn.functional.cross_entropy
        )

start = perf_counter()
print("training")
model.fit(X_train_mtl, y_train_mtl)
print(perf_counter()-start)

training


  0%|                                                                                          | 0/100 [00:00<?, ?it/s]
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00, 12.06it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(0):                                                                                                              [A
                    Training Loss: 0.355 |
Epoch(0):           Validation Loss: 0.246 | AUROC: 0.520:   0%| |
                    Training Loss: 0.355 |
                    Validation Loss: 0.246 | AUROC: 0.520:   1%| |
  0%|                                        

Epoch(9):                                                                                                              [A
                    Training Loss: 0.208 |
Epoch(9):           Validation Loss: 0.196 | AUROC: 0.558:   9%| |
                    Training Loss: 0.208 |
                    Validation Loss: 0.196 | AUROC: 0.558:  10%| |
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 30.96it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(10):                                                                                                             [A
                    Training Loss: 0.187 |

                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(19):                                                                                                             [A
                    Training Loss: 0.209 |
Epoch(19):          Validation Loss: 0.192 | AUROC: 0.591:  19%|▏
                    Training Loss: 0.209 |
                    Validation Loss: 0.192 | AUROC: 0.591:  20%|▏
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 29.41it/s][A
                                                                                                                       [A
  0%|                                       

  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 29.94it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(29):                                                                                                             [A
                    Training Loss: 0.178 |
Epoch(29):          Validation Loss: 0.191 | AUROC: 0.591:  29%|▎
                    Training Loss: 0.178 |
                    Validation Loss: 0.191 | AUROC: 0.591:  30%|▎
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|███████████████████████████████████████

Epoch(38):                                                                                                             [A
                    Training Loss: 0.195 |
Epoch(38):          Validation Loss: 0.191 | AUROC: 0.609:  38%|▍
                    Training Loss: 0.195 |
                    Validation Loss: 0.191 | AUROC: 0.609:  39%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 28.40it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(39):                                                                                                             [A
                    Training Loss: 0.180 |
E

                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(48):                                                                                                             [A
                    Training Loss: 0.182 |
Epoch(48):          Validation Loss: 0.190 | AUROC: 0.622:  48%|▍
                    Training Loss: 0.182 |
                    Validation Loss: 0.190 | AUROC: 0.622:  49%|▍
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00, 29.40it/s][A
                                                                                                                       [A
  0%|                                       

                    Validation Loss: 0.191 | AUROC: 0.616:  58%|▌
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 34.36it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(58):                                                                                                             [A
                    Training Loss: 0.190 |
Epoch(58):          Validation Loss: 0.191 | AUROC: 0.603:  58%|▌
                    Training Loss: 0.190 |
                    Validation Loss: 0.191 | AUROC: 0.603:  59%|▌
  0%|                                                                                            | 0/

Epoch(67):          Validation Loss: 0.191 | AUROC: 0.613:  67%|▋
                    Training Loss: 0.182 |
                    Validation Loss: 0.191 | AUROC: 0.613:  68%|▋
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 33.01it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(68):                                                                                                             [A
                    Training Loss: 0.178 |
Epoch(68):          Validation Loss: 0.191 | AUROC: 0.609:  68%|▋
                    Training Loss: 0.178 |
                    Validation Loss: 0.191 | AUROC: 0.609:

  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(77):                                                                                                             [A
                    Training Loss: 0.180 |
Epoch(77):          Validation Loss: 0.190 | AUROC: 0.606:  77%|▊
                    Training Loss: 0.180 |
                    Validation Loss: 0.190 | AUROC: 0.606:  78%|▊
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 37.57it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(78):                                  

 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 28.25it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(87):                                                                                                             [A
                    Training Loss: 0.187 |
Epoch(87):          Validation Loss: 0.191 | AUROC: 0.616:  87%|▊
                    Training Loss: 0.187 |
                    Validation Loss: 0.191 | AUROC: 0.616:  88%|▉
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 28.19it/s][A
                                            

                    Validation Loss: 0.191 | AUROC: 0.597:  97%|▉
  0%|                                                                                            | 0/5 [00:00<?, ?it/s][A
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 33.60it/s][A
                                                                                                                       [A
  0%|                                                                                            | 0/1 [00:00<?, ?it/s][A
Epoch(97):                                                                                                             [A
                    Training Loss: 0.183 |
Epoch(97):          Validation Loss: 0.191 | AUROC: 0.598:  97%|▉
                    Training Loss: 0.183 |
                    Validation Loss: 0.191 | AUROC: 0.598:  98%|▉
  0%|                                                                                            | 0/

21.598452600184828


In [24]:
pred = model.predict_proba(X_test_mtl)
pred

array([[0.68433465, 0.31496906],
       [0.74925618, 0.24770968],
       [0.7117841 , 0.28846996],
       ...,
       [0.68433465, 0.31496906],
       [0.68608045, 0.31305955],
       [0.6843353 , 0.31496755]])

In [17]:
# Flatten and remove nans
y_test_mtl_flat = y_test_mtl.to_numpy().reshape(-1)
pred_flat = pred.reshape(-1)

non_nan_indices = y_test_mtl_flat == y_test_mtl_flat 
y_test_mtl_flat = y_test_mtl_flat[non_nan_indices]
pred_flat = pred_flat[non_nan_indices]

In [18]:
sk_metrics.roc_auc_score(y_test_mtl_flat, pred_flat)

0.6255289120818726