diff --git a/redirects.py b/redirects.py index 11332c4974..748b7f1f9c 100644 --- a/redirects.py +++ b/redirects.py @@ -38,4 +38,5 @@ "recipes/recipes_index.html": "../recipes_index.html", "recipes/torchserve_vertexai_tutorial.html": "../index.html", "unstable_source/vulkan_workflow.rst": "../index.html", + "unstable/skip_param_init.html": "https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html", } diff --git a/unstable_index.rst b/unstable_index.rst index 5c3dd61cf6..974fc27681 100644 --- a/unstable_index.rst +++ b/unstable_index.rst @@ -45,15 +45,6 @@ decide if we want to upgrade the level of commitment or to fail fast. :link: unstable/semi_structured_sparse.html :tags: Model-Optimiziation -.. Modules - -.. customcarditem:: - :header: Skipping Module Parameter Initialization in PyTorch 1.10 - :card_description: Describes skipping parameter initialization during module construction in PyTorch 1.10, avoiding wasted computation. - :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png - :link: unstable/skip_param_init.html - :tags: Modules - .. vmap .. customcarditem:: diff --git a/unstable_source/skip_param_init.rst b/unstable_source/skip_param_init.rst deleted file mode 100644 index 197877b4c6..0000000000 --- a/unstable_source/skip_param_init.rst +++ /dev/null @@ -1,127 +0,0 @@ -Skipping Module Parameter Initialization -======================================== - -Introduction ------------- - -When a module is created, its learnable parameters are initialized according -to a default initialization scheme associated with the module type. For example, the `weight` -parameter for a :class:`torch.nn.Linear` module is initialized from a -`uniform(-1/sqrt(in_features), 1/sqrt(in_features))` distribution. If some other initialization -scheme is desired, this has traditionally required re-initializing the parameters -after module instantiation: - -:: - - from torch import nn - - # Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)). - m = nn.Linear(10, 5) - - # Re-initialize weight from a different distribution. - nn.init.orthogonal_(m.weight) - -In this case, the initialization done during construction is wasted computation, and it may be non-trivial if -the `weight` parameter is large. - -Skipping Initialization ------------------------ - -It is now possible to skip parameter initialization during module construction, avoiding -wasted computation. This is easily accomplished using the :func:`torch.nn.utils.skip_init` function: - -:: - - from torch import nn - from torch.nn.utils import skip_init - - m = skip_init(nn.Linear, 10, 5) - - # Example: Do custom, non-default parameter initialization. - nn.init.orthogonal_(m.weight) - -This can be applied to any module that satisfies the conditions described in the -:ref:`Updating` section below. Note that all modules provided by -`torch.nn` satisfy these conditions and thus support skipping init. - -.. _Updating: - -Updating Modules to Support Skipping Initialization ---------------------------------------------------- - -Due to the way :func:`torch.nn.utils.skip_init` is implemented (see :ref:`Details`), there are -two requirements that a module must meet to be compatible with the function. -You can opt in to the parameter initialization skipping functionality for your custom module -simply by adhering to these requirements: - - 1. The module must accept a `device` kwarg in its constructor that is passed to any parameters - or buffers created during construction. - - 2. The module must not perform any computation on parameters or buffers in its constructor except - initialization (i.e. functions from `torch.nn.init`). - -The following example demonstrates a module updated to support the `device` -kwarg by passing it along to any created parameters, buffers, or submodules: - -:: - - import torch - from torch import nn - - class MyModule(torch.nn.Module): - def __init__(self, foo, bar, device=None): - super().__init__() - - # ==== Case 1: Module creates parameters directly. ==== - # Pass device along to any created parameters. - self.param1 = nn.Parameter(torch.empty((foo, bar), device=device)) - self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device))) - - # To ensure support for the meta device, avoid using ops except those in - # torch.nn.init on parameters in your module's constructor. - with torch.no_grad(): - nn.init.kaiming_uniform_(self.param1) - nn.init.uniform_(self.param2) - - - # ==== Case 2: Module creates submodules. ==== - # Pass device along recursively. All submodules will need to support - # them as well; this is the case for all torch.nn provided modules. - self.fc = nn.Linear(bar, 5, device=device) - - # This also works with containers. - self.linears = nn.Sequential( - nn.Linear(5, 5, device=device), - nn.Linear(5, 1, device=device) - ) - - - # ==== Case 3: Module creates buffers. ==== - # Pass device along during buffer tensor creation. - self.register_buffer('some_buffer', torch.ones(7, device=device)) - - ... - -.. _Details: - -Implementation Details ----------------------- - -Behind the scenes, the :func:`torch.nn.utils.skip_init` function is implemented in terms of a two-step pattern: - -:: - - # 1. Initialize module on the meta device; all torch.nn.init ops have - # no-op behavior on the meta device. - m = nn.Linear(10, 5, device='meta') - - # 2. Materialize an uninitialized (empty) form of the module on the CPU device. - # The result of this is a module instance with uninitialized parameters. - m.to_empty(device='cpu') - -It works by instantiating the module onto a "meta" device, which has tensor shape information -but does not allocate any storage. The `torch.nn.init` ops are specially implemented for this meta device -so that they have no-op behavior. This results in the parameter intialization logic being essentially skipped. - -Note that this pattern only works for modules that properly support a `device` kwarg during construction, as -described in :ref:`Updating`.