Skip to content

Commit

Permalink
DOC how to prevent dictionary unpacking in forward (#941)
Browse files Browse the repository at this point in the history
skorch automatically unpacks dictionaries when passing input arguments
to the module's forward method. Sometimes, this may not be wanted. This
PR updates the docs to show how to prevent that.
  • Loading branch information
BenjaminBossan committed Mar 21, 2023
1 parent 6870e99 commit a148fed
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions docs/user/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ In addition to the types above, you can pass dictionaries or lists of
one of those data types, e.g. a dictionary of
:class:`numpy.ndarray`\s. When you pass dictionaries, the keys of the
dictionaries are used as the argument name for the
:func:`~torch.nn.Module.forward` method of the net's
:meth:`~torch.nn.Module.forward` method of the net's
``module``. Similarly, the column names of pandas ``DataFrame``\s are
used as argument names. The example below should illustrate how to use
this feature:
Expand Down Expand Up @@ -128,7 +128,7 @@ this feature:
net.fit(X, y)
Note that the keys in the dictionary ``X`` exactly match the argument
names in the :func:`~torch.nn.Module.forward` method. This way, you
names in the :meth:`~torch.nn.Module.forward` method. This way, you
can easily work with several different types of input features.

The :class:`.Dataset` from skorch makes the assumption that you always
Expand All @@ -144,3 +144,24 @@ apply your own transformation on the data, you should subclass
:class:`.Dataset` and override the
:func:`~skorch.dataset.Dataset.transform` method, then pass your
custom class to :class:`.NeuralNet` as the ``dataset`` argument.

Preventing dictionary unpacking
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As noted, when ``X`` is a dictionary, it is automatically unpacked when passed
to the module's :meth:`~torch.nn.Module.forward` method. Sometimes, you may want
to prevent this, e.g. because you're using a ``module`` from another library
that expects a dict as input, or because the exact dict keys are unknown. This
can be achieved by wrapping the original ``module`` and packing the dict again:

.. code:: python
from other_lib import ModuleExpectingDict
class WrappedModule(ModuleExpectingDict):
def forward(self, **kwargs):
# catch **kwargs, pass as a dict
return super().forward(kwargs)
Similarly, wrapping the ``module`` can be used to make any other desired changes
to the input arguments.

0 comments on commit a148fed

Please sign in to comment.