-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Description
🐛 Minor Bug
Passing a weights
tensor to WeigthedRandomSampler
throws a warning:
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
This is due to this line of code.
Reported by Wafaa_Wardah in this post.
To Reproduce
Steps to reproduce the behavior:
import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.randn(10)
sampler = WeightedRandomSampler(weights=weights, num_samples=10)
Environment
- PyTorch Version (e.g., 1.0): 1.0.0.dev20190104
Possible bug fix:
I think the easiest way to fix this issue would be to change the current line from
self.weights = torch.tensor(weights, dtype=torch.double)
to
self.weights = torch.as_tensor(weights, dtype=torch.double)
CC @ssnl Let me know, if that would be sufficient and I can create a fast PR for this.
Best,
ptrblck
martijnvanattekum and Vivek-23-Titan
Metadata
Metadata
Assignees
Labels
No labels