New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] Let DistributedSampler take a Sampler as input #23430
Comments
@pandeykartikey this is still pending discussion if it's something that we want to add in PyTorch. |
I don't mind having this. We can make the contract such that it can be e.g. an integer, or an object which supports |
What's the status of this issue? It says triaged. Can I go ahead with this? |
@fmassa this could be helpful ! Waiting for such update, we can use a sort of from torch.utils.data.distributed import DistributedSampler
class DistributedProxySampler(DistributedSampler):
"""Sampler that restricts data loading to a subset of input sampler indices.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Input sampler is assumed to be of constant size.
Arguments:
sampler: Input data sampler.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self, sampler, num_replicas=None, rank=None):
super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
self.sampler = sampler
def __iter__(self):
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch)
indices = list(self.sampler)
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
if len(indices) != self.total_size:
raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if len(indices) != self.num_samples:
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
return iter(indices) To check : import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)
num_replicas = 4
dist_samplers = [
DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i)
for i in range(num_replicas)
]
torch.manual_seed(0)
true_indices = list(sampler)
true_indices[:5]
indices_per_rank = []
for s in dist_samplers:
s.set_epoch(0)
indices_per_rank += list(s)
set(indices_per_rank) == set(true_indices) |
I think this implementation will be more correct in semi-degenerate edge cases:
Consider what would happen with the original implementation if you set your WeightedRandomSampler |
@elistevens the cases when the number of samples are less then the number of replicas are rare :) |
I think that the issue is triggered when |
It seems that this post provides a solution? |
Looking at the solution it looks pretty good to me. Is there any plan to add something similar to pytorch ? I can just copy catalyst solution in the meantime but I feel like it should be natively possible to use custom samplers with ddp in pytorch. |
catalyst has an implementation of this feature in their code. class DistributedSamplerWrapper(DistributedSampler):
"""
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self):
"""@TODO: Docs. Contribution is welcome."""
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) |
Any plans to implement this in the future? |
1 similar comment
Any plans to implement this in the future? |
Also stumbled on this |
@ngimel: high priority because needed for dataparallel deprecation @VitalyFedyunin -- is this the right module? |
There're also two versions of DistributedSampler in https://github.com/fundamentalvision/Deformable-DETR/blob/main/datasets/samplers.py changing how the dataset is sliced over the replicas. Maybe worth checking it out as well |
my own stab cc @ngimel: class DistributedSamplerProxy(torch.utils.data.DistributedSampler):
def __iter__(self):
indices = list(self.dataset)[:self.total_size]
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size , f"{len(indices)} != {self.total_size}"
# subsample
#indices = indices[self.rank:self.total_size:self.num_replicas]
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}"
return iter(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, 'set_epoch'):
self.dataset.set_epoch(epoch)
elif hasattr(self.dataset, 'generator'):
self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch)
def state_dict(self):
return self.dataset.state_dict()
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict) |
There two ways of splitting the samples over replicas (contiguous chunks vs interleaved). So if distributedsampler supports both, it can help removing this duplicate brittle code from users' projects. |
I will drop hi-pri as it has workarounds (thanks @vadimkantorov) and we better focus on improving TorchData with distributed functionality. |
@VitalyFedyunin my workaround is not the first one :) good ones above as well, i just simplified/improved on top |
馃殌 Feature
Motivation
Currently,
DistributedSampler
assumes that it takes aDataset
as argument. But in reality, the only information it exploits from it is itslen
.We sometimes want to have a custom Sampler to be used in distributed mode. So it might be desirable to also let
DistributedSampler
take aSampler
as argument.Potential implementation
The only difference is that in
pytorch/torch/utils/data/distributed.py
Lines 57 to 61 in 46224ef
We would additionally have something like
Pitch
More modularity and code reuse
Additionally, it make writing code more (in my view) clear. Instead of
we can always have
which, at first sight might seem very similar, but they imply different things.
Alternatives
We can integrate the functionality of
DistributedSampler
inside our custom sampler, but this seems redundant.cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @ssnl @VitalyFedyunin @ejguan @NivekT @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang
The text was updated successfully, but these errors were encountered: