# Advanced usage

This notebook shows some more advanced features of `skorch`. More examples will be added with time.

### Table of contents

* [Setup](#Setup)
* [Callbacks](#Callbacks)
  * [Writing your own callback](#Writing-a-custom-callback)
  * [Accessing callback parameters](#Accessing-callback-parameters)

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
torch.manual_seed(0);

## Setup

### A toy binary classification task

We load a toy classification task from `sklearn`.

In [3]:
import numpy as np
from sklearn.datasets import make_classification

In [4]:
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)

In [5]:
X.shape, y.shape, y.mean()

((1000, 20), (1000,), 0.5)

### Definition of the `pytorch` classification `module`

We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used.

In [6]:
from skorch.net import NeuralNetClassifier

In [7]:
class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=10,
            nonlin=F.relu,
            dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X))
        return X

## Callbacks

Callbacks are a powerful and flexible way to customize the behavior of your neural network. They are all called at specific points during the model training, e.g. when training starts, or after each batch. Have a look at the `skorch.callbacks` module to see the callbacks that are already implemented.

### Writing a custom callback

Although `skorch` comes with a handful of useful callbacks, you may find that you would like to write your own callbacks. Doing so is straightforward, just remember these rules:
* They should inherit from `skorch.callbacks.Callback`.
* They should implement at least one of the `on_`-methods provided by the parent class (e.g. `on_batch_begin` or `on_epoch_end`).
* As argument, the `on_`-methods first get the `NeuralNet` instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have `**kwargs` in the signature for potentially unused arguments.
* *Optional*: If you have attributes that should be reset when the model is re-initialized, those attributes should be set in the `initialize` method.

Here is an example of a callback that remembers at which epoch the validation accuracy reached a certain value. Then, when training is finished, it calls a mock Twitter API and tweets that epoch. We proceed as follows:
* We set the desired minimum accuracy during `__init__`.
* We set the critical epoch during `initialize`.
* After each epoch, if the critical accuracy has not yet been reached, we check if it was reached.
* When training finishes, we send a tweet informing us whether our training was successful or not.

In [8]:
from skorch.callbacks import Callback


def tweet(msg):
    print("~" * 60)
    print("*tweet*", msg, "#skorch #pytorch")
    print("~" * 60)


class AccuracyTweet(Callback):
    def __init__(self, min_accuracy):
        self.min_accuracy = min_accuracy

    def initialize(self):
        self.critical_epoch_ = -1

    def on_epoch_end(self, net, **kwargs):
        if self.critical_epoch_ > -1:
            return
        # look at the validation accuracy of the last epoch
        if net.history[-1, 'valid_acc'] >= self.min_accuracy:
            self.critical_epoch_ = len(net.history)

    def on_train_end(self, net, **kwargs):
        if self.critical_epoch_ < 0:
            msg = "Accuracy never reached {} :(".format(self.min_accuracy)
        else:
            msg = "Accuracy reached {} at epoch {}!!!".format(
                self.min_accuracy, self.critical_epoch_)

        tweet(msg)

Now we initialize a `NeuralNetClassifier` and pass your new callback in a list to the `callbacks` argument. After that, we train the model and see what happens.

In [9]:
net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=10,
    lr=0.1,
    warm_start=True,
    callbacks=[AccuracyTweet(min_accuracy=0.7)],
)

In [10]:
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7111[0m       [32m0.5100[0m        [35m0.6894[0m  0.0994
      2        [36m0.6928[0m       [32m0.5500[0m        [35m0.6803[0m  0.0676
      3        [36m0.6833[0m       [32m0.5650[0m        [35m0.6741[0m  0.0396
      4        [36m0.6763[0m       [32m0.5850[0m        [35m0.6674[0m  0.0661
      5        [36m0.6727[0m       [32m0.6450[0m        [35m0.6616[0m  0.0557
      6        [36m0.6606[0m       [32m0.6600[0m        [35m0.6536[0m  0.0645
      7        [36m0.6560[0m       0.6600        [35m0.6443[0m  0.0559
      8        [36m0.6427[0m       [32m0.6650[0m        [35m0.6354[0m  0.0569
      9        [36m0.6300[0m       [32m0.6800[0m        [35m0.6264[0m  0.0606
     10        [36m0.6289[0m       0.6800        [35m0.6189[0m  0.0594
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*tw

<skorch.net.NeuralNetClassifier at 0x7f7aeeae95f8>

Oh no, our model never reached a validation accuracy of 0.7. Let's train some more (this is possible because we set `warm_start=True`):

In [11]:
net.fit(X, y)

     11        [36m0.6241[0m       [32m0.7150[0m        [35m0.6114[0m  0.0475
     12        [36m0.6132[0m       0.7150        [35m0.6017[0m  0.0443
     13        [36m0.5950[0m       [32m0.7350[0m        [35m0.5902[0m  0.0523
     14        [36m0.5914[0m       0.7200        [35m0.5831[0m  0.0551
     15        [36m0.5784[0m       0.7300        [35m0.5733[0m  0.0598
     16        0.5816       [32m0.7400[0m        [35m0.5665[0m  0.0524
     17        [36m0.5766[0m       [32m0.7450[0m        [35m0.5616[0m  0.0539
     18        [36m0.5636[0m       0.7450        [35m0.5559[0m  0.0548
     19        [36m0.5517[0m       0.7350        [35m0.5527[0m  0.0571
     20        0.5570       0.7350        [35m0.5492[0m  0.0544
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*tweet* Accuracy reached 0.7 at epoch 11!!! #skorch #pytorch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


<skorch.net.NeuralNetClassifier at 0x7f7aeeae95f8>

Finally, the validation score exceeded 0.7. Hooray!

### Accessing callback parameters

Say you would like to use a learning rate schedule with your neural net, but you don't know what parameters are best for that schedule. Wouldn't it be nice if you could find those parameters with a grid search? With `skorch`, this is possible. Below, we show how to access the parameters of your callbacks.

To simplify the access to your callback parameters, it is best if you give your callback a name. This is achieved by passing the `callbacks` parameter a list of *name*, *callback* tuples, such as:

    callbacks=[
        ('scheduler', LearningRateScheduler)),
        ...
    ],
    
This way, you can access your callbacks using the double underscore semantics (as, for instance, in an `sklearn` `Pipeline`):

    callbacks__scheduler__epoch=50,

So if you would like to perform a grid search on, say, the number of units in the hidden layer and the learning rate schedule, it could look something like this:

    param_grid = {
        'module__num_units': [50, 100, 150],
        'callbacks__scheduler__epoch': [10, 50, 100],
    }
    
*Note*: If you would like to refresh your knowledge on grid search, look [here](http://scikit-learn.org/stable/modules/grid_search.html#grid-search), [here](http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_text_feature_extraction.html), or in the *Basic_Usage* notebok.

Below, we show how accessing the callback parameters works our `AccuracyTweet` callback:

In [12]:
net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=10,
    lr=0.1,
    warm_start=True,
    callbacks=[
        ('tweet', AccuracyTweet(min_accuracy=0.7)),
    ],
    callbacks__tweet__min_accuracy=0.6,
)

In [13]:
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7261[0m       [32m0.5050[0m        [35m0.6986[0m  0.0527
      2        [36m0.6977[0m       [32m0.5350[0m        [35m0.6889[0m  0.0584
      3        [36m0.6897[0m       [32m0.5550[0m        [35m0.6844[0m  0.0593
      4        [36m0.6846[0m       [32m0.5800[0m        [35m0.6813[0m  0.0579
      5        [36m0.6788[0m       0.5800        [35m0.6768[0m  0.0483
      6        [36m0.6725[0m       0.5800        [35m0.6731[0m  0.0479
      7        [36m0.6711[0m       [32m0.5950[0m        [35m0.6689[0m  0.0522
      8        [36m0.6581[0m       [32m0.6150[0m        [35m0.6650[0m  0.0542
      9        0.6648       0.6050        [35m0.6605[0m  0.0585
     10        [36m0.6550[0m       0.6150        [35m0.6549[0m  0.0544
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*tweet* Accuracy reached 0.6 a

<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>

As you can see, by passing `callbacks__tweet__min_accuracy=0.6`, we changed that parameter. The same can be achieved by calling the `set_params` method with the corresponding arguments:

In [14]:
net.set_params(callbacks__tweet__min_accuracy=0.75)

<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>

In [15]:
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
     11        [36m0.6435[0m       [32m0.6250[0m        [35m0.6492[0m  0.0568
     12        0.6435       [32m0.6350[0m        [35m0.6437[0m  0.0546
     13        [36m0.6267[0m       [32m0.6450[0m        [35m0.6375[0m  0.0555
     14        [36m0.6214[0m       [32m0.6800[0m        [35m0.6306[0m  0.0390
     15        [36m0.6185[0m       0.6750        [35m0.6239[0m  0.0440
     16        [36m0.6060[0m       0.6750        [35m0.6154[0m  0.0454
     17        [36m0.5964[0m       [32m0.6850[0m        [35m0.6061[0m  0.0503
     18        [36m0.5868[0m       [32m0.7000[0m        [35m0.5964[0m  0.0579
     19        [36m0.5693[0m       [32m0.7150[0m        [35m0.5859[0m  0.0556
     20        [36m0.5689[0m       [32m0.7200[0m        [35m0.5793[0m  0.0558
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*tweet* Accu

<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>