Skip to content

Commit

Permalink
Improve Declarations.yaml: (pytorch#81)
Browse files Browse the repository at this point in the history
* Improve Declarations.yaml:

 - translate defaults to C++ values
 - include names of returned values
 - mark keyword-only arguments

* Add comment to translate_default
  • Loading branch information
colesbury authored and zdevito committed Nov 2, 2017
1 parent f248536 commit 2c6136a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
32 changes: 25 additions & 7 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,31 @@ def to_return_type(arg, option):
if not is_mutable_formal_argument(arg, option):
rt = 'const ' + rt
return {
'name': arg['name'],
'type': rt,
'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']),
}


def create_generic(top_env, declarations):
# translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default):
if default is None:
# cause the default constructor for the object to run
return '{}'
if 'if_true' in argument:
return argument['default'] == argument['if_true']
for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS:
default = re.sub(pattern, replacement, str(default))
if type_str in {'Scalar', 'int64_t'}:
return int(default)
elif type_str == 'double':
return float(default)
elif type_str == 'bool':
assert default.lower() in ['true', 'false']
return default.lower() == 'true'
else:
return default

# change from THTensor* to Tensor & so we get how it will appear
# in the aten argument list...
Expand All @@ -233,12 +252,11 @@ def translate_formal(argument, option):
'type': type_str,
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
}
if 'kwarg_only' in argument:
translated['kwarg_only'] = argument['kwarg_only']
if 'default' in argument:
if 'if_true' in argument:
val = argument['default'] == argument['if_true']
translated['default'] = str(val).lower()
else:
translated['default'] = argument['default']
default = argument['default']
translated['default'] = translate_default(argument, type_str, default)
if argument.get('output'):
translated['output'] = True
return translated
Expand Down Expand Up @@ -315,8 +333,8 @@ def formal_with_default(f):
v = f.get('default')
if v is None:
return s
for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS:
v = re.sub(pattern, replacement, str(v))
if isinstance(v, bool):
v = str(v).lower()
return '{}={}'.format(s, v)

def get_broadcast_argument(option):
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/templates/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ATen/SparseTensorRef.h"
#include "ATen/ScalarType.h"
#include "ATen/Scalar.h"
#include "ATen/Tensor.h"

// To solve the conflict of s_addr in inaddr.h
#ifdef _MSC_VER
Expand Down

0 comments on commit 2c6136a

Please sign in to comment.