# Hyperparameter Tuning Methods Comparison
Cross validation approach based on the following [repo](https://github.com/roamanalytics/roamresearch/tree/master/BlogPosts/Hyperparameter_tuning_comparison)

In [1]:
%matplotlib inline

In [2]:
%load_ext autoreload
%aimport hpt_cmp

In [3]:
from hpt_cmp import *

In [4]:
from __future__ import print_function

from sklearn import datasets
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from dotmap import DotMap

In [5]:
# Loading the Digits dataset
digits = datasets.load_digits()

# To apply an classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:


In [6]:
param_grid = [
  {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
  {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
 ]

In [9]:
hpt_objs = [
    # add more objs once more search functions implemented
    {
        'name': 'GridSearch',
        'cv': grid_search,
        'param_grid': param_grid,
        'args' : {}
    }]
hpt_objs = [DotMap(obj) for obj in hpt_objs ]

#### cmp_hpt_methods
Parameter description
> `htp_objs`: list of hyperparam-tuning object <br>
> `model`: sklearn model to optimize (needs to have fit/predict function)<br>
> `dataset`: tuple of (X,y) e.g (Data, Target) <br>
> `loss`: sklearn loss function to user <br>
> `metric`: sklearn metric to optimize for <br>
> `datset_split`: random_state for datasetsplit <br>
> `name`: currently not user *optional* <br>

In [8]:
# Run search & print result
results = cmp_hpt_methods(
    hpt_objs,
    SVC,
    (digits.data, digits.target),
    'log_loss',
    accuracy_score)
results

Start
Here
HTP using GridSearch
<generator object _BaseKFold.split at 0x7fcd5838da98>


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))




[{'Model': 'SVC',
  'Hyper optimization method': 'GridSearch',
  'Test accuracy': [0.978021978021978,
   0.9502762430939227,
   0.9832869080779945,
   0.988795518207283,
   0.9633802816901409],
  'Best Parameters': [{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'},
   {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'},
   {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'},
   {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'},
   {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}],
  'Parameters sampled': [12, 12, 12, 12, 12],
  'Cross validation time (in s)': [7.966981410980225,
   8.186244010925293,
   8.348015546798706,
   8.93733835220337,
   8.348435401916504],
  'Mean Test accuracy': 0.9727521858182637,
  'Mean Cross validation time (in s)': 8.35740294456482,
  'Mean Parameters sampled': 12.0}]