Skip to content

Conversation

apaszke
Copy link
Contributor

@apaszke apaszke commented Dec 17, 2017

This commits makes DataParallel cache the wrapped module replicas, and helps bring down overhead of replicate on ResNet1001 from ~140ms to 42ms. Remaining time is spent mostly in broadcast_coalesced, so moving it to C++ is the next step.

This commit really starts to push Python to the limit, which can be seen in two places:

  • There's this ugly odict hack, because Python inheritance does weird things, and selects a much slower implementation of __getitem__ for OrderedDict subclasses than it could. I'm going to post to Python's mailing lists and clarify why is this happening, but I'm not aware of any reason why this hack would not work. This might seem like a silly thing, but removing this hack costs us 30ms at each forward.
  • I was forced to change the implementation of torch.jit.compile for modules, and implement this poor man's inheritance-like thing (including the __instancecheck__ hack so these object still appear to belong to subclasses)... I've tried a few other things, but I can't come up with anything else that wouldn't break. The problem is that having one superclass with __slots__, and another one in C++ confuses Python, and it complains that it can't figure out how to lay them out in memory. I'm happy to discuss alternative solutions.

@apaszke apaszke requested review from ezyang and colesbury December 17, 2017 20:57
dp(i)
self.assertTrue(dp._replicas)
yield
self.assertFalse(dp._replicas)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -583,6 +583,17 @@ PyObject *THPModule_userEnabledCuDNN(PyObject *_unused)
else Py_RETURN_FALSE;
}

PyObject* THPModule_inheritODictGetitem(PyObject *_unused, PyObject *_cls)

This comment was marked as off-topic.

@@ -17,6 +17,18 @@

_flatten = torch._C._jit_flatten

def _hack_compiled_function_isinstance():

This comment was marked as off-topic.

from ..modules import Module
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply


# NOTE: these callbacks are one-off
# NOTE: they are dedupliacted based on identity

This comment was marked as off-topic.

from ..modules import Module
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply


# NOTE: these callbacks are one-off
# NOTE: they are dedupliacted based on identity
class _CallbackOrderedDict(OrderedDict):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

super(_CallbackOrderedDict, self).__init__(*args, **kwargs)

def register_modification_callback(self, cb):
if all(c is not cb for c in self.callbacks):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

def __reduce__(self):
state = super(_CallbackOrderedDict, self).__reduce__()
# NB: 2nd element of reduce tuple is the dict.
if state[2]:

This comment was marked as off-topic.

# NB: 2nd element of reduce tuple is the dict.
if state[2]:
lstate = list(state)
lstate[2] = state[2].copy().pop('callbacks')

This comment was marked as off-topic.

This comment was marked as off-topic.

def replicate(self, module, device_ids):
return replicate(module, device_ids)
if not self._replicas:

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor Author

apaszke commented Dec 18, 2017

I force pushed, but there are no significant code changes. Only the comments @ezyang requested and lint fixes

@ssnl
Copy link
Collaborator

ssnl commented Dec 20, 2017

Not about this PR, but in broadcast_coalesced, could there be a way to only broadcast modified parameters? A valid use case is finetuning only the last layer.

@apaszke
Copy link
Contributor Author

apaszke commented Dec 20, 2017

To do this you'd need a mechanism to verify which parameters have changed, and it's not that simple. It's not enough to account for optimizers - someone might have loaded a different set of weights, and you need to propagate this change

@ssnl
Copy link
Collaborator

ssnl commented Dec 20, 2017

Good point. How about when doing inference with a model? Is there a way to avoid broadcasting tensors at all in PyTorch? Could this be an option of DataParallel?

@ezyang
Copy link
Contributor

ezyang commented Dec 21, 2017

In general, I'd like a comment on top of DataParallel describing how the general replica strategy works. Otherwise LGTM.

@apaszke
Copy link
Contributor Author

apaszke commented Dec 21, 2017

There is a comment. In general it's all invisible to the user and doesn't need to be. We're taking care of maintaining them in sync. The only thing that changed (and I forgot to update it) is that now __dict__s are shared among all replicas (previously they were shallow copies).

EDIT: I thought you meant the docs, but I think I misunderstood that and you just wanted a comment in the code. Will add this

@apaszke
Copy link
Contributor Author

apaszke commented Dec 21, 2017

@ssnl I don't think so. If someone changes parameters you'd still need to re-broadcast them even in inference mode. In general, you need this "tensor modification hook" that we don't provide, and I'm not sure if we want to leak this logic that deep.

@colesbury
Copy link
Member

I really like that we're speeding up DataParallel. I'm a bit concerned about the caching strategy. We're adding state to DataParallel, but still trying to pretend it's stateless.

Is there a way to speed up replicate without adding hidden state (the caching)? For example, we're already special casing _parameters, _buffers and _modules. Can we put a parent class of nn.Module in C++ that handles these three attributes? That would give use quick access to parameters and buffers for broadcasting. Replicate could then be light-weight: create a sort of proxy module that has its own _parameters, _buffers, and _modules but shares __dict__. Sub-modules could be lazily created on access to minimize the overhead at the start of a forward pass.

@ezyang
Copy link
Contributor

ezyang commented Apr 15, 2019

I believe the new answer is to just DistributedDataParallel

@ezyang ezyang closed this Apr 15, 2019
@facebook-github-bot facebook-github-bot deleted the fast_dp branch July 13, 2020 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants