Skip to content

Commit

Permalink
Re-name cold_start to warm_start.
Browse files Browse the repository at this point in the history
This is consistent with sklearn. Change code, tests, docstrings,
documentation, notebook accordingly.
  • Loading branch information
benjamin-work authored and ottonemo committed Nov 28, 2017
1 parent 1d82974 commit fe41f9d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 65 deletions.
4 changes: 2 additions & 2 deletions docs/user/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ By default, when you call ``fit`` more than once, the training starts
from zero instead of from where it was left. This is in line with
sklearn\'s behavior but not always desired. If you would like to
continue training, use ``partial_fit`` instead of
``fit``. Alternatively, there is the ``cold_start`` argument, which is
``True`` by default. Set it to ``False`` instead and you should be
``fit``. Alternatively, there is the ``warm_start`` argument, which is
``False`` by default. Set it to ``True`` instead and you should be
fine.


Expand Down
6 changes: 3 additions & 3 deletions docs/user/neuralnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,14 @@ default callbacks. This is so that user-defined callbacks can make use
of the things provided by the default callbacks. The only exception is
the default callback ``PrintLog``, which is always called last.

cold_start
warm_start
^^^^^^^^^^

This argument determines whether each ``fit`` call leads to a
re-initialization of the ``NeuralNet`` or not. By default, when
calling ``fit``, the parameters of the net are initialized, so your
previous training progress is lost (consistent with the sklearn
``fit`` calls). In contrast, with ``cold_start=False``, each ``fit``
``fit`` calls). In contrast, with ``warm_start=True``, each ``fit``
call will continue from the most recent state.

use_cuda
Expand Down Expand Up @@ -276,7 +276,7 @@ task does not have an actual ``y``, you may pass ``y=None``.
In addition to ``fit``, there is also the ``partial_fit`` method,
known from some sklearn estimators. ``partial_fit`` allows you to
continue training from your current status, even if you set
``cold_start=True``. A further use case for ``partial_fit`` is when
``warm_start=False``. A further use case for ``partial_fit`` is when
your data does not fit into memory and you thus need to have several
training steps.

Expand Down
100 changes: 51 additions & 49 deletions notebooks/Advanced_Usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
" ClassifierModule,\n",
" max_epochs=10,\n",
" lr=0.1,\n",
" cold_start=False,\n",
" warm_start=True,\n",
" callbacks=[AccuracyTweet(min_accuracy=0.7)],\n",
")"
]
Expand All @@ -300,16 +300,16 @@
"text": [
" epoch train_loss valid_acc valid_loss dur\n",
"------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.7111\u001b[0m \u001b[32m0.5100\u001b[0m \u001b[35m0.6894\u001b[0m 0.1552\n",
" 2 \u001b[36m0.6928\u001b[0m \u001b[32m0.5500\u001b[0m \u001b[35m0.6803\u001b[0m 0.0554\n",
" 3 \u001b[36m0.6833\u001b[0m \u001b[32m0.5650\u001b[0m \u001b[35m0.6741\u001b[0m 0.0534\n",
" 4 \u001b[36m0.6763\u001b[0m \u001b[32m0.5850\u001b[0m \u001b[35m0.6674\u001b[0m 0.0547\n",
" 5 \u001b[36m0.6727\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6616\u001b[0m 0.0795\n",
" 6 \u001b[36m0.6606\u001b[0m \u001b[32m0.6600\u001b[0m \u001b[35m0.6536\u001b[0m 0.0643\n",
" 7 \u001b[36m0.6560\u001b[0m 0.6600 \u001b[35m0.6443\u001b[0m 0.0532\n",
" 8 \u001b[36m0.6427\u001b[0m \u001b[32m0.6650\u001b[0m \u001b[35m0.6354\u001b[0m 0.0784\n",
" 9 \u001b[36m0.6300\u001b[0m \u001b[32m0.6800\u001b[0m \u001b[35m0.6264\u001b[0m 0.1540\n",
" 10 \u001b[36m0.6289\u001b[0m 0.6800 \u001b[35m0.6189\u001b[0m 0.0632\n",
" 1 \u001b[36m0.7111\u001b[0m \u001b[32m0.5100\u001b[0m \u001b[35m0.6894\u001b[0m 0.0994\n",
" 2 \u001b[36m0.6928\u001b[0m \u001b[32m0.5500\u001b[0m \u001b[35m0.6803\u001b[0m 0.0676\n",
" 3 \u001b[36m0.6833\u001b[0m \u001b[32m0.5650\u001b[0m \u001b[35m0.6741\u001b[0m 0.0396\n",
" 4 \u001b[36m0.6763\u001b[0m \u001b[32m0.5850\u001b[0m \u001b[35m0.6674\u001b[0m 0.0661\n",
" 5 \u001b[36m0.6727\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6616\u001b[0m 0.0557\n",
" 6 \u001b[36m0.6606\u001b[0m \u001b[32m0.6600\u001b[0m \u001b[35m0.6536\u001b[0m 0.0645\n",
" 7 \u001b[36m0.6560\u001b[0m 0.6600 \u001b[35m0.6443\u001b[0m 0.0559\n",
" 8 \u001b[36m0.6427\u001b[0m \u001b[32m0.6650\u001b[0m \u001b[35m0.6354\u001b[0m 0.0569\n",
" 9 \u001b[36m0.6300\u001b[0m \u001b[32m0.6800\u001b[0m \u001b[35m0.6264\u001b[0m 0.0606\n",
" 10 \u001b[36m0.6289\u001b[0m 0.6800 \u001b[35m0.6189\u001b[0m 0.0594\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"*tweet* Accuracy never reached 0.7 :( #skorch #pytorch\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
Expand All @@ -318,7 +318,7 @@
{
"data": {
"text/plain": [
"<skorch.net.NeuralNetClassifier at 0x7fe8a8214438>"
"<skorch.net.NeuralNetClassifier at 0x7f7aeeae95f8>"
]
},
"execution_count": 10,
Expand All @@ -334,7 +334,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Oh no, our model never reached a validation accuracy of 0.7. Let's train some more (this is possible because we set `cold_start=False`):"
"Oh no, our model never reached a validation accuracy of 0.7. Let's train some more (this is possible because we set `warm_start=True`):"
]
},
{
Expand All @@ -346,16 +346,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
" 11 \u001b[36m0.6241\u001b[0m \u001b[32m0.7150\u001b[0m \u001b[35m0.6114\u001b[0m 0.1174\n",
" 12 \u001b[36m0.6132\u001b[0m 0.7150 \u001b[35m0.6017\u001b[0m 0.1397\n",
" 13 \u001b[36m0.5950\u001b[0m \u001b[32m0.7350\u001b[0m \u001b[35m0.5902\u001b[0m 0.2457\n",
" 14 \u001b[36m0.5914\u001b[0m 0.7200 \u001b[35m0.5831\u001b[0m 0.2226\n",
" 15 \u001b[36m0.5784\u001b[0m 0.7300 \u001b[35m0.5733\u001b[0m 0.1469\n",
" 16 0.5816 \u001b[32m0.7400\u001b[0m \u001b[35m0.5665\u001b[0m 0.1852\n",
" 17 \u001b[36m0.5766\u001b[0m \u001b[32m0.7450\u001b[0m \u001b[35m0.5616\u001b[0m 0.1034\n",
" 18 \u001b[36m0.5636\u001b[0m 0.7450 \u001b[35m0.5559\u001b[0m 0.1243\n",
" 19 \u001b[36m0.5517\u001b[0m 0.7350 \u001b[35m0.5527\u001b[0m 0.1246\n",
" 20 0.5570 0.7350 \u001b[35m0.5492\u001b[0m 0.1318\n",
" 11 \u001b[36m0.6241\u001b[0m \u001b[32m0.7150\u001b[0m \u001b[35m0.6114\u001b[0m 0.0475\n",
" 12 \u001b[36m0.6132\u001b[0m 0.7150 \u001b[35m0.6017\u001b[0m 0.0443\n",
" 13 \u001b[36m0.5950\u001b[0m \u001b[32m0.7350\u001b[0m \u001b[35m0.5902\u001b[0m 0.0523\n",
" 14 \u001b[36m0.5914\u001b[0m 0.7200 \u001b[35m0.5831\u001b[0m 0.0551\n",
" 15 \u001b[36m0.5784\u001b[0m 0.7300 \u001b[35m0.5733\u001b[0m 0.0598\n",
" 16 0.5816 \u001b[32m0.7400\u001b[0m \u001b[35m0.5665\u001b[0m 0.0524\n",
" 17 \u001b[36m0.5766\u001b[0m \u001b[32m0.7450\u001b[0m \u001b[35m0.5616\u001b[0m 0.0539\n",
" 18 \u001b[36m0.5636\u001b[0m 0.7450 \u001b[35m0.5559\u001b[0m 0.0548\n",
" 19 \u001b[36m0.5517\u001b[0m 0.7350 \u001b[35m0.5527\u001b[0m 0.0571\n",
" 20 0.5570 0.7350 \u001b[35m0.5492\u001b[0m 0.0544\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"*tweet* Accuracy reached 0.7 at epoch 11!!! #skorch #pytorch\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
Expand All @@ -364,7 +364,7 @@
{
"data": {
"text/plain": [
"<skorch.net.NeuralNetClassifier at 0x7fe8a8214438>"
"<skorch.net.NeuralNetClassifier at 0x7f7aeeae95f8>"
]
},
"execution_count": 11,
Expand Down Expand Up @@ -432,14 +432,16 @@
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"net = NeuralNetClassifier(\n",
" ClassifierModule,\n",
" max_epochs=10,\n",
" lr=0.1,\n",
" cold_start=False,\n",
" warm_start=True,\n",
" callbacks=[\n",
" ('tweet', AccuracyTweet(min_accuracy=0.7)),\n",
" ],\n",
Expand All @@ -458,16 +460,16 @@
"text": [
" epoch train_loss valid_acc valid_loss dur\n",
"------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.7261\u001b[0m \u001b[32m0.5050\u001b[0m \u001b[35m0.6986\u001b[0m 0.1680\n",
" 2 \u001b[36m0.6977\u001b[0m \u001b[32m0.5350\u001b[0m \u001b[35m0.6889\u001b[0m 0.1411\n",
" 3 \u001b[36m0.6897\u001b[0m \u001b[32m0.5550\u001b[0m \u001b[35m0.6844\u001b[0m 0.1202\n",
" 4 \u001b[36m0.6846\u001b[0m \u001b[32m0.5800\u001b[0m \u001b[35m0.6813\u001b[0m 0.0879\n",
" 5 \u001b[36m0.6788\u001b[0m 0.5800 \u001b[35m0.6768\u001b[0m 0.0641\n",
" 6 \u001b[36m0.6725\u001b[0m 0.5800 \u001b[35m0.6731\u001b[0m 0.0881\n",
" 7 \u001b[36m0.6711\u001b[0m \u001b[32m0.5950\u001b[0m \u001b[35m0.6689\u001b[0m 0.0708\n",
" 8 \u001b[36m0.6581\u001b[0m \u001b[32m0.6150\u001b[0m \u001b[35m0.6650\u001b[0m 0.0975\n",
" 9 0.6648 0.6050 \u001b[35m0.6605\u001b[0m 0.1186\n",
" 10 \u001b[36m0.6550\u001b[0m 0.6150 \u001b[35m0.6549\u001b[0m 0.1199\n",
" 1 \u001b[36m0.7261\u001b[0m \u001b[32m0.5050\u001b[0m \u001b[35m0.6986\u001b[0m 0.0527\n",
" 2 \u001b[36m0.6977\u001b[0m \u001b[32m0.5350\u001b[0m \u001b[35m0.6889\u001b[0m 0.0584\n",
" 3 \u001b[36m0.6897\u001b[0m \u001b[32m0.5550\u001b[0m \u001b[35m0.6844\u001b[0m 0.0593\n",
" 4 \u001b[36m0.6846\u001b[0m \u001b[32m0.5800\u001b[0m \u001b[35m0.6813\u001b[0m 0.0579\n",
" 5 \u001b[36m0.6788\u001b[0m 0.5800 \u001b[35m0.6768\u001b[0m 0.0483\n",
" 6 \u001b[36m0.6725\u001b[0m 0.5800 \u001b[35m0.6731\u001b[0m 0.0479\n",
" 7 \u001b[36m0.6711\u001b[0m \u001b[32m0.5950\u001b[0m \u001b[35m0.6689\u001b[0m 0.0522\n",
" 8 \u001b[36m0.6581\u001b[0m \u001b[32m0.6150\u001b[0m \u001b[35m0.6650\u001b[0m 0.0542\n",
" 9 0.6648 0.6050 \u001b[35m0.6605\u001b[0m 0.0585\n",
" 10 \u001b[36m0.6550\u001b[0m 0.6150 \u001b[35m0.6549\u001b[0m 0.0544\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"*tweet* Accuracy reached 0.6 at epoch 8!!! #skorch #pytorch\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
Expand All @@ -476,7 +478,7 @@
{
"data": {
"text/plain": [
"<skorch.net.NeuralNetClassifier at 0x7fe8a8204668>"
"<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>"
]
},
"execution_count": 13,
Expand All @@ -503,7 +505,7 @@
{
"data": {
"text/plain": [
"<skorch.net.NeuralNetClassifier at 0x7fe8a8204668>"
"<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>"
]
},
"execution_count": 14,
Expand All @@ -526,16 +528,16 @@
"text": [
" epoch train_loss valid_acc valid_loss dur\n",
"------- ------------ ----------- ------------ ------\n",
" 11 \u001b[36m0.6435\u001b[0m \u001b[32m0.6250\u001b[0m \u001b[35m0.6492\u001b[0m 0.2017\n",
" 12 0.6435 \u001b[32m0.6350\u001b[0m \u001b[35m0.6437\u001b[0m 0.1521\n",
" 13 \u001b[36m0.6267\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6375\u001b[0m 0.0861\n",
" 14 \u001b[36m0.6214\u001b[0m \u001b[32m0.6800\u001b[0m \u001b[35m0.6306\u001b[0m 0.1819\n",
" 15 \u001b[36m0.6185\u001b[0m 0.6750 \u001b[35m0.6239\u001b[0m 0.0761\n",
" 16 \u001b[36m0.6060\u001b[0m 0.6750 \u001b[35m0.6154\u001b[0m 0.2071\n",
" 17 \u001b[36m0.5964\u001b[0m \u001b[32m0.6850\u001b[0m \u001b[35m0.6061\u001b[0m 0.1707\n",
" 18 \u001b[36m0.5868\u001b[0m \u001b[32m0.7000\u001b[0m \u001b[35m0.5964\u001b[0m 0.1347\n",
" 19 \u001b[36m0.5693\u001b[0m \u001b[32m0.7150\u001b[0m \u001b[35m0.5859\u001b[0m 0.1106\n",
" 20 \u001b[36m0.5689\u001b[0m \u001b[32m0.7200\u001b[0m \u001b[35m0.5793\u001b[0m 0.1101\n",
" 11 \u001b[36m0.6435\u001b[0m \u001b[32m0.6250\u001b[0m \u001b[35m0.6492\u001b[0m 0.0568\n",
" 12 0.6435 \u001b[32m0.6350\u001b[0m \u001b[35m0.6437\u001b[0m 0.0546\n",
" 13 \u001b[36m0.6267\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6375\u001b[0m 0.0555\n",
" 14 \u001b[36m0.6214\u001b[0m \u001b[32m0.6800\u001b[0m \u001b[35m0.6306\u001b[0m 0.0390\n",
" 15 \u001b[36m0.6185\u001b[0m 0.6750 \u001b[35m0.6239\u001b[0m 0.0440\n",
" 16 \u001b[36m0.6060\u001b[0m 0.6750 \u001b[35m0.6154\u001b[0m 0.0454\n",
" 17 \u001b[36m0.5964\u001b[0m \u001b[32m0.6850\u001b[0m \u001b[35m0.6061\u001b[0m 0.0503\n",
" 18 \u001b[36m0.5868\u001b[0m \u001b[32m0.7000\u001b[0m \u001b[35m0.5964\u001b[0m 0.0579\n",
" 19 \u001b[36m0.5693\u001b[0m \u001b[32m0.7150\u001b[0m \u001b[35m0.5859\u001b[0m 0.0556\n",
" 20 \u001b[36m0.5689\u001b[0m \u001b[32m0.7200\u001b[0m \u001b[35m0.5793\u001b[0m 0.0558\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"*tweet* Accuracy never reached 0.75 :( #skorch #pytorch\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
Expand All @@ -544,7 +546,7 @@
{
"data": {
"text/plain": [
"<skorch.net.NeuralNetClassifier at 0x7fe8a8204668>"
"<skorch.net.NeuralNetClassifier at 0x7f7aeeae94e0>"
]
},
"execution_count": 15,
Expand Down
10 changes: 5 additions & 5 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class NeuralNet(object):
``net.set_params(callbacks__print_log__keys=['epoch',
'train_loss'])``).
cold_start : bool (default=True)
warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the
module (cold start) or whether the module should be trained
further (warm start).
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
dataset=Dataset,
train_split=CVSplit(5),
callbacks=None,
cold_start=True,
warm_start=False,
verbose=1,
use_cuda=False,
**kwargs
Expand All @@ -231,7 +231,7 @@ def __init__(
self.dataset = dataset
self.train_split = train_split
self.callbacks = callbacks
self.cold_start = cold_start
self.warm_start = warm_start
self.verbose = verbose
self.use_cuda = use_cuda
self.gradient_clip_value = gradient_clip_value
Expand Down Expand Up @@ -604,7 +604,7 @@ def fit(self, X, y=None, **fit_params):
"""Initialize and fit the module.
If the module was already initialized, by calling fit, the
module will be re-initialized (unless ``cold_start`` is False).
module will be re-initialized (unless ``warm_start`` is True).
Parameters
----------
Expand All @@ -626,7 +626,7 @@ def fit(self, X, y=None, **fit_params):
**fit_params : currently ignored
"""
if self.cold_start or not self.initialized_:
if not self.warm_start or not self.initialized_:
self.initialize()

self.partial_fit(X, y, **fit_params)
Expand Down
12 changes: 6 additions & 6 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_net_init_one_unknown_argument(self, net_cls, module_cls):
def test_net_init_two_unknown_argument(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
net_cls(module_cls, lr=0.1, mxa_epochs=5,
cold_start=True, bathc_size=20)
warm_start=False, bathc_size=20)

expected = ("__init__() got unexpected argument(s) "
"mxa_epochs, bathc_size."
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_with_initialized_module_warm_start(
self, net_cls, module_cls, data, capsys):
X, y = data
module = module_cls(num_units=123)
net = net_cls(module, max_epochs=0, cold_start=False)
net = net_cls(module, max_epochs=0, warm_start=True)
net.partial_fit(X, y)

for p0, p1 in zip(module.parameters(), net.module_.parameters()):
Expand Down Expand Up @@ -620,7 +620,7 @@ def test_call_fit_twice_retrains(self, net_cls, module_cls, data):
# fit loop, parameters have changed (because the module was
# re-initialized)
X, y = data[0][:100], data[1][:100]
net = net_cls(module_cls, cold_start=True).fit(X, y)
net = net_cls(module_cls, warm_start=False).fit(X, y)
params_before = net.module_.parameters()

net.max_epochs = 0
Expand All @@ -633,7 +633,7 @@ def test_call_fit_twice_retrains(self, net_cls, module_cls, data):

def test_call_fit_twice_warmstart(self, net_cls, module_cls, data):
X, y = data[0][:100], data[1][:100]
net = net_cls(module_cls, cold_start=False).fit(X, y)
net = net_cls(module_cls, warm_start=True).fit(X, y)
params_before = net.module_.parameters()

net.max_epochs = 0
Expand All @@ -648,11 +648,11 @@ def test_partial_fit_first_call(self, net_cls, module_cls, data):
# It should be possible to partial_fit without calling fit first.
X, y = data[0][:100], data[1][:100]
# does not raise
net_cls(module_cls, cold_start=False).partial_fit(X, y)
net_cls(module_cls, warm_start=True).partial_fit(X, y)

def test_call_partial_fit_after_fit(self, net_cls, module_cls, data):
X, y = data[0][:100], data[1][:100]
net = net_cls(module_cls, cold_start=True).fit(X, y)
net = net_cls(module_cls, warm_start=False).fit(X, y)
params_before = net.module_.parameters()

net.max_epochs = 0
Expand Down

0 comments on commit fe41f9d

Please sign in to comment.