Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ distributed

.. automodule:: torch_xla.distributed.xla_multiprocessing
.. autofunction:: spawn
.. autoclass:: MpModelWrapper
:members: to

utils
----------------------------------
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/distributed/xla_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,51 @@ def spawn(fn,
join=join,
daemon=daemon,
start_method=start_method)


class MpModelWrapper(object):
"""Wraps a model to minimize host memory usage when `fork` method is used.

This class should be used together with the `spawn(..., start_method='fork')`
API to minimize the use of host memory.
Instead of creating models on each multiprocessing process, hence replicating
the model's initial host memory, the model is created once at global scope,
and then moved into each device inside the `spawn()` target function.
Example::

WRAPPED_MODEL = MpModelWrapper(MyNetwork())

def _mp_fn(index, ...):
device = xm.xla_device()
model = WRAPPED_MODEL.to(device)
...

xmp.spwan(_mp_fn, ..., start_method='fork')

This method has two advantages. First if uses only one copy of the memory
pages to host the original model weights, and second it serializes the move
of the wrapped model into each device, by lowering the load onto the system
memory during the process.
"""

def __init__(self, model):
"""Creates a new `MpModelWrapper` object.

Args:
model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU
device (which is the default when creating new models).
"""
self._model = model
self._lock = torch.multiprocessing.Lock()

def to(self, device):
"""Retrieves the model moved onto the specified device.

Args:
device (torch.device): The device where the model should be moved onto.
Returns:
The model on the specified device.
"""
with self._lock:
self._model.to(device)
return self._model