In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
random_seed = 0
save_to_disk = True
model_prefix = 'nn'

In [None]:
from numpy.random import seed
seed(random_seed)
from tensorflow.random import set_seed
set_seed(random_seed)

In [None]:
from dataclasses import dataclass
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import StackingClassifier
from typing import List
import glob


In [None]:
from knowledge_distillation.io import *
from knowledge_distillation.ensemble import UnbiasedAverage
from knowledge_distillation.processing import * 
from knowledge_distillation.nn import *

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier

# Load data

In [None]:
df = load_adult()

# Preprocessing

In [None]:
X, y, target_names = scale_onehot(df, target='salary')

X_train, X_test, y_train, y_test = split_with_seed(X, y)

  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):
  elif pd.api.types.is_categorical(cols):


# Load models

In [None]:
model_paths = sorted(glob.glob(f"{ASSETS_PATH / model_prefix}_*.tf"))

In [None]:
model_names = [p.split('/')[-1].split('.')[0] for p in model_paths]
model_names

['nn_1', 'nn_2']

In [None]:
# load models
models = [load_keras_classifier(name) for name in model_names]    

In [None]:
first = models[0]
assert len(first.predict(X_test.head(5))) == 5



In [None]:
first.predict_proba(X_test.head(5))[:,1], first.predict(X_test.head(5))



(array([0.0076542 , 0.2546998 , 0.25310326, 0.0216367 , 0.05762338],
       dtype=float32),
 array([0, 0, 0, 0, 0], dtype=int8))

# Evaluate ensemble predictions

In [None]:
@dataclass
class TrainedKerasEnsemble():
    keras_estimators:List[KerasClassifier]
    final_estimator:BaseEstimator
    
    def predict_proba(self, X):
        individual_preds = [model.predict_proba(X) for model in self.keras_estimators]
        individual_preds = np.stack(individual_preds, axis=1)
        
        return self.final_estimator.predict_proba(individual_preds)
    
    def predict(self, X):
        return (self.predict_proba(X) > .5).astype(int)
    
    

In [None]:
ensemble = TrainedKerasEnsemble(models, UnbiasedAverage())

In [None]:
ensemble.predict_proba(X_train.head(5))



array([[0.910709  , 0.08929093],
       [0.42169142, 0.5783086 ],
       [0.9630044 , 0.03699556],
       [0.9044428 , 0.09555721],
       [0.9769597 , 0.02304029]], dtype=float32)

In [None]:
evaluate_model(X_train, X_test, y_train, y_test, ensemble, f"{model_prefix}_ensemble", save_to_disk=save_to_disk, target_names=target_names)




=== Train ===
              precision    recall  f1-score   support

       <=50K       0.84      0.98      0.90     19778
        >50K       0.86      0.40      0.54      6270

    accuracy                           0.84     26048
   macro avg       0.85      0.69      0.72     26048
weighted avg       0.84      0.84      0.82     26048


=== Test ===
              precision    recall  f1-score   support

       <=50K       0.84      0.98      0.90      4942
        >50K       0.86      0.40      0.55      1571

    accuracy                           0.84      6513
   macro avg       0.85      0.69      0.73      6513
weighted avg       0.84      0.84      0.82      6513




Unnamed: 0,model_name,data,accuracy,precision,recall,f1,auc
0,nn_ensemble,train,0.83922,0.856019,0.399203,0.544486,0.688958
1,nn_ensemble,test,0.840626,0.864569,0.402292,0.549088,0.69113


# Model distillation: train on ensemble output

In [None]:
# Create a single NN, identical to the ones in the ensemble
model = KerasClassifier(
    build_fn=create_nn,
    **train_params,
    verbose=1
)


In [None]:
y_train_pred_ensemble = ensemble.predict_proba(X_train)
y_train_pred_ensemble[:5]



array([[0.910709  , 0.08929093],
       [0.42169142, 0.5783086 ],
       [0.9630044 , 0.03699556],
       [0.9044428 , 0.09555721],
       [0.9769597 , 0.02304029]], dtype=float32)

In [None]:
# train the NN on the output of the ensemble
model.fit(X_train, y_train_pred_ensemble)


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x7fb6f0342e80>

# Evaluate distilled model

In [None]:
evaluate_model(X_train, X_test, y_train, y_test, model, f'{model_prefix}_{random_seed}_distilled', save_to_disk=save_to_disk, target_names=target_names)






=== Train ===
              precision    recall  f1-score   support

       <=50K       0.79      0.99      0.88     19778
        >50K       0.91      0.18      0.30      6270

    accuracy                           0.80     26048
   macro avg       0.85      0.59      0.59     26048
weighted avg       0.82      0.80      0.74     26048


=== Test ===
              precision    recall  f1-score   support

       <=50K       0.79      1.00      0.88      4942
        >50K       0.92      0.17      0.29      1571

    accuracy                           0.80      6513
   macro avg       0.86      0.58      0.59      6513
weighted avg       0.82      0.80      0.74      6513






Unnamed: 0,model_name,data,accuracy,precision,recall,f1,auc
0,nn_0_distilled,train,0.798756,0.91252,0.18134,0.302555,0.587914
1,nn_0_distilled,test,0.797328,0.922559,0.174411,0.293362,0.584879
