Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Let DistributedSampler take a Sampler as input #23430

Open
fmassa opened this issue Jul 26, 2019 · 21 comments
Open

[Feature request] Let DistributedSampler take a Sampler as input #23430

fmassa opened this issue Jul 26, 2019 · 21 comments
Labels
feature A request for a proper, new feature. has workaround module: dataloader Related to torch.utils.data.DataLoader and Sampler oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@fmassa
Copy link
Member

fmassa commented Jul 26, 2019

馃殌 Feature

Motivation

Currently, DistributedSampler assumes that it takes a Dataset as argument. But in reality, the only information it exploits from it is its len.

We sometimes want to have a custom Sampler to be used in distributed mode. So it might be desirable to also let DistributedSampler take a Sampler as argument.

Potential implementation

The only difference is that in

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

We would additionally have something like

if isinstance(self.dataset, Sampler):
    orig_indices = list(iter(self.dataset))
    indices = [orig_indices[i] for i in indices]

return iter(indices)

Pitch

More modularity and code reuse

sampler = MyNiceSampler(dataset)
if distributed:
    sampler = DistributedSampler(sampler)

Additionally, it make writing code more (in my view) clear. Instead of

if distributed:
    sampler = DistributedSampler(dataset)
else:
    sampler = RandomSampler(dataset)

we can always have

sampler = RandomSampler(dataset)
if distributed:
    sampler = DistributedSampler(sampler, shuffle=False)

which, at first sight might seem very similar, but they imply different things.

Alternatives

We can integrate the functionality of DistributedSampler inside our custom sampler, but this seems redundant.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @ssnl @VitalyFedyunin @ejguan @NivekT @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@fmassa fmassa added module: dataloader Related to torch.utils.data.DataLoader and Sampler enhancement Not as big of a feature, but technically not a bug. Should be easy to fix small We think this is a small issue to fix. Consider knocking off high priority small issues labels Jul 26, 2019
@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 26, 2019
@izdeby izdeby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 26, 2019
@pandeykartikey
Copy link
Contributor

Would like to work on this feature @fmassa @izdeby?

@fmassa
Copy link
Member Author

fmassa commented Jul 30, 2019

@pandeykartikey this is still pending discussion if it's something that we want to add in PyTorch.

@apaszke
Copy link
Contributor

apaszke commented Aug 1, 2019

I don't mind having this. We can make the contract such that it can be e.g. an integer, or an object which supports __len__.

@v0dro
Copy link
Contributor

v0dro commented Nov 2, 2019

What's the status of this issue? It says triaged. Can I go ahead with this?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Dec 5, 2019

@fmassa this could be helpful !

Waiting for such update, we can use a sort of DistributedProxySampler:

from torch.utils.data.distributed import DistributedSampler

class DistributedProxySampler(DistributedSampler):
    """Sampler that restricts data loading to a subset of input sampler indices.

    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.

    .. note::
        Input sampler is assumed to be of constant size.

    Arguments:
        sampler: Input data sampler.
        num_replicas (optional): Number of processes participating in
            distributed training.
        rank (optional): Rank of the current process within num_replicas.
    """

    def __init__(self, sampler, num_replicas=None, rank=None):        
        super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
        self.sampler = sampler

    def __iter__(self):
        # deterministically shuffle based on epoch
        torch.manual_seed(self.epoch)
        indices = list(self.sampler)

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        if len(indices) != self.total_size:
            raise RuntimeError("{} vs {}".format(len(indices), self.total_size))

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        if len(indices) != self.num_samples:
            raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))

        return iter(indices)

To check :

import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)

num_replicas = 4
dist_samplers = [
    DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) 
    for i in range(num_replicas)
]

torch.manual_seed(0)
true_indices = list(sampler)
true_indices[:5]

indices_per_rank = []
for s in dist_samplers:
    s.set_epoch(0)
    indices_per_rank += list(s)

set(indices_per_rank) == set(true_indices)

@elistevens
Copy link
Contributor

I think this implementation will be more correct in semi-degenerate edge cases:

def __iter__(self):
    # deterministically shuffle based on epoch
    torch.manual_seed(self.epoch)
    
    indices = []
    while len(indices) < self.total_size:
        indices += list(self.sampler)

    # subsample
    indices = indices[self.rank:self.total_size:self.num_replicas]
    if len(indices) != self.num_samples:
        raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))

    return iter(indices)

Consider what would happen with the original implementation if you set your WeightedRandomSampler num_samples to 3; you'd get 4 samples, with the 0th and 3rd being the same, which is incorrect given that the 4th should be drawn from the entire 100 samples. This implementation isn't perfect, since it still could duplicate a result in violation of the replacement flag, but I don't think that's as much of a problem.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 10, 2020

@elistevens the cases when the number of samples are less then the number of replicas are rare :)
Yes, I agree that there is an assumption about that.

@elistevens
Copy link
Contributor

I think that the issue is triggered when self.total_size is larger than sampler.num_samples, but smaller than len(sampler.weights). total_size will be larger than num_samples any time the replica count isn't an even divisor, and having your num_samples be fixed to a size smaller than your dataset seems like a reasonable thing to do if you want to have consistent epoch sizes or something like that.

@b02202050
Copy link

b02202050 commented Jun 7, 2020

@schwobr
Copy link

schwobr commented Dec 1, 2020

It seems that this post provides a solution?
https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/22

Looking at the solution it looks pretty good to me. Is there any plan to add something similar to pytorch ? I can just copy catalyst solution in the meantime but I feel like it should be natively possible to use custom samplers with ddp in pytorch.

@steermomo
Copy link

catalyst has an implementation of this feature in their code.

class DistributedSamplerWrapper(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.
    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.
    .. note::
        Sampler is assumed to be of constant size.
    """

    def __init__(
        self,
        sampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
    ):
        """
        Args:
            sampler: Sampler used for subsampling
            num_replicas (int, optional): Number of processes participating in
              distributed training
            rank (int, optional): Rank of the current process
              within ``num_replicas``
            shuffle (bool, optional): If true (default),
              sampler will shuffle the indices
        """
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self):
        """@TODO: Docs. Contribution is welcome."""
        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))

https://github.com/catalyst-team/catalyst/blob/ea3fadbaa6034dabeefbbb53ab8c310186f6e5d0/catalyst/data/sampler.py#L522

@lukasfolle
Copy link

Any plans to implement this in the future?

1 similar comment
@ZhiyuanChen
Copy link
Contributor

Any plans to implement this in the future?

@vadimkantorov
Copy link
Contributor

Also stumbled on this

@ngimel
Copy link
Collaborator

ngimel commented Oct 8, 2021

triage review due to user activity, and this being desirable for #65936. cc @ejguan

@ngimel ngimel added triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 8, 2021
@ngimel ngimel removed the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 11, 2021
@mruberry mruberry removed triage review enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Oct 11, 2021
@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module feature A request for a proper, new feature. high priority and removed small We think this is a small issue to fix. Consider knocking off high priority small issues labels Oct 11, 2021
@mruberry
Copy link
Collaborator

@ngimel: high priority because needed for dataparallel deprecation

@VitalyFedyunin -- is this the right module?

@vadimkantorov
Copy link
Contributor

There're also two versions of DistributedSampler in https://github.com/fundamentalvision/Deformable-DETR/blob/main/datasets/samplers.py changing how the dataset is sliced over the replicas. Maybe worth checking it out as well

@vadimkantorov
Copy link
Contributor

my own stab cc @ngimel:

class DistributedSamplerProxy(torch.utils.data.DistributedSampler):
    def __iter__(self):
        indices = list(self.dataset)[:self.total_size]

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size , f"{len(indices)} != {self.total_size}"

        # subsample
        #indices = indices[self.rank:self.total_size:self.num_replicas]
        offset = self.num_samples * self.rank
        indices = indices[offset : offset + self.num_samples]
        assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}"

        return iter(indices)

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        if hasattr(self.dataset, 'set_epoch'):
            self.dataset.set_epoch(epoch)
        elif hasattr(self.dataset, 'generator'):
            self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch)

    def state_dict(self):
        return self.dataset.state_dict()

    def load_state_dict(self, state_dict):
        self.dataset.load_state_dict(state_dict)

@vadimkantorov
Copy link
Contributor

There two ways of splitting the samples over replicas (contiguous chunks vs interleaved). So if distributedsampler supports both, it can help removing this duplicate brittle code from users' projects.

@VitalyFedyunin VitalyFedyunin added oncall: distributed Add this issue/PR to distributed oncall triage queue has workaround and removed high priority labels Feb 14, 2022
@VitalyFedyunin
Copy link
Contributor

I will drop hi-pri as it has workarounds (thanks @vadimkantorov) and we better focus on improving TorchData with distributed functionality.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Feb 14, 2022

@VitalyFedyunin my workaround is not the first one :) good ones above as well, i just simplified/improved on top

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. has workaround module: dataloader Related to torch.utils.data.DataLoader and Sampler oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests