-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Cache DataParallel replicas #4216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dp(i) | ||
self.assertTrue(dp._replicas) | ||
yield | ||
self.assertFalse(dp._replicas) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/Module.cpp
Outdated
@@ -583,6 +583,17 @@ PyObject *THPModule_userEnabledCuDNN(PyObject *_unused) | |||
else Py_RETURN_FALSE; | |||
} | |||
|
|||
PyObject* THPModule_inheritODictGetitem(PyObject *_unused, PyObject *_cls) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
@@ -17,6 +17,18 @@ | |||
|
|||
_flatten = torch._C._jit_flatten | |||
|
|||
def _hack_compiled_function_isinstance(): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/data_parallel.py
Outdated
from ..modules import Module | ||
from .scatter_gather import scatter_kwargs, gather | ||
from .replicate import replicate | ||
from .parallel_apply import parallel_apply | ||
|
||
|
||
# NOTE: these callbacks are one-off | ||
# NOTE: they are dedupliacted based on identity |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/data_parallel.py
Outdated
from ..modules import Module | ||
from .scatter_gather import scatter_kwargs, gather | ||
from .replicate import replicate | ||
from .parallel_apply import parallel_apply | ||
|
||
|
||
# NOTE: these callbacks are one-off | ||
# NOTE: they are dedupliacted based on identity | ||
class _CallbackOrderedDict(OrderedDict): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
super(_CallbackOrderedDict, self).__init__(*args, **kwargs) | ||
|
||
def register_modification_callback(self, cb): | ||
if all(c is not cb for c in self.callbacks): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/data_parallel.py
Outdated
def __reduce__(self): | ||
state = super(_CallbackOrderedDict, self).__reduce__() | ||
# NB: 2nd element of reduce tuple is the dict. | ||
if state[2]: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/data_parallel.py
Outdated
# NB: 2nd element of reduce tuple is the dict. | ||
if state[2]: | ||
lstate = list(state) | ||
lstate[2] = state[2].copy().pop('callbacks') |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
def replicate(self, module, device_ids): | ||
return replicate(module, device_ids) | ||
if not self._replicas: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I force pushed, but there are no significant code changes. Only the comments @ezyang requested and lint fixes |
Not about this PR, but in |
To do this you'd need a mechanism to verify which parameters have changed, and it's not that simple. It's not enough to account for optimizers - someone might have loaded a different set of weights, and you need to propagate this change |
Good point. How about when doing inference with a model? Is there a way to avoid broadcasting tensors at all in PyTorch? Could this be an option of DataParallel? |
In general, I'd like a comment on top of DataParallel describing how the general replica strategy works. Otherwise LGTM. |
There is a comment. In general it's all invisible to the user and doesn't need to be. We're taking care of maintaining them in sync. The only thing that changed (and I forgot to update it) is that now EDIT: I thought you meant the docs, but I think I misunderstood that and you just wanted a comment in the code. Will add this |
@ssnl I don't think so. If someone changes parameters you'd still need to re-broadcast them even in inference mode. In general, you need this "tensor modification hook" that we don't provide, and I'm not sure if we want to leak this logic that deep. |
I really like that we're speeding up DataParallel. I'm a bit concerned about the caching strategy. We're adding state to DataParallel, but still trying to pretend it's stateless. Is there a way to speed up replicate without adding hidden state (the caching)? For example, we're already special casing |
I believe the new answer is to just DistributedDataParallel |
This commits makes DataParallel cache the wrapped module replicas, and helps bring down overhead of
replicate
on ResNet1001 from ~140ms to 42ms. Remaining time is spent mostly inbroadcast_coalesced
, so moving it to C++ is the next step.This commit really starts to push Python to the limit, which can be seen in two places:
__getitem__
for OrderedDict subclasses than it could. I'm going to post to Python's mailing lists and clarify why is this happening, but I'm not aware of any reason why this hack would not work. This might seem like a silly thing, but removing this hack costs us 30ms at each forward.torch.jit.compile
for modules, and implement this poor man's inheritance-like thing (including the__instancecheck__
hack so these object still appear to belong to subclasses)... I've tried a few other things, but I can't come up with anything else that wouldn't break. The problem is that having one superclass with__slots__
, and another one in C++ confuses Python, and it complains that it can't figure out how to lay them out in memory. I'm happy to discuss alternative solutions.