Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom operator code generator script #6881

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 105 additions & 0 deletions tensorflow/tools/op_generator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Tensorflow Custom Operator Code Outline Generator

Writing a tensorflow operator requires writing fair amounts of boilerplate C++ and CUDA code.
This script generates code for the CPU and GPU version of a tensorflow operator.
More specifically, given tensorflow `inputs`, `outputs` and `attribute`s, it generates:

* C++ Header file that defines the operator class, templated on Device.
* C++ Header file that defines the CPU implementation of the operator.
* C++ Source file with Shape Function, REGISTER_OP and REGISTER_KERNEL_BUILDER constructs.
* Cuda Header that defines the GPU implementation of the operator, including a CUDA kernel.
* Cuda Source file with GPU REGISTER_KERNEL_BUILDER's for the operator.
* python unit test case, which constructs random input data, and calls the operator.
* Makefile for compiling the operator into a shared library, using g++ and nvcc.

## Requirements

The jinja2 templating engine is required, as well as a tensorflow installation for building the operator.

```bash
pip install jinja2
```

## Usage

The user should edit the `op_config.py` file and define the operator:

* inputs and optionally, their shapes.
* outputs and optionally, their outputs.
* polymorphic type attributes.
* other attributes.
* documentation.

Once complete the script can be called as follows

```bash
$ python create_op.py --project=tensorflow --library=custom MyCustomOperator
```

to create the following directory structure

```bash
$ tree custom/
custom/
├── custom_op_op_cpu.cpp
├── custom_op_op_cpu.h
├── custom_op_op_gpu.cu
├── custom_op_op_gpu.cuh
├── custom_op_op.h
├── Makefile
└── test_custom_op.py
```

The `--project` and `--library` flags specify C++ namespaces within which the operator is created. Additionally, the Makefile will created a `custom.so` that can be loaded with `tf.load_op_library('custom.so')`.


The operator inputs and their optional shapes should be specified as a list of tuples. If concrete dimensions are specified, corresponding checks will be generated in the Shape Function associated with the operator. If `None` is supplied, a shape of `(N, )` where `N=1024` is assumed.

```python
# Operator inputs and shapes
# If shape is None a default one dimensional shape of (N, ) will be given
# Shape dimensions may be None, in which case they will not be checked
op_inputs = [
("uvw: FT", (100, 10, 3)),
("lm: FT", (75, None)),
("frequency: FT", (32,)),
("mapping: int32", None),
]
```

Similarly the operator outputs and their shapes should be specified as a list of tuples. Dimensions may not be None as memory allocations for the outputs will be created in the CPU and GPU ops.

```python
# Operator outputs and shapes
# Shape dimensions should not be None
op_outputs = [
("complex_phase: CT", (75, 100, 10, 32))
]
```

Given these inputs and outputs, CPU and GPU operators are created with named variables corresponding to the inputs and outputs. Additionally, a CUDA kernel with the given inputs and outputs is created, as well as a shape function checking the rank and dimensions of the supplied inputs.

Next, polymorphic type attributes should be supplied. The generator will template the operators on type attributes. It will also generate concrete permutations of REGISTER_KERNEL_BUILDER for both the CPU and GPU op using the actual types supplied in the type attributes (float, double, complex64 and complex128) below.

```python
# Attributes specifying polymorphic types
op_type_attrs = [
"FT: {float, double} = DT_FLOAT",
"CT: {complex64, complex128} = DT_COMPLEX64"]
```

Other attributes may be specified (and will be output in the REGISTER_OP) directive, but are not catered for automatically by the generator code as the range of attribute behaviour is complex.

```python
# Any other attributes
op_other_attrs = [
"iterations: int32 >= 2",
]
```

Finally operator documentation may also be supplied.

```python
# Operator documentation
op_doc = """Custom Operator"""
```
180 changes: 180 additions & 0 deletions tensorflow/tools/op_generator/create_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import argparse
from collections import namedtuple
import itertools
import os
import re

import jinja2
from tensorflow.python.framework.dtypes import (
_STRING_TO_TF,
_TYPE_TO_STRING,
_TF_TO_NP)

parser = argparse.ArgumentParser()
parser.add_argument('op_name')
parser.add_argument('-p', '--project', default='project')
parser.add_argument('-l', '--library', default='library')
args = parser.parse_args()

LIBRARY = args.library
PROJECT = args.project

# Set a shape for our variables
N = 1024
var_shape = (N, )

FIRST_CAP_RE = re.compile('(.)([A-Z][a-z]+)')
ALL_CAP_RE = re.compile('([a-z0-9])([A-Z])')

# Convert CamelCase op names to snake case
def camel_to_snake_case(name):
s1 = FIRST_CAP_RE.sub(r'\1_\2', name)
return ALL_CAP_RE.sub(r'\1_\2', s1).lower()

# Derive a C++ header guard from the header name
def header_guard(header_name):
guard_str = header_name.replace('.', '_')
return ''.join([LIBRARY, '_', guard_str]).upper()

def parse_inout(s, shape):
var, type_ = tuple(c.strip() for c in s.split(":"))

if "*" in type_:
raise ValueError("Failed to parse '{}'. "
"List lengths are not yet supported".format(s))

TF_TYPES = _TYPE_TO_STRING.values()
tf_type = "tensorflow::" + type_ if type_ in TF_TYPES else type_
np_type = ("np." + _TF_TO_NP[_STRING_TO_TF[type_]].__name__
if type_ in _STRING_TO_TF else type_)

shape = var_shape if shape is None else shape

return var, type_, tf_type, np_type, shape

def parse_attr_type(s):
var, type = tuple(c.strip() for c in s.split(":"))

split = type.split("=")
default = None if len(split) > 1 else split[1].strip()
types = split[0].strip()

if types.startswith("{") and types.endswith("}"):
types = tuple(c.strip() for c in types[1:-1].split(","))
else:
types = tuple(types,)

TF_TYPES = _TYPE_TO_STRING.values()
tf_types = tuple("tensorflow::" + t if t in TF_TYPES else t for t in types)
np_types = ("np." + _TF_TO_NP[_STRING_TO_TF[t]].__name__
if t in _STRING_TO_TF else type_ for t in types)

return s, var, types, tf_types, np_types, default

def strip_and_split(s, sep):
return (c.strip() for c in s.split(sep))

InOut = namedtuple("InOut", ["name", "type",
"tf_type", "np_type", "shape"])
Attr = namedtuple("Attr", ["original", "name", "types",
"tf_types", "np_types", "default"])

from op_config import (op_inputs, op_outputs,
op_type_attrs, op_other_attrs, op_doc)

# Parse input ops
op_inputs = [InOut(*parse_inout(i, s)) for i, s in op_inputs]

# Parse output ops
op_outputs = [InOut(*parse_inout(o, s)) for o, s in op_outputs]

# Parse type constrained attrs
op_type_attrs = [Attr(*parse_attr_type(a)) for a in op_type_attrs]

type_constraints = [[t for t in a.np_types]for a in op_type_attrs]

# Permute the type constraints
op_type_perms = itertools.product(*(a.tf_types for a in op_type_attrs))
op_type_perms = [list(p) for p in op_type_perms]

# Snake case python version of the operator
py_op_name = camel_to_snake_case(args.op_name)

# Create dictionary with variables required for creating the templates
D = {
'op_name' : args.op_name,
'py_op_name' : py_op_name,
'project' : PROJECT,
'library' : LIBRARY,
'shared_library' : ''.join([LIBRARY, '.so']),
}

D.update({
'op_inputs' : op_inputs,
'op_outputs' : op_outputs,
'op_type_attrs' : op_type_attrs,
'op_other_attrs' : op_other_attrs,
'op_type_perms' : op_type_perms,
'type_constraints' : type_constraints,
'op_doc' : op_doc,
})

# Filenames
D.update({
'main_header_file' : ''.join([py_op_name, '_op.h']),
'cpp_header_file' : ''.join([py_op_name, '_op_cpu.h']),
'cpp_source_file' : ''.join([py_op_name, '_op_cpu.cpp']),
'cuda_header_file' : ''.join([py_op_name, '_op_gpu.cuh']),
'cuda_source_file' : ''.join([py_op_name, '_op_gpu.cu']),
'python_test_file' : ''.join(['test_', py_op_name, '.py']),
'makefile' : 'Makefile',
})

# C++ header guards
D.update({
'main_header_guard' : header_guard(D['main_header_file']),
'cpp_header_guard' : header_guard(D['cpp_header_file']),
'cuda_header_guard' : header_guard(D['cuda_header_file']),
})

NB = '_namespace_begin'
NE = '_namespace_stop'

# C++ namespace
D.update({
'project_namespace_start' : ''.join([PROJECT, NB]).upper(),
'project_namespace_stop' : ''.join([PROJECT, NE]).upper(),
'op_namespace_start' : ''.join([PROJECT, '_', py_op_name, NB]).upper(),
'op_namespace_stop' : ''.join([PROJECT, '_', py_op_name, NE]).upper(),
})

# CUDA kernel
D.update({
'kernel_name' : ''.join([LIBRARY, '_', py_op_name]),
})


jinja_loader = jinja2.FileSystemLoader('templates')
jinja_env = jinja2.Environment(loader=jinja_loader,
trim_blocks=False, lstrip_blocks=False)

# Create a filter for formatting a list
jinja_env.filters['format_list'] = lambda l, p: [p % s for s in l]

# Create library directory if it does not exist
if not os.path.exists(LIBRARY):
os.makedirs(LIBRARY)

def render(template, output):
""" Hook to render template file to output """
with open(os.path.join(LIBRARY, D[output]), 'w') as f:
header_template = jinja_env.get_template(template)
f.write(header_template.render(**D))

render('main_header.j2', 'main_header_file')
render('cpp_header.j2', 'cpp_header_file')
render('cpp_source.j2', 'cpp_source_file')
render('cuda_header.j2', 'cuda_header_file')
render('cuda_source.j2', 'cuda_source_file')
render('test_source.j2', 'python_test_file')
render('Makefile.j2', 'makefile')
26 changes: 26 additions & 0 deletions tensorflow/tools/op_generator/op_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Operator inputs and shapes
# If shape is None a default one dimension shape of (N, ) will be given
# Shape dimensions may be None
op_inputs = [
("uvw: FT", (100, 10, 3)),
("lm: FT", (75, None)),
("frequency: FT", (32,)),
("mapping: int32", None),
]

# Operator outputs and shapes
# Shape dimensions should not be None
op_outputs = [
("complex_phase: CT", (75, 100, 10, 32))
]

# Attributes specifying polymorphic types
op_type_attrs = [
"FT: {float, double} = DT_FLOAT",
"CT: {complex64, complex128} = DT_COMPLEX64"]

# Any other attributes
op_other_attrs = []

# Operator documentation
op_doc = """Custom Operator"""
76 changes: 76 additions & 0 deletions tensorflow/tools/op_generator/templates/Makefile.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Tensorflow includes and defines
TF_INC=$(shell python -c 'import tensorflow as tf; print tf.sysconfig.get_include()')
TF_CUDA=$(shell python -c 'import tensorflow as tf; print int(tf.test.is_built_with_cuda())')
MB_INC=../../../../include

TF_FLAGS=-D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES -D_GLIBCXX_USE_CXX11_ABI=0

# Dependencies
DEPDIR:=.d
$(shell mkdir -p $(DEPDIR) >/dev/null)
DEPFLAGS=-MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td

# Define our sources, compiling CUDA code if it's enabled
ifeq ($(TF_CUDA), 1)
SOURCES=$(wildcard *.cpp *.cu)
else
SOURCES=$(wildcard *.cpp)
endif

# Define objects and shared_library
OBJECTS=$(addsuffix .o, $(basename $(SOURCES)))
LIBRARY={{shared_library}}

# Compiler flags
INCLUDES= -I $(TF_INC) -I $(MB_INC)
CPPFLAGS=-std=c++11 $(TF_FLAGS) $(INCLUDES) -fPIC -fopenmp -O2 -march=native -mtune=native
NVCCFLAGS=-std=c++11 -D GOOGLE_CUDA=$(TF_CUDA) $(TF_FLAGS) $(INCLUDES) \
-x cu --compiler-options "-fPIC" --gpu-architecture=sm_30 -lineinfo

LDFLAGS = -fPIC -fopenmp

# Compiler directives
COMPILE.cpp = g++ $(DEPFLAGS) $(CPPFLAGS) -c
COMPILE.nvcc = nvcc --compiler-options " $(DEPFLAGS)" $(NVCCFLAGS) -c

all : $(LIBRARY)

%.o : %.cpp
$(COMPILE.cpp) $<

%.o : %.cu
$(COMPILE.nvcc) $<

clean :
rm -f $(OBJECTS) $(LIBRARY)

$(LIBRARY) : $(OBJECTS)
g++ -shared $(OBJECTS) -o $(LIBRARY) $(LDFLAGS)

$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d

-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(SRCS)))

# Compiler directives
COMPILE.cpp = g++ $(DEPFLAGS) $(CPPFLAGS) -c
COMPILE.nvcc = nvcc --compiler-options " $(DEPFLAGS)" $(NVCCFLAGS) -c

all : $(LIBRARY)

%.o : %.cpp
$(COMPILE.cpp) $<

%.o : %.cu
$(COMPILE.nvcc) $<

clean :
rm -f $(OBJECTS) $(LIBRARY)

$(LIBRARY) : $(OBJECTS)
g++ -shared $(OBJECTS) -o $(LIBRARY) $(LDFLAGS)

$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d

-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(SRCS)))