![](https://scikit-learn.org/stable/_images/grid_search_workflow.png)

In [80]:
import numpy as np

from sklearn.datasets import make_classification
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV

from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier

In [81]:
X, y = make_classification(1500, 1000, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

In [82]:
X.shape

(1500, 1000)

In [83]:
y.shape

(1500,)

In [84]:
class NDD(nn.Module):
    def __init__(self, D_in=1000, H1=500, H2=300, D_out=2, drop=0.5):
        super(NDD, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(D_in, H1) # Fully Connected
        self.fc2 = nn.Linear(H1, H2)
        self.fc3 = nn.Linear(H2, D_out)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.drop(x)
        x = F.relu(self.fc2(x))
        x = self.drop(x)
        x = self.fc3(x)
        return x

In [85]:
net = NeuralNetClassifier(
    NDD,
    criterion=nn.CrossEntropyLoss,
    #max_epochs=10,
    #lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

In [86]:
# pipe = Pipeline([
#     ('net', net),
# ])

In [87]:
# pipe.fit(X, y)
# y_proba = pipe.predict_proba(X)

In [91]:
params = {
    'lr': [0.1],
    'max_epochs': [10],
    'module__H1': [500, 400, 300],
    'module__H2': [300, 200, 100],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7109[0m       [32m0.5075[0m        [35m0.6912[0m  0.1154
      2        [36m0.6817[0m       0.5025        0.7103  0.0895
      3        [36m0.6656[0m       [32m0.6169[0m        [35m0.6651[0m  0.1226
      4        [36m0.6288[0m       [32m0.6219[0m        [35m0.6534[0m  0.0896
      5        [36m0.5882[0m       [32m0.6318[0m        [35m0.6509[0m  0.0849
      6        [36m0.5403[0m       [32m0.6368[0m        [35m0.6340[0m  0.0920
      7        [36m0.4915[0m       [32m0.6418[0m        [35m0.6200[0m  0.0920
      8        [36m0.4123[0m       [32m0.6617[0m        0.6218  0.1025
      9        [36m0.3359[0m       0.6418        0.6733  0.0771
     10        [36m0.2631[0m       [32m0.6716[0m        0.6837  0.0683
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  -------

      3        [36m0.6646[0m       0.5522        [35m0.6696[0m  0.0618
      4        [36m0.6375[0m       [32m0.5871[0m        [35m0.6589[0m  0.0581
      5        [36m0.6023[0m       [32m0.6418[0m        [35m0.6415[0m  0.0731
      6        [36m0.5636[0m       0.5622        0.6623  0.0616
      7        [36m0.4837[0m       0.6020        [35m0.6308[0m  0.0549
      8        [36m0.4353[0m       0.5920        0.6370  0.0569
      9        [36m0.3404[0m       [32m0.6517[0m        0.6653  0.0601
     10        [36m0.2734[0m       0.6318        0.6856  0.0567
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6891[0m       [32m0.5622[0m        [35m0.6850[0m  0.0633
      2        [36m0.6742[0m       [32m0.5920[0m        [35m0.6771[0m  0.0658
      3        [36m0.6625[0m       [32m0.6020[0m        [35m0.6703[0m  0.0563
      4        [36m0.6338[0m       0.5920

      6        [36m0.5801[0m       0.5871        [35m0.6427[0m  0.0511
      7        [36m0.5326[0m       [32m0.6766[0m        [35m0.6282[0m  0.0512
      8        [36m0.4646[0m       [32m0.7015[0m        [35m0.6211[0m  0.0489
      9        [36m0.4048[0m       0.6866        0.6261  0.0508
     10        [36m0.3012[0m       0.6816        0.6518  0.0564
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7014[0m       [32m0.5274[0m        [35m0.6924[0m  0.0629
      2        [36m0.6799[0m       [32m0.5672[0m        [35m0.6883[0m  0.0504
      3        [36m0.6701[0m       [32m0.5821[0m        [35m0.6793[0m  0.0506
      4        [36m0.6479[0m       [32m0.6119[0m        [35m0.6729[0m  0.0505
      5        [36m0.6230[0m       0.6020        [35m0.6689[0m  0.0507
      6        [36m0.5806[0m       [32m0.6269[0m        [35m0.6484[0m  0.0511
      7        