[Feature] Imbalanced batch scattering in th.nn.DataParallel #46381
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: data parallel
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature
When, say 4 GPUs (e.g. Titan Xp with 12 gigs of video memory) are shared among a group of people, the balanced (default) data scattering (i.e. sizes of sub-batches scattered across these cards are nearly the same) may not be flexible enough to fully leverage the GPU resource in terms of the graphics memory. For example, when the other students / colleagues have consumed
[1GB, 2GB, 8GB, 4GB]
of graphics memory, we may want to change the numbers of samples scattered to these cards as[x*11/33, x*10/33, x*4/33, x*8/33]
where x denote the overall batchsize. Namely, I think the scattering function should be able to accept a weight vector for imbalanced scattering.However, assuming the previously existing workloads on the 4 cards are the same dispite of distinct memory usage, an apparent downside of doing this is that the cards with smaller batches have to sleep and wait longer for synchronization / barrier. At this point I'm not quite sure whether this requested feature is meaningful.
Motivation
Better leverage the GPU resources especially when a shared GPU server is crowded.
The text was updated successfully, but these errors were encountered: