diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index e80206bc02184..ccea2f52ebe54 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -43,12 +43,13 @@ class RandomSampler(Sampler): Arguments: data_source (Dataset): dataset to sample from """ + _cpu = torch.device('cpu') def __init__(self, data_source): self.data_source = data_source def __iter__(self): - return iter(torch.randperm(len(self.data_source)).tolist()) + return iter(torch.randperm(len(self.data_source), device=RandomSampler._cpu).tolist()) def __len__(self): return len(self.data_source)