In [1]:
% run 1-datasource.ipynb

In [2]:
import torch
from torch import nn
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

import torch.nn.functional as F
from skorch.net import NeuralNetClassifier
from sklearn.preprocessing import StandardScaler

In [3]:
scaler_1 = StandardScaler()
scaler_2 = StandardScaler()

In [4]:
for _, i in icebergs.iterrows():
    scaler_1.partial_fit(i.band_1.reshape(1, -1))
    scaler_2.partial_fit(i.band_2.reshape(1, -1))

In [5]:
y = icebergs.is_iceberg.values
x = (
    np.stack(
        [
            scaler_1.transform(np.stack(icebergs.band_1)),
            scaler_2.transform(np.stack(icebergs.band_2)),
        ],
        axis=1
    ).reshape(-1, 2, 75, 75)
).astype(np.float32)

In [6]:
x.dtype, y.dtype

(dtype('float32'), dtype('int64'))

In [7]:
type(x), type(y)

(numpy.ndarray, numpy.ndarray)

In [8]:
y.shape, x.shape

((1604,), (1604, 2, 75, 75))

In [9]:
insize1 = 2 
outsize1 = 32 
outsize2 = 32

class Net(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super().__init__()
        

        self.layer1 = nn.Sequential(
            nn.Conv2d(insize1, outsize1, kernel_size=7, stride=1, padding=2, groups=2),
            nn.BatchNorm2d(outsize1),
            nn.ReLU(),
            nn.MaxPool2d(4)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(outsize1, outsize2, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(outsize2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(2592, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 2),
            nn.Softmax(1)
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [10]:
net = NeuralNetClassifier(
    Net,
    optimizer=torch.optim.Adam,
    criterion=nn.CrossEntropyLoss,
    max_epochs=10,
    batch_size=25,
    lr=0.000001,
    use_cuda=True
)

In [11]:
net.fit(x, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6931[0m       [32m0.5994[0m        [35m0.6889[0m  2.2235
      2        [36m0.6877[0m       0.5776        [35m0.6838[0m  1.1326
      3        [36m0.6838[0m       0.5745        [35m0.6794[0m  1.1357
      4        [36m0.6805[0m       0.5776        [35m0.6756[0m  1.2032
      5        [36m0.6775[0m       0.5870        [35m0.6721[0m  1.1298
      6        [36m0.6748[0m       0.5901        [35m0.6690[0m  1.1283
      7        [36m0.6723[0m       0.5901        [35m0.6660[0m  1.1263
      8        [36m0.6699[0m       [32m0.6056[0m        [35m0.6633[0m  1.1498
      9        [36m0.6677[0m       0.6025        [35m0.6606[0m  1.1278
     10        [36m0.6656[0m       0.6056        [35m0.6581[0m  1.1321


<class 'skorch.net.NeuralNetClassifier'>[initialized](
  module_=Net(
    (layer1): Sequential(
      (0): Conv2d (2, 32, kernel_size=(7, 7), stride=(1, 1), padding=(2, 2), groups=2)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), dilation=(1, 1))
    )
    (layer2): Sequential(
      (0): Conv2d (32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
    )
    (fc): Sequential(
      (0): Linear(in_features=2592, out_features=1024)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=512)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=2)
      (5): Softmax()
    )
  ),
)

In [None]:
%%capture output
from sklearn.model_selection import GridSearchCV


params = {
    'lr': [0.000001, 0.000002,0.000004,0.000008],
    'max_epochs': [50,100,150, 200,250],
    'module__num_units': [5,10,15,20,40,80],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='neg_log_loss',n_jobs=1)

In [None]:
gs.fit(x, y)

Re-initializing module!
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6966[0m       [32m0.4977[0m        [35m0.6931[0m  0.7739
      2        [36m0.6891[0m       [32m0.6233[0m        [35m0.6855[0m  0.7584
      3        [36m0.6841[0m       0.5860        [35m0.6802[0m  0.7609
      4        [36m0.6804[0m       0.5581        [35m0.6760[0m  0.7506
      5        [36m0.6774[0m       0.5628        [35m0.6726[0m  0.7486
      6        [36m0.6750[0m       0.5581        [35m0.6698[0m  0.7475
      7        [36m0.6729[0m       0.5581        [35m0.6672[0m  0.7604
      8        [36m0.6710[0m       0.5581        [35m0.6650[0m  0.7591
      9        [36m0.6693[0m       0.5581        [35m0.6629[0m  0.7482
     10        [36m0.6678[0m       0.5581        [35m0.6609[0m  0.7480
     11        [36m0.6663[0m       0.5581        [35m0.6591[0m  0.

In [None]:
%%capture output2
print(gs.best_score_, gs.best_params_)

In [None]:
class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X))
        return X

In [None]:
class BCL2(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss(*args, **kwargs)
        
    def forward(self, input, target, *args, **kwargs):
        return self.loss.forward(input, target.view(-1, 1), *args, **kwargs)
    
    def backward(self, grad_output, *args, **kwargs):
        return self.loss.backward(grad_output, *args, **kwargs)

In [17]:
output2.show()

-0.355908236186 {'max_epochs': 100, 'lr': 4e-06, 'module__num_units': 80}


In [20]:
output.show()

In [43]:
import scipy.stats as sc
param_dis = {
    'lr': sc.uniform(loc=0.000001,scale=0.00001),
    'max_epochs': sc.randint(low=50,high=250)
}

In [47]:
from sklearn.model_selection import RandomizedSearchCV
gsR = RandomizedSearchCV(net, param_dis, refit=False, scoring='neg_log_loss',n_jobs=1,n_iter=5)

In [48]:
gsR.fit(x,y)

Re-initializing module!
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6882[0m       [32m0.5674[0m        [35m0.6838[0m  1.0834
      2        [36m0.6777[0m       0.5581        [35m0.6705[0m  0.7819
      3        [36m0.6727[0m       0.5581        [35m0.6640[0m  0.7701
      4        [36m0.6695[0m       0.5581        [35m0.6594[0m  0.7523
      5        [36m0.6669[0m       0.5581        [35m0.6556[0m  0.7527
      6        [36m0.6646[0m       0.5581        [35m0.6523[0m  0.7509
      7        [36m0.6625[0m       0.5581        [35m0.6494[0m  0.7562
      8        [36m0.6604[0m       0.5581        [35m0.6467[0m  0.7563
      9        [36m0.6584[0m       [32m0.5721[0m        [35m0.6442[0m  0.7594
     10        [36m0.6565[0m       0.5721        [35m0.6419[0m  0.7581
     11        [36m0.6546[0m       [32m0.5767[0m        [35m0.639

     22        [36m0.6335[0m       0.6028        [35m0.6396[0m  0.7485
     23        [36m0.6315[0m       0.6028        [35m0.6380[0m  0.7719
     24        [36m0.6295[0m       [32m0.6168[0m        [35m0.6364[0m  0.7510
     25        [36m0.6274[0m       [32m0.6215[0m        [35m0.6347[0m  0.7463
     26        [36m0.6254[0m       [32m0.6308[0m        [35m0.6331[0m  0.7447
     27        [36m0.6234[0m       0.6308        [35m0.6315[0m  0.7454
     28        [36m0.6213[0m       0.6308        [35m0.6298[0m  0.7495
     29        [36m0.6193[0m       [32m0.6355[0m        [35m0.6281[0m  0.7503
     30        [36m0.6172[0m       [32m0.6449[0m        [35m0.6264[0m  0.7507
     31        [36m0.6151[0m       [32m0.6495[0m        [35m0.6247[0m  0.7463
     32        [36m0.6130[0m       [32m0.6542[0m        [35m0.6229[0m  0.7518
     33        [36m0.6108[0m       [32m0.6589[0m        [35m0.6212[0m  0.7487
     34        [36m0.6087[

     45        [36m0.4760[0m       [32m0.7023[0m        [35m0.5559[0m  0.7486
     46        [36m0.4713[0m       0.7023        [35m0.5524[0m  0.7502
     47        [36m0.4666[0m       [32m0.7116[0m        [35m0.5490[0m  0.7536
     48        [36m0.4619[0m       [32m0.7209[0m        [35m0.5456[0m  0.7589
     49        [36m0.4571[0m       [32m0.7256[0m        [35m0.5421[0m  0.7539
     50        [36m0.4523[0m       0.7256        [35m0.5386[0m  0.7505
     51        [36m0.4475[0m       [32m0.7302[0m        [35m0.5351[0m  0.7531
     52        [36m0.4427[0m       [32m0.7349[0m        [35m0.5316[0m  0.7538
     53        [36m0.4379[0m       0.7349        [35m0.5282[0m  0.7487
     54        [36m0.4331[0m       [32m0.7442[0m        [35m0.5247[0m  0.7471
     55        [36m0.4282[0m       0.7442        [35m0.5212[0m  0.7532
     56        [36m0.4234[0m       [32m0.7488[0m        [35m0.5178[0m  0.7520
     57        [36m0.4186[

     67        [36m0.4163[0m       [32m0.8326[0m        [35m0.4426[0m  0.7548
     68        [36m0.4112[0m       0.8279        [35m0.4396[0m  0.7621
     69        [36m0.4062[0m       0.8279        [35m0.4367[0m  0.7654
     70        [36m0.4012[0m       0.8326        [35m0.4338[0m  0.7582
     71        [36m0.3962[0m       [32m0.8372[0m        [35m0.4309[0m  0.7578
     72        [36m0.3913[0m       [32m0.8419[0m        [35m0.4281[0m  0.7488
     73        [36m0.3864[0m       0.8372        [35m0.4254[0m  0.7516
     74        [36m0.3816[0m       0.8372        [35m0.4227[0m  0.7493
     75        [36m0.3768[0m       0.8419        [35m0.4201[0m  0.7597
     76        [36m0.3720[0m       0.8419        [35m0.4176[0m  0.7519
     77        [36m0.3673[0m       0.8372        [35m0.4151[0m  0.7534
     78        [36m0.3627[0m       0.8326        [35m0.4127[0m  0.7673
     79        [36m0.3580[0m       0.8326        [35m0.4102[0m  0.761

     56        [36m0.4818[0m       0.7523        [35m0.5138[0m  0.7610
     57        [36m0.4776[0m       0.7523        [35m0.5107[0m  0.7565
     58        [36m0.4733[0m       [32m0.7570[0m        [35m0.5077[0m  0.7493
     59        [36m0.4689[0m       0.7570        [35m0.5046[0m  0.7492
     60        [36m0.4646[0m       [32m0.7664[0m        [35m0.5016[0m  0.7506
     61        [36m0.4602[0m       [32m0.7710[0m        [35m0.4986[0m  0.7594
     62        [36m0.4558[0m       [32m0.7757[0m        [35m0.4956[0m  0.7706
     63        [36m0.4514[0m       0.7757        [35m0.4926[0m  0.7583
     64        [36m0.4469[0m       0.7757        [35m0.4895[0m  0.7510
     65        [36m0.4425[0m       0.7710        [35m0.4865[0m  0.7551
     66        [36m0.4380[0m       [32m0.7850[0m        [35m0.4835[0m  0.7514
     67        [36m0.4334[0m       0.7850        [35m0.4805[0m  0.7578
     68        [36m0.4288[0m       0.7850        [35

     44        [36m0.4304[0m       0.7256        [35m0.5194[0m  0.7510
     45        [36m0.4246[0m       [32m0.7349[0m        [35m0.5155[0m  0.7527
     46        [36m0.4189[0m       [32m0.7442[0m        [35m0.5116[0m  0.7508
     47        [36m0.4131[0m       [32m0.7488[0m        [35m0.5076[0m  0.7516
     48        [36m0.4074[0m       0.7488        [35m0.5038[0m  0.7556
     49        [36m0.4017[0m       0.7442        [35m0.4999[0m  0.7540
     50        [36m0.3960[0m       0.7442        [35m0.4962[0m  0.7548
     51        [36m0.3903[0m       0.7488        [35m0.4925[0m  0.7499
     52        [36m0.3847[0m       [32m0.7535[0m        [35m0.4889[0m  0.7558
     53        [36m0.3791[0m       0.7535        [35m0.4853[0m  0.7503
     54        [36m0.3735[0m       [32m0.7581[0m        [35m0.4818[0m  0.7539
     55        [36m0.3680[0m       [32m0.7628[0m        [35m0.4784[0m  0.7628
     56        [36m0.3625[0m       [32m0.77

     33        [36m0.5733[0m       [32m0.7581[0m        [35m0.5545[0m  0.7537
     34        [36m0.5694[0m       [32m0.7628[0m        [35m0.5515[0m  0.7547
     35        [36m0.5655[0m       [32m0.7674[0m        [35m0.5486[0m  0.7514
     36        [36m0.5616[0m       0.7628        [35m0.5458[0m  0.7487
     37        [36m0.5576[0m       0.7674        [35m0.5429[0m  0.7509
     38        [36m0.5535[0m       0.7674        [35m0.5401[0m  0.7506
     39        [36m0.5494[0m       [32m0.7721[0m        [35m0.5372[0m  0.7469
     40        [36m0.5453[0m       0.7721        [35m0.5344[0m  0.7491
     41        [36m0.5411[0m       0.7721        [35m0.5316[0m  0.7516
     42        [36m0.5368[0m       0.7674        [35m0.5287[0m  0.7498
     43        [36m0.5325[0m       0.7721        [35m0.5259[0m  0.7523
     44        [36m0.5282[0m       0.7674        [35m0.5231[0m  0.7516
     45        [36m0.5237[0m       0.7674        [35m0.5203[

    141        [36m0.1363[0m       0.8140        [35m0.3672[0m  0.7465
    142        [36m0.1339[0m       0.8140        [35m0.3670[0m  0.7483
    143        [36m0.1315[0m       [32m0.8186[0m        [35m0.3669[0m  0.7484
    144        [36m0.1291[0m       0.8186        [35m0.3668[0m  0.7476
    145        [36m0.1267[0m       0.8186        [35m0.3667[0m  0.7476
    146        [36m0.1244[0m       0.8186        [35m0.3667[0m  0.7536
    147        [36m0.1221[0m       0.8186        [35m0.3666[0m  0.7491
    148        [36m0.1199[0m       0.8186        [35m0.3666[0m  0.7487
    149        [36m0.1177[0m       0.8186        [35m0.3665[0m  0.7555
    150        [36m0.1155[0m       0.8186        0.3665  0.7519
    151        [36m0.1133[0m       0.8186        0.3666  0.7577
    152        [36m0.1112[0m       0.8186        0.3666  0.7534
    153        [36m0.1091[0m       0.8186        0.3666  0.7474
    154        [36m0.1070[0m       0.8186        0

     69        [36m0.3713[0m       0.8131        [35m0.4344[0m  0.7504
     70        [36m0.3661[0m       0.8131        [35m0.4310[0m  0.7534
     71        [36m0.3609[0m       [32m0.8178[0m        [35m0.4277[0m  0.7580
     72        [36m0.3557[0m       0.8178        [35m0.4244[0m  0.7625
     73        [36m0.3506[0m       0.8178        [35m0.4211[0m  0.7512
     74        [36m0.3455[0m       0.8178        [35m0.4179[0m  0.7503
     75        [36m0.3404[0m       0.8178        [35m0.4148[0m  0.7560
     76        [36m0.3354[0m       0.8178        [35m0.4118[0m  0.7549
     77        [36m0.3304[0m       [32m0.8224[0m        [35m0.4088[0m  0.7535
     78        [36m0.3254[0m       0.8224        [35m0.4058[0m  0.7499
     79        [36m0.3205[0m       0.8224        [35m0.4029[0m  0.7505
     80        [36m0.3156[0m       [32m0.8318[0m        [35m0.4001[0m  0.7482
     81        [36m0.3107[0m       0.8318        [35m0.3972[0m  0.750

Re-initializing module!
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6847[0m       [32m0.5907[0m        [35m0.6837[0m  0.7579
      2        [36m0.6674[0m       [32m0.5953[0m        [35m0.6746[0m  0.7633
      3        [36m0.6549[0m       0.5907        [35m0.6689[0m  0.7563
      4        [36m0.6447[0m       [32m0.6047[0m        [35m0.6642[0m  0.7476
      5        [36m0.6358[0m       [32m0.6093[0m        [35m0.6596[0m  0.7491
      6        [36m0.6279[0m       0.6047        [35m0.6549[0m  0.7464
      7        [36m0.6205[0m       [32m0.6140[0m        [35m0.6503[0m  0.7480
      8        [36m0.6134[0m       [32m0.6279[0m        [35m0.6455[0m  0.7462
      9        [36m0.6067[0m       [32m0.6326[0m        [35m0.6409[0m  0.7463
     10        [36m0.6002[0m       0.6279        [35m0.6363[0m  0.7474
     11        [36m0.593

    105        [36m0.1032[0m       0.8047        0.3976  0.7572
    106        [36m0.1006[0m       0.8047        0.3977  0.7621
    107        [36m0.0980[0m       0.8000        0.3979  0.7707
    108        [36m0.0955[0m       0.8000        0.3982  0.7541
    109        [36m0.0931[0m       0.8000        0.3985  0.7507
    110        [36m0.0907[0m       0.8000        0.3988  0.7501
    111        [36m0.0883[0m       0.8000        0.3992  0.7511
    112        [36m0.0859[0m       0.8000        0.3996  0.7658
    113        [36m0.0837[0m       0.8000        0.4001  0.7509
    114        [36m0.0814[0m       0.8000        0.4006  0.7588
    115        [36m0.0792[0m       0.8000        0.4012  0.7507
    116        [36m0.0771[0m       0.8000        0.4018  0.7631
    117        [36m0.0750[0m       0.8047        0.4025  0.7537
    118        [36m0.0730[0m       0.8047        0.4032  0.7523
    119        [36m0.0710[0m       0.8047        0.4039  0.7548
    120   

     40        [36m0.3907[0m       0.8233        [35m0.4080[0m  0.7527
     41        [36m0.3824[0m       0.8186        [35m0.4031[0m  0.7564
     42        [36m0.3742[0m       0.8233        [35m0.3984[0m  0.7503
     43        [36m0.3659[0m       0.8186        [35m0.3938[0m  0.7495
     44        [36m0.3579[0m       0.8186        [35m0.3894[0m  0.7501
     45        [36m0.3501[0m       0.8233        [35m0.3852[0m  0.7595
     46        [36m0.3424[0m       [32m0.8326[0m        [35m0.3811[0m  0.7540
     47        [36m0.3349[0m       0.8326        [35m0.3772[0m  0.7504
     48        [36m0.3275[0m       0.8326        [35m0.3735[0m  0.7529
     49        [36m0.3202[0m       [32m0.8372[0m        [35m0.3700[0m  0.7617
     50        [36m0.3131[0m       [32m0.8419[0m        [35m0.3666[0m  0.7558
     51        [36m0.3062[0m       0.8419        [35m0.3634[0m  0.7478
     52        [36m0.2994[0m       [32m0.8465[0m        [35m0.3603[

    157        [36m0.0124[0m       0.8512        0.3736  0.7513
    158        [36m0.0120[0m       0.8512        0.3750  0.7528
    159        [36m0.0117[0m       0.8512        0.3763  0.7477
    160        [36m0.0113[0m       0.8512        0.3775  0.7450
    161        [36m0.0109[0m       0.8512        0.3788  0.7700
    162        [36m0.0106[0m       0.8512        0.3802  0.7497
    163        [36m0.0103[0m       0.8512        0.3814  0.7555
    164        [36m0.0099[0m       0.8512        0.3828  0.7502
    165        [36m0.0096[0m       0.8512        0.3840  0.7549
    166        [36m0.0093[0m       0.8512        0.3854  0.7490
    167        [36m0.0090[0m       0.8512        0.3867  0.7592
    168        [36m0.0088[0m       0.8512        0.3880  0.7529
    169        [36m0.0085[0m       0.8512        0.3894  0.7548
    170        [36m0.0082[0m       0.8512        0.3907  0.7504
    171        [36m0.0080[0m       0.8512        0.3920  0.7490
    172   

     53        [36m0.3027[0m       0.8364        [35m0.3929[0m  0.7557
     54        [36m0.2953[0m       0.8364        [35m0.3891[0m  0.7565
     55        [36m0.2878[0m       [32m0.8411[0m        [35m0.3856[0m  0.7542
     56        [36m0.2806[0m       0.8364        [35m0.3821[0m  0.7556
     57        [36m0.2734[0m       0.8364        [35m0.3789[0m  0.7581
     58        [36m0.2665[0m       0.8318        [35m0.3756[0m  0.7587
     59        [36m0.2596[0m       0.8318        [35m0.3726[0m  0.7537
     60        [36m0.2529[0m       0.8318        [35m0.3696[0m  0.7571
     61        [36m0.2464[0m       0.8318        [35m0.3668[0m  0.7599
     62        [36m0.2400[0m       0.8364        [35m0.3641[0m  0.7557
     63        [36m0.2338[0m       0.8364        [35m0.3615[0m  0.7599
     64        [36m0.2277[0m       0.8364        [35m0.3591[0m  0.7563
     65        [36m0.2218[0m       0.8411        [35m0.3567[0m  0.7578
     66        

    170        [36m0.0070[0m       0.8411        0.3627  0.7545
    171        [36m0.0068[0m       0.8411        0.3635  0.7608
    172        [36m0.0066[0m       0.8458        0.3644  0.7606
    173        [36m0.0064[0m       0.8411        0.3653  0.7567
    174        [36m0.0062[0m       0.8411        0.3661  0.7540
    175        [36m0.0060[0m       0.8411        0.3670  0.7591
    176        [36m0.0058[0m       0.8411        0.3678  0.7585
    177        [36m0.0056[0m       0.8411        0.3687  0.7572
    178        [36m0.0055[0m       0.8411        0.3695  0.7544
    179        [36m0.0053[0m       0.8411        0.3704  0.7594
    180        [36m0.0052[0m       0.8411        0.3712  0.7590
    181        [36m0.0050[0m       0.8411        0.3721  0.7533
    182        [36m0.0049[0m       0.8411        0.3730  0.7591
    183        [36m0.0047[0m       0.8411        0.3739  0.7573
    184        [36m0.0046[0m       0.8411        0.3747  0.7650
    185   

     64        [36m0.1028[0m       0.8140        [35m0.3895[0m  0.7590
     65        [36m0.0988[0m       0.8140        [35m0.3890[0m  0.7563
     66        [36m0.0949[0m       0.8140        [35m0.3886[0m  0.7543
     67        [36m0.0911[0m       0.8140        [35m0.3883[0m  0.7538
     68        [36m0.0875[0m       0.8047        [35m0.3881[0m  0.7536
     69        [36m0.0840[0m       0.8140        [35m0.3879[0m  0.7534
     70        [36m0.0807[0m       0.8140        [35m0.3878[0m  0.7513
     71        [36m0.0775[0m       0.8140        0.3879  0.7521
     72        [36m0.0745[0m       0.8140        [35m0.3878[0m  0.7589
     73        [36m0.0716[0m       0.8140        0.3878  0.7610
     74        [36m0.0688[0m       0.8186        0.3880  0.7537
     75        [36m0.0661[0m       0.8186        0.3883  0.7539
     76        [36m0.0636[0m       0.8186        0.3886  0.7871
     77        [36m0.0611[0m       0.8186        0.3891  0.7580
    

RandomizedSearchCV(cv=None, error_score='raise',
          estimator=<class 'skorch.net.NeuralNetClassifier'>[initialized](
  module_=Net(
    (layer1): Sequential(
      (0): Conv2d (2, 32, kernel_size=(7, 7), stride=(1, 1), padding=(2, 2), groups=2)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (2): ReLU()
      (3): MaxPool2d(ker...    (3): ReLU()
      (4): Linear(in_features=512, out_features=2)
      (5): Softmax()
    )
  ),
),
          fit_params=None, iid=True, n_iter=5, n_jobs=1,
          param_distributions={'max_epochs': array([216, 186,  80, 180, 116]), 'lr': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7fabacef4f28>, 'module__num_units': [5, 10, 15, 20, 40, 80]},
          pre_dispatch='2*n_jobs', random_state=None, refit=False,
          return_train_score='warn', scoring='neg_log_loss', verbose=0)

In [None]:
class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=10,
            nonlin=F.relu,
            dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X

In [51]:
print(gs.best_score_, gs.best_params_)

-0.355908236186 {'max_epochs': 100, 'lr': 4e-06, 'module__num_units': 80}
