Skip to content

WeightedRandomSampler throws warning for copy constructor #16627

@ptrblck

Description

@ptrblck

🐛 Minor Bug

Passing a weights tensor to WeigthedRandomSampler throws a warning:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

This is due to this line of code.

Reported by Wafaa_Wardah in this post.

To Reproduce

Steps to reproduce the behavior:

import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.randn(10)
sampler = WeightedRandomSampler(weights=weights, num_samples=10)

Environment

  • PyTorch Version (e.g., 1.0): 1.0.0.dev20190104

Possible bug fix:

I think the easiest way to fix this issue would be to change the current line from

self.weights = torch.tensor(weights, dtype=torch.double)

to

self.weights = torch.as_tensor(weights, dtype=torch.double)

CC @ssnl Let me know, if that would be sufficient and I can create a fast PR for this.

Best,
ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions