New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow support for negative dimension argument for all functions #1108
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but we need to add tests for all those functions in CUDA and especially in autograd. Maybe it'd be possible to reuse the same code somehow (Variables have a very similar API).
test/test_autograd.py
Outdated
@@ -1059,6 +1059,7 @@ def prod_single_zero(dim_size): | |||
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))), | |||
(Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'), | |||
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))), | |||
(Concat, (-1,), ((S, S, 1), (S, S, 2), (S, S, 3)), 'negdim'), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
# 4th argument is the supported types of call: | ||
# 0: method called as tensor.name(...) | ||
# 1: method called as tensor.name_(...) | ||
# 2: method called as torch.name(tensor, ...) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
x = torch.randn(*tensor_arg) | ||
ndim = len(tensor_arg) | ||
|
||
n_dim_to_test = 0 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
|
||
for dims_val in combinations(range(ndim), n_dim_to_test): | ||
arg = arg_constr() | ||
arg_neg = copy.copy(arg) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
b = getattr(torch, name)(x, *arg_neg) | ||
self.assertEqual(a, b) | ||
|
||
return tmp |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
|
||
|
||
def idx_tensor(size, max_val): | ||
return torch.LongTensor(*size).random_() % max_val |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
x = [] | ||
for arg in tensor_arg: | ||
x.append(torch.randn(*arg)) | ||
ndim = len(arg) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/cwrap.py
Outdated
@@ -33,7 +33,8 @@ class cwrap(object): | |||
FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($call_arg);") | |||
|
|||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments, | |||
ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease] | |||
ArgumentReferences, BeforeAfterCall, ReturnArguments, | |||
GILRelease, WrapDim] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/WrapDim.py
Outdated
class WrapDim(CWrapPlugin): | ||
|
||
DIM_WRAP_TEMPLATE = Template( | ||
"""if (${arg_to_wrap} < 0) ${arg_to_wrap} += ${arg_tensor}->nDimension;""") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/WrapDim.py
Outdated
arg_tensor = arg.get('wrap_dim') | ||
|
||
new_code.append(self.DIM_WRAP_TEMPLATE.substitute( | ||
arg_to_wrap="arg_" + arg.get('formal_name', arg['name']), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/autograd/_functions/reduce.py
Outdated
|
||
grad_idx_tuple = idx_tuple[:self.dim] + (zero_idx,) + idx_tuple[self.dim + 1:] | ||
grad_input[grad_idx_tuple] = grad_output[idx_tuple] * input_copy.prod() | ||
if len(single_zero_idx) > 0: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_autograd.py
Outdated
new_constructor_args[arg_idx] *= dim_perm[i] | ||
if dim_perm[i] == -1: | ||
test_name += "_negdimarg" + str(arg_idx) | ||
new_constructor_args = tuple(new_constructor_args) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_autograd.py
Outdated
new_args[arg_idx] *= dim_perm[i] | ||
if dim_perm[i] == -1: | ||
test_name += "_negdimarg" + str(arg_idx) | ||
new_args = tuple(new_args) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0): | ||
def neg_dim_test(self): | ||
if isinstance(tensor_arg, list): | ||
assert 0 not in types and 1 not in types |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
for decl in neg_dim_tests: | ||
if len(decl) == 4: | ||
name, tensor_arg, arg_constr, types = decl | ||
extra_dim = 0 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/WrapDim.py
Outdated
"""${arg_tensor}->nDimension""") | ||
|
||
DIM_WRAP_CHECK_TEMPLATE = Template( | ||
"""THPUtils_assert(${arg_to_wrap} >= -(${ndim}) && ${arg_to_wrap} < (${ndim}), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/WrapDim.py
Outdated
"dimension out of range (expected to be in range of [%d, %d], but got %d)", | ||
-(${ndim}), (${ndim})-1, ${arg_to_wrap})""") | ||
|
||
DIM_WRAP_CODE_TEMPLATE = Template( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/WrapDim.py
Outdated
arg_tensor = params[0] | ||
|
||
arg_tensor = "arg_" + arg_tensor | ||
arg_to_wrap = "arg_" + arg.get('assign_name', arg['name']) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -191,6 +191,8 @@ static PyObject * THPTensor_(select)(THPTensor *self, PyObject *args) | |||
return NULL; | |||
|
|||
int ndim = THTensor_(nDimension)(LIBRARY_STATE self->cdata); | |||
if (dim<0) dim += ndim; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if 'wrap_dim' not in arg: | ||
continue | ||
|
||
params = arg.get('wrap_dim').split("+") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
5013762
to
448bf55
Compare
Thanks a lot Alban! |
unswitch predicate equality fix
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/1108
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Co-authored-by: Sukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
Add
WrapDim
plugin to cwrap that allows to support negative dimension index.Add generic test for negative dimension.
Add support for negative dimension for all functions that can take a dimension as parameter.