Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1e209a0
OpInfo: use functools.partial to decrease noise in make_tensor calls
peterbell10 Sep 2, 2022
8b01e4f
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 2, 2022
935cdfb
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 2, 2022
0d88ebe
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 2, 2022
de0e2c4
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 3, 2022
8e9d7bb
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 5, 2022
187ce18
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 5, 2022
9a38cb7
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 6, 2022
da18489
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 7, 2022
004cc98
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 7, 2022
4c26abf
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 7, 2022
6a0736f
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 7, 2022
79e1de0
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 7, 2022
d5f52a6
Rebase on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 8, 2022
017a1d5
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 20, 2022
9d185ed
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 21, 2022
a557448
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 22, 2022
ec409b8
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 24, 2022
cb94a79
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 27, 2022
a105add
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 27, 2022
c2aa3cf
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 28, 2022
bef0d44
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Sep 29, 2022
6ca0a6d
Rebase on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 2, 2022
9e25fed
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 2, 2022
f649023
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 2, 2022
4cf937c
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 4, 2022
418eedf
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 4, 2022
5bb5eba
Update on "OpInfo: use functools.partial to decrease noise in make_te…
peterbell10 Oct 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs):
).with_metadata(broadcasts_input=broadcasts_input)

if dtype.is_complex:
args = (make_arg(input_shape), make_arg(batch1_shape), make_arg(batch2_shape))
yield SampleInput(
make_arg(input_shape),
make_arg(batch1_shape),
Expand Down
44 changes: 14 additions & 30 deletions torch/testing/_internal/opinfo/definitions/_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,22 @@ def sample_inputs_softmax_variant(


def _generate_masked_op_mask(input_shape, device, **kwargs):
make_arg = partial(
make_tensor, dtype=torch.bool, device=device, requires_grad=False
)
yield None
yield make_tensor(input_shape, dtype=torch.bool, device=device, requires_grad=False)
yield make_arg(input_shape)
if len(input_shape) > 2:
# broadcast last mask dimension:
yield make_tensor(
input_shape[:-1] + (1,),
dtype=torch.bool,
device=device,
requires_grad=False,
)
yield make_arg(input_shape[:-1] + (1,))
# broadcast middle mask dimension:
yield make_tensor(
input_shape[:1] + (1,) + input_shape[2:],
dtype=torch.bool,
device=device,
requires_grad=False,
)
yield make_arg(input_shape[:1] + (1,) + input_shape[2:])
# broadcast first mask dimension:
yield make_tensor(
(1,) + input_shape[1:], dtype=torch.bool, device=device, requires_grad=False
)
yield make_arg((1,) + input_shape[1:])
# mask.ndim < input.ndim
yield make_tensor(
input_shape[1:], dtype=torch.bool, device=device, requires_grad=False
)
yield make_arg(input_shape[1:])
# mask.ndim == 1
yield make_tensor(
input_shape[-1:], dtype=torch.bool, device=device, requires_grad=False
)
yield make_arg(input_shape[-1:])
# masks that require broadcasting of inputs (mask.ndim >
# input.ndim) will not be supported, however, we may
# reconsider this if there will be demand on this kind of
Expand Down Expand Up @@ -351,19 +338,16 @@ def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwar
list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
]

make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
for shape, input_masks, other_masks in zip(
shapes, input_mask_lists, other_mask_lists
):
for input_mask, other_mask in zip(input_masks, other_masks):
input = make_tensor(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
other = make_tensor(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
yield SampleInput(
input.clone().requires_grad_(requires_grad),
args=(other.clone().requires_grad_(requires_grad),),
make_arg(shape),
args=(make_arg(shape),),
kwargs=dict(input_mask=input_mask, other_mask=other_mask),
)

Expand Down
69 changes: 22 additions & 47 deletions torch/testing/_internal/opinfo/definitions/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,15 @@ def sample_inputs_linalg_norm(
else:
matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)

make_arg = partial(
make_tensor,
dtype=dtype,
device=device,
requires_grad=requires_grad,
low=None,
high=None,
)

for test_size in test_sizes:
is_vector_norm = len(test_size) == 1
is_matrix_norm = len(test_size) == 2
Expand All @@ -383,17 +392,7 @@ def sample_inputs_linalg_norm(

for keepdim in [False, True]:
if variant != "subgradient_at_zero" and is_valid_for_p2:
yield SampleInput(
make_tensor(
test_size,
dtype=dtype,
device=device,
low=None,
high=None,
requires_grad=requires_grad,
),
kwargs=dict(keepdim=keepdim),
)
yield SampleInput(make_arg(test_size), kwargs=dict(keepdim=keepdim))

if not (is_vector_norm or is_matrix_norm):
continue
Expand Down Expand Up @@ -436,28 +435,12 @@ def sample_inputs_linalg_norm(
)
else:
yield SampleInput(
make_tensor(
test_size,
dtype=dtype,
device=device,
low=None,
high=None,
requires_grad=requires_grad,
),
args=(ord,),
kwargs=dict(keepdim=keepdim),
make_arg(test_size), args=(ord,), kwargs=dict(keepdim=keepdim)
)

if ord in ["nuc", "fro"]:
yield SampleInput(
make_tensor(
test_size,
dtype=dtype,
device=device,
low=None,
high=None,
requires_grad=requires_grad,
),
make_arg(test_size),
kwargs=dict(ord=ord, keepdim=keepdim, dim=(0, 1)),
)

Expand Down Expand Up @@ -732,16 +715,14 @@ def sample_inputs_linalg_ldl_solve(
)

# Symmetric case
make_arg = partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
for test_case in test_cases1:
factors, pivots, _ = test_case
factors.requires_grad = requires_grad
for B_batch_shape in ((), factors.shape[:-2]):
B = make_tensor(
(*B_batch_shape, factors.shape[-1], S),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
B = make_arg((*B_batch_shape, factors.shape[-1], S))
yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
clone_factors = factors.detach().clone().requires_grad_(requires_grad)
yield SampleInput(
Expand All @@ -753,12 +734,7 @@ def sample_inputs_linalg_ldl_solve(
factors, pivots, _ = test_case
factors.requires_grad = requires_grad
for B_batch_shape in ((), factors.shape[:-2]):
B = make_tensor(
(*B_batch_shape, factors.shape[-1], S),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
B = make_arg((*B_batch_shape, factors.shape[-1], S))
yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
clone_factors = factors.detach().clone().requires_grad_(requires_grad)
yield SampleInput(
Expand Down Expand Up @@ -1080,13 +1056,12 @@ def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
# a_shapes += [(0, 0, 1, 2, 3, 0)]
dimss = [None, (0, 2)]

make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
for a_shape, dims in itertools.product(a_shapes, dimss):
a = make_tensor(
a_shape, dtype=dtype, device=device, requires_grad=requires_grad
)
b = make_tensor(
a_shape[:2], dtype=dtype, device=device, requires_grad=requires_grad
)
a = make_arg(a_shape)
b = make_arg(a_shape[:2])
yield SampleInput(a, args=(b,), kwargs=dict(dims=dims))


Expand Down