From 1df40115efb7936a6e46c9f7bc9fd1d1e9985881 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Mon, 4 May 2020 07:34:11 -0700 Subject: [PATCH] Added utility class to wrap a model, to be used together with fork-based multiprocessing in order to minimize host memory load. --- docs/source/index.rst | 2 + torch_xla/distributed/xla_multiprocessing.py | 48 ++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index 68e2375295a..9a6af6054ab 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,6 +34,8 @@ distributed .. automodule:: torch_xla.distributed.xla_multiprocessing .. autofunction:: spawn +.. autoclass:: MpModelWrapper + :members: to utils ---------------------------------- diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index 3db133d678c..418a4dcf689 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -288,3 +288,51 @@ def spawn(fn, join=join, daemon=daemon, start_method=start_method) + + +class MpModelWrapper(object): + """Wraps a model to minimize host memory usage when `fork` method is used. + + This class should be used together with the `spawn(..., start_method='fork')` + API to minimize the use of host memory. + Instead of creating models on each multiprocessing process, hence replicating + the model's initial host memory, the model is created once at global scope, + and then moved into each device inside the `spawn()` target function. + Example:: + + WRAPPED_MODEL = MpModelWrapper(MyNetwork()) + + def _mp_fn(index, ...): + device = xm.xla_device() + model = WRAPPED_MODEL.to(device) + ... + + xmp.spwan(_mp_fn, ..., start_method='fork') + + This method has two advantages. First if uses only one copy of the memory + pages to host the original model weights, and second it serializes the move + of the wrapped model into each device, by lowering the load onto the system + memory during the process. + """ + + def __init__(self, model): + """Creates a new `MpModelWrapper` object. + + Args: + model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU + device (which is the default when creating new models). + """ + self._model = model + self._lock = torch.multiprocessing.Lock() + + def to(self, device): + """Retrieves the model moved onto the specified device. + + Args: + device (torch.device): The device where the model should be moved onto. + Returns: + The model on the specified device. + """ + with self._lock: + self._model.to(device) + return self._model