<a href="https://colab.research.google.com/github/sp8rks/MaterialsInformatics/blob/main/worked_examples/hyperparameter_opt/materials_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Grid vs. Random Search Hyperparameter Optimization

## Setup

### Installation

In [1]:
!pip install matbench
!pip install CBFV

Collecting matbench
  Downloading matbench-0.5-py3-none-any.whl (9.9 MB)
[K     |████████████████████████████████| 9.9 MB 6.8 MB/s 
[?25hCollecting monty==2021.8.17
  Downloading monty-2021.8.17-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 2.1 MB/s 
[?25hCollecting matminer==0.7.4
  Downloading matminer-0.7.4-py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 3.7 MB/s 
[?25hCollecting scikit-learn==1.0
  Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)
[K     |████████████████████████████████| 23.1 MB 2.3 MB/s 
Collecting numpy>=1.21.1
  Downloading numpy-1.21.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
[K     |████████████████████████████████| 15.7 MB 351 kB/s 
[?25hCollecting pint>=0.17
  Downloading Pint-0.18-py2.py3-none-any.whl (209 kB)
[K     |████████████████████████████████| 209 kB 50.4 MB/s 
Collecting requests>=2.26.0
  Downloading requests-2

Collecting CBFV
  Downloading CBFV-1.1.0-py3-none-any.whl (539 kB)
[?25l[K     |▋                               | 10 kB 30.6 MB/s eta 0:00:01[K     |█▏                              | 20 kB 25.6 MB/s eta 0:00:01[K     |█▉                              | 30 kB 17.2 MB/s eta 0:00:01[K     |██▍                             | 40 kB 15.2 MB/s eta 0:00:01[K     |███                             | 51 kB 9.1 MB/s eta 0:00:01[K     |███▋                            | 61 kB 8.8 MB/s eta 0:00:01[K     |████▎                           | 71 kB 9.6 MB/s eta 0:00:01[K     |████▉                           | 81 kB 10.6 MB/s eta 0:00:01[K     |█████▌                          | 92 kB 10.3 MB/s eta 0:00:01[K     |██████                          | 102 kB 8.6 MB/s eta 0:00:01[K     |██████▊                         | 112 kB 8.6 MB/s eta 0:00:01[K     |███████▎                        | 122 kB 8.6 MB/s eta 0:00:01[K     |████████                        | 133 kB 8.6 MB/s eta 0:00:01[K   

### Imports

In [2]:
import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.tree import DecisionTreeClassifier

from scipy.stats import randint

from matbench.bench import MatbenchBenchmark
from CBFV.composition import generate_features

### Data

In [3]:
mb = MatbenchBenchmark(subset=["matbench_expt_is_metal"])
task = list(mb.tasks)[0]
task.load()
fold0 = task.folds[0]
train_inputs, train_outputs = task.get_train_and_val_data(fold0)
test_inputs, test_outputs = task.get_test_data(fold0, include_target=True)
print(train_inputs[0:2], train_outputs[0:2])
print(train_outputs.shape, test_outputs.shape)
        

2022-02-09 21:36:57 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_expt_is_metal']
2022-02-09 21:36:57 INFO     Loading dataset 'matbench_expt_is_metal'...
Fetching matbench_expt_is_metal.json.gz from https://ml.materialsproject.org/projects/matbench_expt_is_metal.json.gz to /usr/local/lib/python3.7/dist-packages/matminer/datasets/matbench_expt_is_metal.json.gz


Fetching https://ml.materialsproject.org/projects/matbench_expt_is_metal.json.gz in MB: 0.034816MB [00:00, 40.50MB/s]                 

2022-02-09 21:36:57 INFO     Dataset 'matbench_expt_is_metal loaded.
mbid
mb-expt-is-metal-0001      Ag(AuS)2
mb-expt-is-metal-0002    Ag(W3Br7)2
Name: composition, dtype: object mbid
mb-expt-is-metal-0001    True
mb-expt-is-metal-0002    True
Name: is_metal, dtype: bool
(3936,) (985,)





In [4]:
train_inputs.describe()

count         3936
unique        3936
top       Ag(AuS)2
freq             1
Name: composition, dtype: object

In [5]:
train_outputs.describe()

count      3936
unique        2
top       False
freq       1976
Name: is_metal, dtype: object

In [6]:
train_df = pd.DataFrame({"formula": train_inputs, "target": train_outputs})
test_df = pd.DataFrame({"formula": test_inputs, "target": test_outputs})
train_df

Unnamed: 0_level_0,formula,target
mbid,Unnamed: 1_level_1,Unnamed: 2_level_1
mb-expt-is-metal-0001,Ag(AuS)2,True
mb-expt-is-metal-0002,Ag(W3Br7)2,True
mb-expt-is-metal-0003,Ag0.5Ge1Pb1.75S4,False
mb-expt-is-metal-0005,Ag2BBr,True
mb-expt-is-metal-0006,Ag2BiO3,True
...,...,...
mb-expt-is-metal-4916,ZrSiTe,True
mb-expt-is-metal-4917,ZrTaN3,False
mb-expt-is-metal-4918,ZrTe,True
mb-expt-is-metal-4920,ZrTiF6,True


In [14]:
X_train, y_train, _, _ = generate_features(train_df)
X_train

Processing Input Data: 100%|██████████| 3936/3936 [00:00<00:00, 15165.79it/s]


	Featurizing Compositions...


Assigning Features...: 100%|██████████| 3936/3936 [00:00<00:00, 8136.75it/s]


	Creating Pandas Objects...


Unnamed: 0,avg_Atomic_Number,avg_Atomic_Weight,avg_Period,avg_group,avg_families,avg_Metal,avg_Nonmetal,avg_Metalliod,avg_Mendeleev_Number,avg_l_quantum_number,avg_Atomic_Radius,avg_Miracle_Radius_[pm],avg_Covalent_Radius,avg_Zunger_radii_sum,avg_ionic_radius,avg_crystal_radius,avg_Pauling_Electronegativity,avg_MB_electonegativity,avg_Gordy_electonegativity,avg_Mulliken_EN,avg_Allred-Rockow_electronegativity,avg_metallic_valence,avg_number_of_valence_electrons,avg_gilmor_number_of_valence_electron,avg_valence_s,avg_valence_p,avg_valence_d,avg_valence_f,avg_Number_of_unfilled_s_valence_electrons,avg_Number_of_unfilled_p_valence_electrons,avg_Number_of_unfilled_d_valence_electrons,avg_Number_of_unfilled_f_valence_electrons,avg_outer_shell_electrons,avg_1st_ionization_potential_(kJ/mol),avg_polarizability(A^3),avg_Melting_point_(K),avg_Boiling_Point_(K),avg_Density_(g/mL),avg_specific_heat_(J/g_K)_,avg_heat_of_fusion_(kJ/mol)_,...,mode_families,mode_Metal,mode_Nonmetal,mode_Metalliod,mode_Mendeleev_Number,mode_l_quantum_number,mode_Atomic_Radius,mode_Miracle_Radius_[pm],mode_Covalent_Radius,mode_Zunger_radii_sum,mode_ionic_radius,mode_crystal_radius,mode_Pauling_Electronegativity,mode_MB_electonegativity,mode_Gordy_electonegativity,mode_Mulliken_EN,mode_Allred-Rockow_electronegativity,mode_metallic_valence,mode_number_of_valence_electrons,mode_gilmor_number_of_valence_electron,mode_valence_s,mode_valence_p,mode_valence_d,mode_valence_f,mode_Number_of_unfilled_s_valence_electrons,mode_Number_of_unfilled_p_valence_electrons,mode_Number_of_unfilled_d_valence_electrons,mode_Number_of_unfilled_f_valence_electrons,mode_outer_shell_electrons,mode_1st_ionization_potential_(kJ/mol),mode_polarizability(A^3),mode_Melting_point_(K),mode_Boiling_Point_(K),mode_Density_(g/mL),mode_specific_heat_(J/g_K)_,mode_heat_of_fusion_(kJ/mol)_,mode_heat_of_vaporization_(kJ/mol)_,mode_thermal_conductivity_(W/(m_K))_,mode_heat_atomization(kJ/mol),mode_Cohesive_energy
0,47.400000,113.186656,4.600000,13.000000,5.200000,0.600000,0.400000,0.000000,74.600000,1.200000,1.378000,127.200000,1.290000,1.979000,1.260000,1.034000,2.434000,1.750000,6.300400,5.684000,2.177600,4.064000,9.000000,3.600000,1.400000,1.600000,6.000000,5.600000,0.600000,4.400000,4.000000,8.400000,3.000000,902.200000,5.180000,936.270000,2125.430000,10.648000,0.382200,7.967000,...,4.0,0.0,0.0,0.0,66.0,1.0,0.88,103.0,1.02,1.100,1.00,0.43,2.54,1.19,5.8458,5.77,1.920,2.00,6.0,2.0,1.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,1.0,890.0,2.900,385.95,717.85,2.07000,0.128,1.71750,9.8000,0.26900,279.0,2.85
1,46.714286,110.931629,4.619048,13.571429,6.666667,0.333333,0.666667,0.000000,81.000000,1.238095,1.256667,133.242647,1.250000,1.694524,1.228571,1.486190,2.739524,2.449048,6.393686,6.528571,2.299048,1.910476,6.904762,5.904762,1.952381,3.333333,1.619048,4.000000,0.047619,2.666667,8.380952,10.000000,5.285714,1014.809524,5.614286,1288.445238,2034.826190,8.094286,0.363667,14.176381,...,8.0,0.0,1.0,0.0,95.0,1.0,0.94,115.0,1.14,1.200,1.15,1.82,2.96,2.83,6.4079,7.59,2.685,0.00,7.0,7.0,2.0,5.0,0.0,0.0,0.0,1.0,10.0,14.0,7.0,1140.0,3.100,265.95,331.95,3.12000,0.473,5.28600,15.4380,0.12200,112.0,1.22
2,36.275862,85.159738,4.000000,14.896552,6.172414,0.310345,0.551724,0.137931,83.482759,0.931034,1.143448,124.482759,1.191379,1.490345,1.268966,0.758966,2.285172,2.273793,5.275255,5.313793,2.279931,2.619310,5.586207,4.965517,1.931034,2.965517,3.103448,3.379310,0.068966,3.034483,6.896552,10.620690,4.896552,880.068966,4.627586,611.456897,1481.398276,5.351724,0.483448,7.980448,...,7.0,0.0,1.0,0.0,88.0,1.0,0.88,103.0,1.02,1.100,1.00,0.43,2.58,2.65,5.8458,6.22,2.589,2.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,1000.0,2.900,385.95,717.85,2.07000,0.710,1.71750,9.8000,0.26900,279.0,2.85
3,33.500000,76.612850,4.000000,13.000000,5.500000,0.500000,0.250000,0.250000,74.250000,0.500000,1.277500,133.242647,1.255000,1.686250,1.300000,1.162500,2.215000,1.717500,4.676850,5.190000,2.119000,3.470000,8.000000,3.500000,1.500000,1.500000,5.000000,0.000000,0.500000,4.500000,5.000000,14.000000,3.000000,850.750000,5.475000,1272.100000,2394.312500,6.615000,0.490750,19.521500,...,4.0,1.0,0.0,0.0,65.0,0.0,1.65,144.0,1.53,2.375,1.60,1.29,1.93,1.07,3.8522,4.44,1.870,5.44,11.0,2.0,1.0,0.0,10.0,0.0,1.0,6.0,0.0,14.0,1.0,731.0,7.900,1235.15,2485.15,10.50000,0.235,11.30000,250.5800,429.00000,284.0,2.95
4,33.500000,78.785828,3.666667,14.166667,5.666667,0.500000,0.500000,0.000000,79.500000,0.666667,1.028333,107.000000,1.118333,1.357000,1.100000,1.230000,2.700000,2.373333,6.274017,6.031667,2.763333,2.313333,7.500000,4.500000,1.666667,2.500000,5.000000,2.333333,0.333333,3.500000,5.000000,11.666667,4.166667,1017.833333,4.263167,529.783333,1178.983333,5.125715,0.558333,5.761295,...,7.0,0.0,1.0,0.0,87.0,1.0,0.48,64.0,0.73,0.465,0.60,1.21,3.44,3.32,8.3703,7.54,3.610,0.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,1314.0,0.793,54.75,90.15,0.00143,0.920,0.22259,3.4099,0.02674,249.0,2.62
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3931,35.333333,82.303167,4.333333,11.333333,5.333333,0.333333,0.666667,0.000000,70.666667,1.333333,1.466667,136.000000,1.313333,1.971667,1.350000,0.790000,1.776667,2.020000,4.359700,4.633333,1.798000,3.333333,4.666667,4.000000,2.000000,2.000000,4.000000,0.000000,0.000000,4.000000,6.000000,14.000000,2.666667,765.333333,9.600000,1510.316667,2847.083333,5.026667,0.393333,28.313333,...,4.0,0.0,0.0,0.0,44.0,1.0,1.11,110.0,1.11,1.420,1.10,0.40,1.33,1.70,3.4521,3.64,1.320,2.00,4.0,2.0,2.0,0.0,0.0,0.0,0.0,2.0,0.0,14.0,2.0,640.0,5.400,722.65,1262.95,2.33000,0.200,16.90000,52.5500,2.35000,197.0,2.19
3932,26.800000,62.838424,3.400000,10.800000,5.800000,0.400000,0.600000,0.000000,67.600000,1.400000,1.148000,103.800000,1.022000,1.447000,0.990000,0.508000,2.390000,2.438000,6.135220,5.930000,2.371600,3.600000,4.800000,4.000000,2.000000,1.800000,1.000000,2.800000,0.000000,4.200000,9.000000,11.200000,3.800000,1121.400000,6.860000,1116.810000,2116.070000,4.622750,0.706000,9.916240,...,7.0,0.0,1.0,0.0,82.0,1.0,0.56,72.0,0.75,0.540,0.65,0.30,3.04,2.85,6.8834,7.30,3.066,3.00,5.0,5.0,2.0,3.0,0.0,0.0,0.0,3.0,10.0,14.0,5.0,1402.0,1.100,63.25,77.35,0.00125,1.040,0.36040,2.7928,0.02598,473.0,4.92
3933,46.000000,109.412000,5.000000,10.000000,5.000000,0.500000,0.500000,0.000000,67.000000,1.500000,1.645000,149.000000,1.415000,2.247500,1.475000,0.985000,1.715000,2.040000,4.446950,4.565000,1.739000,3.000000,5.000000,4.000000,2.000000,2.000000,6.000000,0.000000,0.000000,4.000000,4.000000,14.000000,2.000000,754.500000,11.700000,1423.900000,2956.550000,6.375000,0.235000,17.195000,...,4.0,0.0,0.0,0.0,44.0,1.0,1.23,140.0,1.35,1.670,1.40,0.86,1.33,1.70,3.4521,3.64,1.320,2.00,4.0,2.0,2.0,0.0,2.0,0.0,0.0,2.0,0.0,14.0,2.0,640.0,5.500,722.65,1262.95,6.24000,0.200,16.90000,52.5500,2.35000,197.0,2.19
3934,14.500000,31.636802,2.625000,13.750000,7.000000,0.250000,0.750000,0.000000,80.625000,1.250000,0.792500,133.242647,0.887500,0.979375,0.743750,1.093750,3.343750,3.280000,8.387550,8.693750,3.482250,1.000000,6.250000,5.750000,2.000000,3.750000,0.500000,0.000000,0.000000,2.250000,9.500000,14.000000,5.750000,1423.125000,4.538000,547.300000,1090.075000,1.382525,0.713750,4.235150,...,8.0,0.0,1.0,0.0,93.0,1.0,0.42,115.0,0.71,0.405,0.50,1.19,3.98,3.78,10.0854,10.41,4.193,0.00,7.0,7.0,2.0,5.0,0.0,0.0,0.0,1.0,10.0,14.0,7.0,1681.0,0.634,53.35,85.05,0.00170,0.820,0.25520,3.2698,0.02790,79.0,0.84


In [15]:
X_test, y_test, _, _ = generate_features(test_df)
X_test

Processing Input Data: 100%|██████████| 985/985 [00:00<00:00, 15770.47it/s]


	Featurizing Compositions...


Assigning Features...: 100%|██████████| 985/985 [00:00<00:00, 7830.71it/s]


	Creating Pandas Objects...


Unnamed: 0,avg_Atomic_Number,avg_Atomic_Weight,avg_Period,avg_group,avg_families,avg_Metal,avg_Nonmetal,avg_Metalliod,avg_Mendeleev_Number,avg_l_quantum_number,avg_Atomic_Radius,avg_Miracle_Radius_[pm],avg_Covalent_Radius,avg_Zunger_radii_sum,avg_ionic_radius,avg_crystal_radius,avg_Pauling_Electronegativity,avg_MB_electonegativity,avg_Gordy_electonegativity,avg_Mulliken_EN,avg_Allred-Rockow_electronegativity,avg_metallic_valence,avg_number_of_valence_electrons,avg_gilmor_number_of_valence_electron,avg_valence_s,avg_valence_p,avg_valence_d,avg_valence_f,avg_Number_of_unfilled_s_valence_electrons,avg_Number_of_unfilled_p_valence_electrons,avg_Number_of_unfilled_d_valence_electrons,avg_Number_of_unfilled_f_valence_electrons,avg_outer_shell_electrons,avg_1st_ionization_potential_(kJ/mol),avg_polarizability(A^3),avg_Melting_point_(K),avg_Boiling_Point_(K),avg_Density_(g/mL),avg_specific_heat_(J/g_K)_,avg_heat_of_fusion_(kJ/mol)_,...,mode_families,mode_Metal,mode_Nonmetal,mode_Metalliod,mode_Mendeleev_Number,mode_l_quantum_number,mode_Atomic_Radius,mode_Miracle_Radius_[pm],mode_Covalent_Radius,mode_Zunger_radii_sum,mode_ionic_radius,mode_crystal_radius,mode_Pauling_Electronegativity,mode_MB_electonegativity,mode_Gordy_electonegativity,mode_Mulliken_EN,mode_Allred-Rockow_electronegativity,mode_metallic_valence,mode_number_of_valence_electrons,mode_gilmor_number_of_valence_electron,mode_valence_s,mode_valence_p,mode_valence_d,mode_valence_f,mode_Number_of_unfilled_s_valence_electrons,mode_Number_of_unfilled_p_valence_electrons,mode_Number_of_unfilled_d_valence_electrons,mode_Number_of_unfilled_f_valence_electrons,mode_outer_shell_electrons,mode_1st_ionization_potential_(kJ/mol),mode_polarizability(A^3),mode_Melting_point_(K),mode_Boiling_Point_(K),mode_Density_(g/mL),mode_specific_heat_(J/g_K)_,mode_heat_of_fusion_(kJ/mol)_,mode_heat_of_vaporization_(kJ/mol)_,mode_thermal_conductivity_(W/(m_K))_,mode_heat_atomization(kJ/mol),mode_Cohesive_energy
0,46.206897,111.032290,4.551724,14.896552,6.172414,0.310345,0.551724,0.137931,84.034483,0.931034,1.226207,132.758621,1.268621,1.592414,1.351724,0.830690,2.268621,2.213103,5.228469,5.131724,2.194414,2.619310,5.586207,4.965517,1.931034,2.965517,3.103448,3.37931,0.068966,3.034483,6.896552,10.62069,4.896552,847.517241,5.124138,668.946552,1613.977586,6.852414,0.268276,10.726103,...,7.0,0.0,1.0,0.0,89.0,1.0,1.03,118.0,1.16,1.285,1.15,0.56,2.55,2.54,5.7610,5.89,2.434,2.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,941.0,3.800,490.15,958.15,4.79000,0.320,6.69400,37.7000,0.52000,227.0,2.46
1,25.750000,57.815474,3.375000,14.000000,5.750000,0.375000,0.500000,0.125000,78.375000,0.625000,1.001250,100.375000,1.087500,1.300000,1.043750,1.178750,2.716250,2.345000,6.276225,6.097500,2.782625,2.415000,7.750000,4.375000,1.625000,2.375000,3.750000,0.00000,0.375000,3.625000,6.250000,14.00000,4.000000,1049.500000,3.896500,626.825000,1088.275000,4.654465,0.589375,7.811295,...,7.0,0.0,1.0,0.0,87.0,1.0,0.48,64.0,0.73,0.465,0.60,1.21,3.44,3.32,8.3703,7.54,3.610,0.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,1314.0,0.793,54.75,90.15,0.00143,0.920,0.22259,3.4099,0.02674,249.0,2.62
2,46.750000,107.506150,5.000000,10.750000,4.000000,1.000000,0.000000,0.000000,64.250000,0.500000,1.660000,143.500000,1.475000,2.393750,1.550000,1.162500,1.997500,1.322500,3.836150,4.442500,1.800000,5.525000,10.750000,2.250000,0.750000,0.000000,10.000000,0.00000,1.250000,6.000000,0.000000,14.00000,0.750000,749.250000,7.125000,1383.150000,2717.150000,10.875000,0.236250,12.875000,...,4.0,1.0,0.0,0.0,65.0,0.0,1.65,144.0,1.53,2.375,1.60,1.29,1.93,1.07,3.8522,4.44,1.870,5.44,11.0,2.0,1.0,0.0,10.0,0.0,1.0,6.0,0.0,14.0,1.0,731.0,7.900,1235.15,2485.15,10.50000,0.235,11.30000,250.5800,429.00000,284.0,2.95
3,27.125000,61.084025,3.500000,13.125000,5.500000,0.500000,0.500000,0.000000,74.875000,0.750000,1.081250,102.750000,1.096250,1.448750,1.062500,1.191250,2.718750,2.307500,6.088075,5.997500,2.698750,2.762500,8.125000,4.375000,1.500000,2.000000,4.625000,0.00000,0.500000,4.000000,5.375000,14.00000,3.500000,1019.875000,4.559000,813.450000,1498.650000,5.488215,0.577875,7.348795,...,7.0,0.0,1.0,0.0,87.0,1.0,0.48,64.0,0.73,0.465,0.60,1.21,3.44,3.32,8.3703,7.54,3.610,0.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,1314.0,0.793,54.75,90.15,0.00143,0.920,0.22259,3.4099,0.02674,249.0,2.62
4,42.454545,97.547122,4.636364,13.000000,5.272727,0.636364,0.363636,0.000000,74.818182,0.363636,1.419091,133.403704,1.400909,1.999545,1.454545,1.273636,2.180909,1.664545,4.592800,5.185455,2.089636,3.825455,9.363636,3.636364,1.363636,1.636364,8.181818,0.00000,0.636364,4.363636,1.818182,14.00000,3.000000,830.272727,6.463636,926.477273,1795.095455,7.954545,0.317545,8.925727,...,4.0,1.0,0.0,0.0,65.0,0.0,1.65,144.0,1.53,2.375,1.60,1.29,1.93,1.07,3.8522,4.44,1.870,5.44,11.0,2.0,1.0,0.0,10.0,0.0,1.0,6.0,0.0,14.0,1.0,731.0,7.900,1235.15,2485.15,10.50000,0.235,11.30000,250.5800,429.00000,284.0,2.95
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,24.000000,51.785333,3.666667,12.000000,6.000000,0.333333,0.666667,0.000000,73.333333,1.333333,1.273333,121.333333,1.173333,1.675000,1.183333,0.573333,2.163333,2.333333,5.047900,5.360000,2.166000,2.666667,5.333333,4.666667,2.000000,2.666667,0.666667,0.00000,0.000000,3.333333,9.333333,14.00000,4.666667,880.000000,7.900000,965.683333,2028.616667,3.550000,0.563333,6.778333,...,7.0,0.0,1.0,0.0,88.0,1.0,0.88,103.0,1.02,1.100,1.00,0.43,2.58,2.65,5.8458,6.22,2.589,2.00,6.0,6.0,2.0,4.0,0.0,0.0,0.0,2.0,10.0,14.0,6.0,1000.0,2.900,385.95,717.85,2.07000,0.710,1.71750,9.8000,0.26900,279.0,2.85
981,45.000000,104.684667,5.000000,9.000000,4.666667,0.666667,0.333333,0.000000,61.666667,1.666667,1.723333,149.000000,1.373333,2.398333,1.433333,0.860000,1.860000,1.936667,4.038067,4.330000,1.614667,4.260000,5.666667,4.000000,1.666667,1.000000,6.333333,0.00000,0.333333,5.000000,3.666667,14.00000,1.666667,728.000000,11.366667,1870.816667,3682.150000,8.533333,0.239333,20.256667,...,4.0,0.0,0.0,0.0,44.0,1.0,1.33,134.0,1.26,1.765,1.30,0.82,1.33,1.70,3.4521,3.64,1.320,3.00,4.0,2.0,1.0,0.0,2.0,0.0,0.0,3.0,0.0,14.0,1.0,640.0,6.600,904.15,2223.15,6.51000,0.210,16.90000,77.1400,22.70000,262.0,2.75
982,37.000000,85.092000,4.500000,10.000000,5.500000,0.500000,0.500000,0.000000,66.500000,1.500000,1.545000,138.000000,1.320000,2.055000,1.350000,0.710000,1.940000,2.120000,4.606550,4.765000,1.877000,3.000000,5.000000,4.000000,2.000000,2.000000,1.000000,0.00000,0.000000,4.000000,9.000000,14.00000,4.000000,790.500000,10.850000,1307.650000,2804.150000,5.650000,0.295000,11.797000,...,4.0,0.0,0.0,0.0,44.0,1.0,1.03,118.0,1.16,1.285,1.15,0.56,1.33,1.70,3.4521,3.64,1.320,2.00,4.0,2.0,2.0,0.0,0.0,0.0,0.0,2.0,8.0,14.0,2.0,640.0,3.800,490.15,958.15,4.79000,0.270,6.69400,37.7000,0.52000,227.0,2.46
983,22.666667,49.131667,3.666667,10.666667,5.333333,0.333333,0.666667,0.000000,66.666667,1.333333,1.426667,126.000000,1.233333,1.888333,1.250000,0.553333,1.710000,1.886667,3.940833,4.393333,1.717333,4.000000,4.000000,3.333333,2.000000,1.333333,0.666667,0.00000,0.000000,4.666667,9.333333,14.00000,3.333333,738.000000,9.566667,1830.483333,3302.150000,3.723333,0.563333,39.333333,...,6.0,0.0,1.0,0.0,78.0,1.0,1.11,110.0,1.11,1.420,1.10,0.40,1.90,1.98,4.1852,4.77,1.916,4.00,4.0,4.0,2.0,2.0,0.0,0.0,0.0,4.0,10.0,14.0,4.0,787.0,5.400,1683.15,2628.15,2.33000,0.710,50.55000,384.2200,148.00000,452.0,4.63


## Train

We can do hyperparameter tuning in different ways. Two common ways are grid search (less efficient) and random search (more efficient). Below are examples taken/modified from the website https://www.geeksforgeeks.org/hyperparameter-tuning/


In [16]:
#Grid search first using logistic regression classifier model

# Creating the hyperparameter grid
c_space = np.logspace(-5, 8, 15)
param_grid = {'C': c_space}
  
# Instantiating logistic regression classifier
# https://stats.stackexchange.com/a/184026/293880
logreg = LogisticRegression(max_iter=100)
  
# Instantiating the GridSearchCV object
logreg_grid = GridSearchCV(logreg, param_grid, cv = 5)
  
logreg_grid.fit(X_train, y_train)
  
# Print the tuned parameters and score
print("Grid tuned Logistic Regression Parameters: {}".format(logreg_grid.best_params_)) 
print("Best score is {}".format(logreg_grid.best_score_))

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logist

Grid tuned Logistic Regression Parameters: {'C': 19306.977288832535}
Best score is 0.8191010003934494


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


In [17]:
#Now we can try random search with logistic regression
  
# Creating the hyperparameter grid 
param_dist = {"C": randint(-5,15)}
  
# Instantiating Decision Tree classifier
logreg = LogisticRegression()
  
# Instantiating RandomizedSearchCV object
logreg_random = RandomizedSearchCV(logreg, param_dist, cv = 5)
  
logreg_random.fit(X_train, y_train)
  
# Print the tuned parameters and score
print("Random tuned Logistic Regression Parameters: {}".format(logreg_random.best_params_))
print("Best score is {}".format(logreg_random.best_score_))


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logist

Random tuned Logistic Regression Parameters: {'C': 3}
Best score is 0.820117841317346


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


We can do the same grid vs random search with another model, like a decision tree classifier

In [18]:
#grid search for decision tree hyperparameters
  
# Creating the hyperparameter grid 
param_grid = {"max_depth": range(1,10),
              "max_features": range(1,10),
              "min_samples_leaf": range(1,10),
              "criterion": ["gini", "entropy"]}

# Instantiating Decision Tree classifier
tree = DecisionTreeClassifier()
  
# Instantiating GridSearchCV object
tree_grid = GridSearchCV(tree, param_grid, cv = 5)
  
tree_grid.fit(X_train, y_train)
  
# Print the tuned parameters and score
print("Grid tuned Decision Tree Parameters: {}".format(tree_grid.best_params_))
print("Best score is {}".format(tree_grid.best_score_))


Grid tuned Decision Tree Parameters: {'criterion': 'gini', 'max_depth': 6, 'max_features': 7, 'min_samples_leaf': 5}
Best score is 0.850099652345539


In [19]:
#random search for decision tree hyperparameters
  
# Creating the hyperparameter grid 
param_dist = {"max_depth": randint(1,10),
              "max_features": randint(1,10),
              "min_samples_leaf": randint(1,10),
              "criterion": ["gini", "entropy"]}

# Instantiating Decision Tree classifier
tree = DecisionTreeClassifier()
  
# Instantiating RandomizedSearchCV object
tree_random = RandomizedSearchCV(tree, param_dist, cv = 5)
  
tree_random.fit(X_train, y_train)
  
# Print the tuned parameters and score
print("Random tuned Decision Tree Parameters: {}".format(tree_random.best_params_))
print("Best score is {}".format(tree_random.best_score_))


Random tuned Decision Tree Parameters: {'criterion': 'entropy', 'max_depth': 8, 'max_features': 5, 'min_samples_leaf': 6}
Best score is 0.8219051335470429


## Test

## Code Graveyard

In [None]:
# from sklearn.preprocessing import StandardScaler
# scaler = StandardScaler()
# scaler.fit(X_train)
# X_train = scaler.transform(X_train)
# X_test = scaler.transform(X_test)
# X_train