-
Notifications
You must be signed in to change notification settings - Fork 0
/
box_filter.py
31 lines (18 loc) · 857 Bytes
/
box_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
def diff_x(input, r):
assert input.shape.ndims == 4
left = input[:, :, r:2 * r + 1]
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
output = tf.concat([left, middle, right], axis=2)
return output
def diff_y(input, r):
assert input.shape.ndims == 4
left = input[:, :, :, r:2 * r + 1]
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
output = tf.concat([left, middle, right], axis=3)
return output
def box_filter(x, r):
assert x.shape.ndims == 4
return diff_y(tf.cumsum(diff_x(tf.cumsum(x, axis=2), r), axis=3), r)