Skip to content

Commit

Permalink
make linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Jan 28, 2021
1 parent 3e162c3 commit 8ac029d
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,45 +993,46 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
)
return [SampleInput(tensor) for tensor in tensors]


def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad):
samples = (
SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn(M, M, device=device) > 0, make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),
args=(torch.randn(M, M, device=device) > 0,
make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),

SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn((M,), device=device) > 0, make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),
args=(torch.randn((M,), device=device) > 0,
make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),

SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(bernoulli_scalar().to(device), make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),
# Inplace variants fail
# SampleInput(make_tensor((M,), device, dtype, low=None, high=None, requires_grad=requires_grad),
# args=(torch.randn(M, M, device=device) > 0, make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),
args=(bernoulli_scalar().to(device),
make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad))),
)

return samples

def sample_inputs_masked_select(op_info, device, dtype, requires_grad):
samples = (
SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn(M, M, device=device) > 0,)),
args=(torch.randn(M, M, device=device) > 0,)),

SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn((M,), device=device) > 0,)),
args=(torch.randn((M,), device=device) > 0,)),

SampleInput(make_tensor((M,), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn((M, M), device=device) > 0,)),
args=(torch.randn((M, M), device=device) > 0,)),

SampleInput(make_tensor((M, 1, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn((M, M), device=device) > 0,)),
args=(torch.randn((M, M), device=device) > 0,)),

SampleInput(make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.tensor(1, device=device, dtype=torch.bool),)),
args=(torch.tensor(1, device=device, dtype=torch.bool),)),

SampleInput(make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.tensor(1, device=device, dtype=torch.bool),)),
args=(torch.tensor(1, device=device, dtype=torch.bool),)),

SampleInput(make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(torch.randn((M, M), device=device) > 0,)),
args=(torch.randn((M, M), device=device) > 0,)),
)

return samples
Expand Down

0 comments on commit 8ac029d

Please sign in to comment.