Skip to content
Permalink
Browse files

Feature: quick CLIs (#390)

* Helper functions to for CLIs with almost boilerplate.

Add a helper function parse_args that makes it very simple to build
custom CLIs. Add an example for the usage of this and extend docs.

* Extend and adjust README and docs.

* Add CLI implementation and unit tests.

* Update dev requirements: pytest >= 3.4

* Add fire library to dev requirements.

* Remove fire from dev requirements, install in travis.

fire is not on the conda channels, so install would fail. Also, modify
cli tests to be skipped if fire is not installed.

* Correct typos.

* Add the option to have custom defaults.

E.g., if you would like to use batch_size=256 as a default instead of
128, you can now pass a dict `{'batch_size': 256}` to
`parse_args`. This will not only update your model to use those
defaults but also change the help to show your custom defaults.

To achieve the latter effect, it was necessary to parse the sklearn
docstrings for default values and replace them by the new
default. This turned out to be more difficult than expected because
the docstring defaults are not always written in the same fashion. I
tried to catch some variants that I found but there are certainly more
variants out there. It should, however, work fine with the way we
write docstrings in skorch.

* Fix typo in docs/user/helper.rst

Co-Authored-By: benjamin-work <benjamin.bossan@ottogroup.com>

* Update docstring, remove unnecessary try..except.

* Simplify function that matches span for docstring match.
  • Loading branch information...
benjamin-work authored and ottonemo committed Dec 13, 2018
1 parent df1099a commit 37123699e1b4a526029a675a7caeec46618d28fb
Showing with 1,174 additions and 2 deletions.
  1. +1 −0 .travis.yml
  2. +3 −0 CHANGES.md
  3. +159 −0 docs/user/helper.rst
  4. +144 −0 examples/cli/README.md
  5. +207 −0 examples/cli/train.py
  6. +2 −2 requirements-dev.txt
  7. +336 −0 skorch/cli.py
  8. +1 −0 skorch/helper.py
  9. +321 −0 skorch/tests/test_cli.py
@@ -30,6 +30,7 @@ install:
- source activate skorch-env
- cat requirements.txt requirements-dev.txt > reqs.txt
- conda install --file=reqs.txt
- pip install fire
- pip install .
- conda install -c pytorch pytorch-cpu==${PYTORCH_VERSION}
script:
@@ -22,10 +22,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
a re-initialization of the optimizer (#369)
- Support for scipy sparse CSR matrices as input (as, e.g., returned by sklearn's
`CountVectorizer`); note that they are cast to dense matrices during batching
- Helper functions to build command line interfaces with almost no
boilerplate, [example][1811191713] that shows usage

[1810251445]: https://colab.research.google.com/github/dnouri/skorch/blob/master/notebooks/Basic_Usage.ipynb
[1810261633]: https://colab.research.google.com/github/dnouri/skorch/blob/master/notebooks/Advanced_Usage.ipynb
[1811011230]: https://colab.research.google.com/github/dnouri/skorch/blob/master/notebooks/MNIST.ipynb
[1811191713]: https://github.com/dnouri/skorch/tree/master/examples/cli

### Changed

@@ -5,6 +5,7 @@ Helper
This module provides helper functions and classes for the user. They
make working with skorch easier but are not used by skorch itself.


SliceDict
---------

@@ -16,3 +17,161 @@ length of the arrays and not the number of keys, and you get a
``dict``, you would normally not be able to use sklearn
:class:`~sklearn.model_selection.GridSearchCV` and similar things;
with :class:`.SliceDict`, this works.


Command line interface helpers
------------------------------

Often you want to wrap up your experiments by writing a small script
that allows others to reproduce your work. With the help of skorch and
the fire_ library, it becomes very easy to write command line
interfaces without boilerplate. All arguments pertaining to skorch or
its PyTorch module are immediately available as command line
arguments, without the need to write a custom parser. If docstrings in
the numpydoc_ specification are available, there is also an
comprehensive help for the user. Overall, this allows you to make your
work reproducible without the usual hassle.

There is an example_ in the skorch repository that shows how to use
the CLI tools. Below is a snippet that shows the output created by the
help function without writing a single line of argument parsing:

.. code:: bash
$ python examples/cli/train.py pipeline --help
<SelectKBest> options:
--select__score_func : callable
Function taking two arrays X and y, and returning a pair of arrays
(scores, pvalues) or a single array with scores.
Default is f_classif (see below "See also"). The default function only
works with classification tasks.
--select__k : int or "all", optional, default=10
Number of top features to select.
The "all" option bypasses selection, for use in a parameter search.
...
<NeuralNetClassifier> options:
--net__module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
--net__criterion : torch criterion (class, default=torch.nn.NLLLoss)
Negative log likelihood loss. Note that the module should return
probabilities, the log is applied during ``get_loss``.
--net__optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the
module
--net__lr : float (default=0.01)
Learning rate passed to the optimizer. You may use ``lr`` instead
of using ``optimizer__lr``, which would result in the same outcome.
--net__max_epochs : int (default=10)
The number of epochs to train for each ``fit`` call. Note that you
may keyboard-interrupt training at any time.
--net__batch_size : int (default=128)
...
--net__verbose : int (default=1)
Control the verbosity level.
--net__device : str, torch.device (default='cpu')
The compute device to be used. If set to 'cuda', data in torch
tensors will be pushed to cuda tensors before being sent to the
module.
<MLPClassifier> options:
--net__module__hidden_units : int (default=10)
Number of units in hidden layers.
--net__module__num_hidden : int (default=1)
Number of hidden layers.
--net__module__nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
--net__module__dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
Installation
^^^^^^^^^^^^
To use this functionality, you need some further libraries that are not
part of skorch, namely fire_ and numpydoc_. You can install them
thusly:
.. code:: bash
pip install fire numpydoc
Usage
^^^^^
When you write your own script, only the following bits need to be
added:
.. code:: python
import fire
from skorch.helper import parse_args
# your model definition and data fetching code below
...
def main(**kwargs):
X, y = get_data()
my_model = get_model()
# important: wrap the model with the parsed arguments
parsed = parse_args(kwargs)
my_model = parsed(my_model)
my_model.fit(X, y)
if __name__ == '__main__':
fire.Fire(main)
This even works if your neural net is part of an sklearn pipeline, in
which case the help extends to all other estimators of your pipeline.
In case you would like to change some defaults for the net (e.g. using
a ``batch_size`` of 256 instead of 128), this is also possible. You
should have a dictionary containing your new defaults and pass it as
an additional argument to ``parse_args``:
.. code:: python
my_defaults = {'batch_size': 128, 'module__hidden_units': 30}
def main(**kwargs):
...
parsed = parse_args(kwargs, defaults=my_defaults)
my_model = parsed(my_model)
This will update the displayed help to your new defaults, as well as
set the parameters on the net or pipeline for you. However, the
arguments passed via the commandline have precedence. Thus, if you
additionally pass ``--batch_size 512`` to the script, batch size will
be 512.
Restrictions
^^^^^^^^^^^^
Almost all arguments should work out of the box. Therefore, you get
command line arguments for the number of epochs, learning rate, batch
size, etc. for free. Moreover, you can access the module parameters
with the double-underscore notation as usual with skorch
(e.g. ``--module__num_units 100``). This should cover almost all
common cases.
Parsing command line arguments that are non-primitive Python objects
is more difficult, though. skorch's custom parsing should support
normal Python types and simple custom objects, e.g. this works:
``--module__nonlin 'torch.nn.RReLU(0.1, upper=0.4)'``. More complex
parsing might not work. E.g., it is currently not possible to add new
callbacks through the command line (but you can modify existing ones
as usual).
.. _fire: https://github.com/google/python-fire
.. _numpydoc: https://github.com/numpy/numpydoc
.. _example: https://github.com/dnouri/skorch/tree/master/examples/cli
@@ -0,0 +1,144 @@
# skorch helpers for command line interfaces (CLIs)

Often you want to wrap up your experiments by writing a small script
that allows others to reproduce your work. With the help of skorch and
the fire library, it becomes very easy to write command line
interfaces without boilerplate. All arguments pertaining to skorch or
its PyTorch module are immediately available as command line
arguments, without the need to write a custom parser. If docstrings in
the numpydoc specification are available, there is also an
comprehensive help for the user. Overall, this allows you to make your
work reproducible without the usual hassle.

This example is a showcase of how easy CLIs become with skorch.

## Installation

To use this functionaliy, you need some further libraries that are not
part of skorch, namely fire and numpydoc. You can install them thusly:

```bash
pip install fire numpydoc
```

## Usage

The `train.py` file contains an example of how to write your own CLI
with the help of skorch. As you can see, this file almost exclusively
consists of the proper logic, there is no argument parsing
involved.

When you write your own script, only the following bits need to be
added:

```python
import fire
from skorch.helper import parse_args
# your model definition and data fetching code below
...
def main(**kwargs):
X, y = get_data()
my_model = get_model()
# important: wrap the model with the parsed arguments
parsed = parse_args(kwargs)
my_model = parsed(my_model)
my_model.fit(X, y)
if __name__ == '__main__':
fire.Fire(main)
```

This even works if your neural net is part of an sklearn pipeline, in
which case the help extends to all other estimators of your pipeline.

In case you would like to change some defaults for the net (e.g. using
a `batch_size` of 256 instead of 128), this is also possible. You
should have a dictionary containing your new defaults and pass it as
an additional argument to `parse_args`:

```python
my_defaults = {'batch_size': 128, 'module__hidden_units': 30}
def main(**kwargs):
...
parsed = parse_args(kwargs, defaults=my_defaults)
my_model = parsed(my_model)
```

This will update the displayed help to your new defaults, as well as
set the parameters on the net or pipeline for you. However, the
arguments passed via the commandline have precedence. Thus, if you
additionally pass ``--batch_size 512`` to the script, batch size will
be 512.

For more information on how to use fire, follow [this
link](https://github.com/google/python-fire).

## Restrictions

Almost all arguments should work out of the box. Therefore, you get
command line arguments for the number of epochs, learning rate, batch
size, etc. for free. Moreover, you can access the module parameters
with the double-underscore notation as usual with skorch
(e.g. `--module__num_units 100`). This should cover almost all common
cases.

Parsing command line arguments that are non-primitive Python objects
is more difficult, though. skorch's custom parsing should support
normal Python types and simple custom objects, e.g. this works:
`--module__nonlin 'torch.nn.RReLU(0.1, upper=0.4)'`. More complex
parsing might not work. E.g., it is currently not possible to add new
callbacks through the command line (but you can modify existing ones
as usual).

## Running the script

### Getting Help

In this example, there are two variants, only the net ("net") and the
net within an sklearn pipeline ("pipeline"). To get general help for
each, run:

```bash
python train.py net -- --help
python train.py pipeline -- --help
```

To get help for model-specific parameters, run:

```bash
python train.py net --help
python train.py pipeline --help
```

### Training a Model

Run

```bash
python train.py net # only the net
python train.py pipeline # net with pipeline
```

with the defaults.

Example with just the net and some non-defaults:

```bash
python train.py net --n_samples 1000 --output_file 'model.pkl' --lr 0.1 --max_epochs 5 --device 'cuda' --module__hidden_units 50 --module__nonlin 'torch.nn.RReLU(0.1, upper=0.4)' --callbacks__valid_acc__on_train --callbacks__valid_acc__name train_acc
```

Example with an sklearn pipeline:

```bash
python train.py pipeline --n_samples 1000 --net__lr 0.1 --net__module__nonlin 'torch.nn.LeakyReLU()' --scale__minmax__feature_range '(-2, 2)' --scale__normalize__norm l1
```
Oops, something went wrong.

0 comments on commit 3712369

Please sign in to comment.
You can’t perform that action at this time.