# K. pneumoniae with MLP

In [1]:
# enable auto-reloading submodules
%reload_ext autoreload
%autoreload 2

# hidpi plots for retina displays
%config InlineBackend.figure_format = 'retina'

from sklearn.neural_network import MLPClassifier
from trainer import Trainer

## Training

In [2]:
drugs = [
    'Ciprofloxacin',
    'Ceftriaxone',
    'Cefepime',
    'Meropenem',
    'Tobramycin',
]

trainer = Trainer(
    pathogen='Klebsiella pneumoniae',
    # NOTE: there will be no true label in test data if you use too many folds
    n_splits=5,
    sites=['A'],
    years=[2015, 2016, 2017, 2018],
)

for d in drugs:
    trainer.fit(
        drug=d,
        model=MLPClassifier(
            alpha=0.0001,
            activation='relu',
            random_state=164,
            hidden_layer_sizes=[512, 256, 128],
            max_iter=200,
            # verbose=1,
            # early_stopping=True,
            # n_iter_no_change=20,
            # validation_fraction=0.1,
        )
    )

Loading Ciprofloxacin...
Training w/o SMOTE...
Fold 1/5...
AUC=0.7367427222659323, ACC=0.8939670932358318, f1=0.5915492957746479
Fold 2/5...
AUC=0.704104379753475, ACC=0.8811700182815356, f1=0.5323741007194245
Fold 3/5...
AUC=0.6883293994230264, ACC=0.8628884826325411, f1=0.489795918367347
Fold 4/5...
AUC=0.7159454497770784, ACC=0.8756855575868373, f1=0.5405405405405406
Fold 5/5...
AUC=0.783910792512943, ACC=0.9010989010989011, f1=0.6493506493506493
Training w/ SMOTE...
Fold 1/5...
AUC=0.6987280356674535, ACC=0.8720292504570384, f1=0.5138888888888888
Fold 2/5...




AUC=0.775872016784684, ACC=0.8409506398537477, f1=0.5628140703517588
Fold 3/5...



KeyboardInterrupt



## Results Collection and Visualization

In [None]:
results = trainer.collect_results()
# export results to csv files
results.save_to('./results/kpn_mlp/')

### Bar Graph

In [None]:
results.bar_plot(
    title='Metric Scores for K. pneumoniae (Bar Graph)',
    save_as='./results/kpn_mlp/bar.png',
    nrows=2,
    ncols=3,
    sharey=True,
    figsize=(22, 15),
)

### Box Graph

In [None]:
results.box_plot(
    title='Metric Scores for K. pneumoniae (Box Graph)',
    save_as='./results/kpn_mlp/box.png',
    nrows=1,
    ncols=3,
    # sharey=True,
    figsize=(22, 6),
)
