In [1]:
import torch

In [2]:
# #define countNeighbors_pyArguments_t \
# torch::Tensor queryPositions_, torch::Tensor querySupport_, \
# torch::Tensor sortedPositions_, torch::Tensor sortedSupport_, \
# torch::Tensor domainMin_, torch::Tensor domainMax_, torch::Tensor periodicity_, \
# float_t hCell, torch::Tensor cellBegin_, torch::Tensor cellEnd_, torch::Tensor cellIndices_, torch::Tensor cellLevel_, torch::Tensor cellResolutions_, \
# std::optional<torch::Tensor> hashMapOffset_, std::optional<torch::Tensor> hashMapOccupancy_, std::optional<torch::Tensor> sortedCells_, int32_t hashMapLength, bool verbose

# #define countNeighbors_functionArguments_t \
# torch::Tensor queryPositions_, torch::Tensor querySupport_, \
# torch::Tensor sortedPositions_, torch::Tensor sortedSupport_, \
# torch::Tensor domainMin_, torch::Tensor domainMax_, torch::Tensor periodicity_, \
# float_t hCell, torch::Tensor offsets_, \
# torch::Tensor cellBegin_, torch::Tensor cellEnd_, torch::Tensor cellIndices_, torch::Tensor cellLevel_, torch::Tensor cellResolutions_, \
# std::optional<torch::Tensor> hashMapOffset_, std::optional<torch::Tensor> hashMapOccupancy_, std::optional<torch::Tensor> sortedCells_, int32_t hashMapLength, bool verbose, \
# torch::Tensor neighborCounters_, torch::Tensor neighborAccessCounters_, torch::Tensor neighborHashCollisions_, torch::Tensor neighborSynchronousCounters_, torch::Tensor neighborSupports_

In [46]:
arguments_toml = '''

queryPositions ={type = "tensor[scalar_t]",dim = 2}

querySupports.type = "tensor[scalar_t]"

sortedPositions = {type = "tensor[scalar_t]",dim = 2}

sortedSupports.type = "tensor[scalar_t]"

domainMin.type = "tensor[scalar_t]"
domainMax.type = "tensor[scalar_t]"
periodicity.type = "tensor[bool]"

hCell.type = "scalar_t"

offsets = {type = "tensor[int32_t]", dim = 2, pythonArg = false}

cellBegin.type = "tensor[int32_t]"
cellEnd.type = "tensor[int32_t]"
cellIndices.type = "tensor[int32_t]"
cellLevel.type = "tensor[int32_t]"
cellResolutions = {type = "tensor[int32_t]", dim = 2}

hashMapOffset.type = "tensor[int32_t]"
hashMapOffset.optional = true
hashMapOccupancy.type = "tensor[int32_t]"
hashMapOccupancy.optional = true
sortedCells.type = "tensor[int32_t]"
sortedCells.optional = true
hashMapLength.type = "int32_t"
hashMapLength.optional = true

verbose.type = "bool"

neighborCounters = {type = "tensor[int32_t]", pythonArg = false, const = false}
neighborAccessCounters = {type = "tensor[int32_t]", pythonArg = false, const = false}
neighborHashCollisions = {type = "tensor[int32_t]", pythonArg = false, const = false}
neighborSynchronousCounters = {type = "tensor[int32_t]", pythonArg = false, const = false}
neighborSupports = {type = "tensor[scalar_t]", pythonArg = false, const = false}
'''

In [47]:
import toml

parsed = toml.loads(arguments_toml)

In [None]:
def transformToArgument(key, value, includeType = True, addUnderScore = False, includeOptional = False, functionOnly = False, typeFormat = 'pyBind'):
    if not functionOnly and 'pythonArg' in value and value['pythonArg'] == False:
        return ""
    if includeType:
        if typeFormat == 'pyBind':
            if 'tensor' in value['type']:
                type_str = f"torch::Tensor"
            else:
                type_str = f"{value['type']}"
            if 'optional' in value and value['optional']:
                type_str = f"std::optional<{type_str}>"
        elif typeFormat == 'compute':
            if 'tensor' in value['type']:
                ty = value['type'].split('[')[1].split(']')[0] if '[' in value['type'] else 'scalar_t'
                type_str = f"{'c' if 'const' not in value or value['const'] else ''}ptr_t<{ty}, {1 if 'dim' not in value else value['dim']}>"
            else:
                type_str = f"{value['type']}"
            
    else:
        type_str = ""

    if addUnderScore:
        name_str = f"{key}_"
    else:
        name_str = f"{key}"

    if not includeOptional and 'optional' in value and value['optional']:
        return ""
    return f"{type_str} {name_str}"

def generateFunctionArguments(parsedToml, **kwargs):
    out = []

    for key, value in parsedToml.items():
        out.append(transformToArgument(key, value, **kwargs))

    out = [x for x in out if x != ""]



    return out

generateFunctionArguments(parsed, includeOptional = True, functionOnly = True, typeFormat = 'compute')


def generateTensorAccessors(parsedToml, optional = False):
    out = []
    for key, value in parsedToml.items():
        if optional: # only output optional values
            if 'optional' in value and value['optional']:
                if 'tensor' in value['type']:
                    ty = value['type'].split('[')[1].split(']')[0] if '[' in value['type'] else 'scalar_t'
                    dim = 1 if 'dim' not in value else value['dim']
                    out += [f"\tauto {key} = getAccessor<{ty}, {dim}>({key}_.value(), \"{key}\", useCuda, verbose_);\n"]
                else:
                    out += [f"\tauto {key} = {key}_.value();\n"]
        else:
            if 'optional' in value and value['optional']:
                continue

            if 'tensor' in value['type']:
                ty = value['type'].split('[')[1].split(']')[0] if '[' in value['type'] else 'scalar_t'
                dim = 1 if 'dim' not in value else value['dim']

                out += [f"\tauto {key} = getAccessor<{ty}, {dim}>({key}_, \"{key}\", useCuda, verbose_);\n"]
            else:
                out += [f"\tauto {key} = {key}_;\n"]
    return out

# generateTensorAccessors(parsed, optional = True)

In [163]:
fileName = 'src/torchCompactRadius/cppSrc/countNeighbors_mlm.h'
filePrefix = fileName.split('.')[0].split('/')[-1]
with open(fileName, 'r') as f:
    lines = f.readlines()

    # print(lines)

tomlBegin = lines.index('/** BEGIN TOML\n')
tomlEnd = lines.index('*/ // END TOML\n')
tomlDefinitions = ''.join(lines[tomlBegin + 1: tomlEnd])

parsedToml = toml.loads(tomlDefinitions)

In [164]:

endOfDefines = lines.index('/// End the definitions for auto generating the function arguments\n')

In [165]:
prefixLines = lines[:tomlEnd+1]
suffixLines = lines[endOfDefines:]

if '// AUTO GENERATE ACCESSORS\n' in suffixLines:
    accessorBegin = suffixLines.index('// AUTO GENERATE ACCESSORS\n')
    accessorEnd = suffixLines.index('// END AUTO GENERATE ACCESSORS\n')
    accessorLines = generateTensorAccessors(parsedToml, optional = False)

    suffixLines = suffixLines[:accessorBegin+1] + accessorLines + suffixLines[accessorEnd:]

if '// AUTO GENERATE OPTIONAL ACCESSORS\n' in suffixLines:
    accessorBegin = suffixLines.index('// AUTO GENERATE OPTIONAL ACCESSORS\n')
    accessorEnd = suffixLines.index('// END AUTO GENERATE OPTIONAL ACCESSORS\n')
    accessorLines = generateTensorAccessors(parsedToml, optional = True)

    suffixLines = suffixLines[:accessorBegin+1] + accessorLines + suffixLines[accessorEnd:]
    

In [166]:

numOptionals = len([x for x in parsedToml.values() if 'optional' in x and x['optional']])

pyArguments = ', '.join(generateFunctionArguments(parsedToml, includeOptional = True, functionOnly = False, typeFormat = 'pyBind'))
fnArguments = ', '.join(generateFunctionArguments(parsedToml, includeOptional = False, functionOnly = True, typeFormat = 'pyBind', addUnderScore = True))
computeArguments = ', '.join(generateFunctionArguments(parsedToml, includeOptional = False, functionOnly = True, typeFormat = 'compute', addUnderScore = False))
arguments = ', '.join(generateFunctionArguments(parsedToml, includeType = False, includeOptional = False, functionOnly = True, typeFormat = 'pyBind', addUnderScore = False))
arguments_ = ', '.join(generateFunctionArguments(parsedToml, includeType = False, includeOptional = False, functionOnly = True, typeFormat = 'pyBind', addUnderScore = True))

if numOptionals > 0:
    fnArgumentsOptional = ', '.join(generateFunctionArguments(parsedToml, includeOptional = True, functionOnly = True, typeFormat = 'pyBind', addUnderScore = True))
    computeArgumentsOptional = ', '.join(generateFunctionArguments(parsedToml, includeOptional = True, functionOnly = True, typeFormat = 'compute', addUnderScore = False))
    argumentsOptional = ', '.join(generateFunctionArguments(parsedToml, includeType = False, includeOptional = True, functionOnly = True, typeFormat = 'pyBind', addUnderScore = False))
    argumentsOptional_ = ', '.join(generateFunctionArguments(parsedToml, includeType = False, includeOptional = True, functionOnly = True, typeFormat = 'pyBind', addUnderScore = True))
else:
    fnArgumentsOptional = fnArguments
    computeArgumentsOptional = computeArguments
    argumentsOptional = arguments


generatedLines = []
generatedLines += ['\n', '// DEF PYTHON BINDINGS\n']
generatedLines += [f'#define {filePrefix}_pyArguments_t {pyArguments}\n']
generatedLines += ['// DEF FUNCTION ARGUMENTS\n']
generatedLines += [f'#define {filePrefix}_functionArguments_t {fnArguments}\n']
generatedLines += [f'#define {filePrefix}_functionArgumentsOptional_t {fnArgumentsOptional}\n']

generatedLines += ['// DEF COMPUTE ARGUMENTS\n']
generatedLines += [f'#define {filePrefix}_computeArguments_t {computeArguments}\n']
generatedLines += [f'#define {filePrefix}_computeArgumentsOptional_t {computeArgumentsOptional}\n']

generatedLines += ['// DEF ARGUMENTS\n']
generatedLines += [f'#define {filePrefix}_arguments_t {arguments}\n']
generatedLines += [f'#define {filePrefix}_argumentsOptional_t {argumentsOptional}\n']
generatedLines += [f'#define {filePrefix}_arguments_t_ {arguments_}\n']
generatedLines += [f'#define {filePrefix}_argumentsOptional_t_ {argumentsOptional_}\n']

generatedLines += ['\n', '// END PYTHON BINDINGS\n']

with open(fileName, 'w') as f:
    f.writelines(prefixLines + generatedLines + suffixLines)


In [44]:
def transformToArgument(argument, includeType = True, addUnderScore = False, includeOptional = False, functionOnly = False):
    if not functionOnly and argument['functionOnly']:
        return ""
    if includeType:
        if argument['type'] == 'tensor':
            type_str = f"torch::Tensor"
        else:
            type_str = f"{argument['type']}"
    else:
        type_str = ""
    if addUnderScore and argument['type'] == 'tensor':
        name_str = f"{argument['name']}_"
    else:
        name_str = f"{argument['name']}"
    if not includeOptional and 'optional' in argument and argument['optional']:
        return ""
    return f"{type_str} {name_str}"

def print_arguments(arguments, functionOnly=False, addUnderScore=True, includeType=True, includeOptional=False):
    processed = [transformToArgument(arg, functionOnly=functionOnly, addUnderScore=addUnderScore, includeType=includeType, includeOptional=includeOptional) for arg in arguments]
    processed_non_empty = [arg for arg in processed if arg]
    # print(", ".join(processed_non_empty))
    return ", ".join(processed_non_empty)


def generateProcessArgumentsNonOptional(arguments):
    # print(f"auto getFunctionArguments({print_arguments(arguments, functionOnly=True)}, bool useCuda){{")
    for argument in arguments:
        if 'optional' in argument and argument['optional']: 
            continue
        if argument['type'] == 'tensor':
            print(f"\tauto {argument['name']} = getAccessor<{argument['tensor_type']}, {argument['dim']}>({argument['name']}_, \"{argument['name']}\", useCuda);")
    # print("}")

def generateProcessArgumentsOptional(arguments):
    for argument in arguments:
        if not 'optional' in argument or not argument['optional']:
            continue
        if argument['type'] == 'tensor':
            print(f"\tauto {argument['name']} = getAccessor<{argument['tensor_type']}, {argument['dim']}>({argument['name']}_.value(), \"{argument['name']}\", useCuda);")

# print_arguments(arguments)

print('# Non-optional arguments')
generateProcessArgumentsNonOptional(arguments)

print('# Optional arguments')
generateProcessArgumentsOptional(arguments)


print(print_arguments(arguments, addUnderScore=False, includeOptional=False, includeType=False))

# Non-optional arguments
	auto input = getAccessor<scalar_t, 2>(input_, "input", useCuda);
	auto result = getAccessor<scalar_t, 1>(result_, "result", useCuda);
# Optional arguments
	auto inputOptional = getAccessor<scalar_t, 2>(inputOptional_.value(), "inputOptional", useCuda);
 input,  dim,  index
