-
Notifications
You must be signed in to change notification settings - Fork 30
/
ops.py
101 lines (74 loc) · 2.8 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import sys
import tensorflow as tf
import subprocess
from tensorflow.python.framework import ops
# Register ops for compilation here
OP_NAMES = ['backward_warp', 'downsample', 'correlation', 'forward_warp']
cwd = os.getcwd()
os.chdir(os.path.dirname(os.path.realpath(__file__)))
os.chdir("../../ops")
def compile(op=None):
if op is not None:
to_compile = [op]
else:
to_compile = OP_NAMES
tf_inc = tf.sysconfig.get_include()
for n in to_compile:
base = n + "_op"
fn_cu_cc = base + ".cu.cc"
fn_cu_o = base + ".cu.o"
fn_cc = base + ".cc"
fn_o = base + ".o"
fn_so = base + ".so"
cuda_lib64_path_arg = "-L /usr/local/cuda-8.0/lib64"
nvcc_cmd = "nvcc -std=c++11 -c -gencode=arch=compute_30,code=sm_30 -o {} -I {} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC"
nvcc_cmd = nvcc_cmd.format(" ".join([fn_cu_o, fn_cu_cc]),
tf_inc)
subprocess.check_output(nvcc_cmd, shell=True)
gcc_cmd = "{} -std=c++11 -shared -o {} -I {} -fPIC -lcudart -D GOOGLE_CUDA=1 {}"
gcc_cmd = gcc_cmd.format('g++',
" ".join([fn_so, fn_cu_o, fn_cc]),
tf_inc,
cuda_lib64_path_arg)
subprocess.check_output(gcc_cmd, shell=True)
if __name__ == "__main__":
compile()
module = sys.modules[__name__]
for n in OP_NAMES:
lib_path = './{}_op.so'.format(n)
try:
op_lib = tf.load_op_library(lib_path)
except:
compile(n)
op_lib = tf.load_op_library(lib_path)
setattr(module, '_' + n + '_module', op_lib)
os.chdir(cwd)
def correlation(first, second, **kwargs):
return _correlation_module.correlation(first, second, **kwargs)[0]
backward_warp = _backward_warp_module.backward_warp
downsample = _downsample_module.downsample
forward_warp = _forward_warp_module.forward_warp
# Register op gradients
@ops.RegisterGradient("BackwardWarp")
def _BackwardWarpGrad(op, grad):
grad0 = _backward_warp_module.backward_warp_grad(
grad, op.inputs[0], op.inputs[1])
return [None, grad0]
@ops.RegisterGradient("ForwardWarp")
def _ForwardWarpGrad(op, grad):
grad0 = _forward_warp_module.forward_warp_grad(
grad, op.inputs[0])
return [grad0]
@ops.RegisterGradient("Correlation")
def _CorrelationGrad(op, in_grad, in_grad1, in_grad2):
grad0, grad1 = _correlation_module.correlation_grad(
in_grad, op.inputs[0], op.inputs[1],
op.outputs[1], op.outputs[2],
kernel_size=op.get_attr('kernel_size'),
max_displacement=op.get_attr('max_displacement'),
pad=op.get_attr('pad'),
stride_1=op.get_attr('stride_1'),
stride_2=op.get_attr('stride_2'))
return [grad0, grad1]
ops.NotDifferentiable("Downsample")