diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 2aaba65b9a79..c12e9b2003d8 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -128,6 +128,8 @@ def load_aten_declarations(path): for arg in declaration['arguments']: arg['simple_type'] = get_simple_type(arg) + for arg in declaration['schema_order_arguments']: + arg['simple_type'] = get_simple_type(arg) for ret in declaration['returns']: ret['simple_type'] = get_simple_type(ret) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 406b65079838..283f0153faae 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -245,9 +245,6 @@ UNPACK_TENSOR = CodeTemplate("""\ auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""") -UNPACK_OPTIONS = CodeTemplate("""\ -auto ${arg_name}_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);""") - LEGACY_UNPACK_OPTIONS = CodeTemplate("""\ auto ${arg_name}_ = TensorOptions(${arg_name});""") @@ -1236,7 +1233,12 @@ def requires_unpack(arg): body = [] unpacked_args = [] unpacked_args_simple_type = {} - for i, arg in enumerate(declaration['arguments']): + if declaration['use_c10_dispatcher'] == 'full': + arguments = declaration['schema_order_arguments'] + else: + assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' + arguments = declaration['arguments'] + for i, arg in enumerate(arguments): if not requires_unpack(arg): unpacked_args.append(arg['name']) unpacked_args_simple_type[arg['name']] = arg['simple_type'] @@ -1258,11 +1260,9 @@ def requires_unpack(arg): # Okay, we are abusing the definition of 'unpack' here a bit, # although it's still getting the non-variable from the variable # (in this case via TensorOptions rather than Variable/Tensor). - if declaration['use_c10_dispatcher'] == 'full': - body.append(UNPACK_OPTIONS.substitute(arg_name=arg['name'])) - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - body.append(LEGACY_UNPACK_OPTIONS.substitute(arg_name=arg['name'])) + assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper', \ + "VariableKernel shouldn't take TensorOptions if the op is c10-full" + body.append(LEGACY_UNPACK_OPTIONS.substitute(arg_name=arg['name'])) unpacked_args.append(arg['name'] + '_') unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type']