1
+ import tensorflow as tf
2
+ from NN .utils import extractInterpolated
3
+
4
+ def create1DGaussian (size , stds , shifts ):
5
+ B = tf .shape (stds )[0 ]
6
+ tf .assert_equal (tf .shape (shifts ), (B , ))
7
+ x = tf .linspace (- size // 2 + 1 , size // 2 + 1 , size )
8
+ x = tf .cast (x , tf .float32 )
9
+ x = tf .tile (x [None ], [B , 1 ]) + shifts [..., None ]
10
+ x = tf .nn .softmax (- (x ** 2 ) / (2.0 * (stds ** 2 )), axis = - 1 )
11
+ x = tf .reshape (x , [B , size ])
12
+ return x
13
+
14
+ def gaussian_kernel (size , stdsPx , shifts = None ):
15
+ if shifts is None :
16
+ shifts = tf .zeros ((tf .shape (stdsPx )[0 ], 2 ))
17
+
18
+ stds = tf .cast (stdsPx , tf .float32 )
19
+ B = tf .shape (stds )[0 ]
20
+ stds = tf .reshape (stds , [B , 1 ])
21
+
22
+ gX = create1DGaussian (size , stds , shifts [:, 0 ])[..., None ]
23
+ gY = create1DGaussian (size , stds , shifts [:, 1 ])[..., None , :]
24
+ gauss = tf .matmul (gX , gY )
25
+
26
+ gauss = tf .reshape (gauss , [B , size , size , 1 ])
27
+ gauss = tf .tile (gauss , [1 , 1 , 1 , 3 ])
28
+ tf .assert_equal (tf .shape (gauss ), (B , size , size , 3 ))
29
+ gauss = tf .transpose (gauss , [1 , 2 , 3 , 0 ]) # [B, size, size, 1] => [size, size, 1, B]
30
+ tf .assert_equal (tf .shape (gauss ), (size , size , 3 , B ))
31
+ return gauss
32
+
33
+ ############################
34
+ # trying to implement more efficient bluring
35
+ def shiftsPixels (HW , points ):
36
+ d = 1.0 / tf .cast (HW , tf .float32 )
37
+ return points - (tf .floor (points / d ) * d + (d / 2.0 ))
38
+
39
+ def visibleArea (points , HW , size ):
40
+ HW = tf .cast (HW , tf .float32 )
41
+ HW = tf .repeat (HW , repeats = 2 )
42
+ HW = tf .reshape (HW , (1 , 2 ))
43
+ points = points * HW
44
+ points = tf .floor (points )
45
+ points = tf .cast (points , tf .int32 )
46
+
47
+ HW = tf .cast (HW , tf .int32 )
48
+ left = tf .maximum (0 , points - size )
49
+ right = tf .minimum (HW , points + size )
50
+ return left , right
51
+
52
+ def area2indices (left , right , HW , maxN ):
53
+ B = tf .shape (left )[0 ]
54
+ LR = tf .concat ([left , right ], axis = - 1 )
55
+ tf .assert_equal (tf .shape (LR ), (B , 4 ))
56
+
57
+ def f (lr ):
58
+ l , r = lr [:2 ], lr [2 :]
59
+ wh = r - l
60
+ w , h = wh [0 ], wh [1 ]
61
+ # tf.debugging.assert_greater(0, w)
62
+ tf .debugging .assert_less_equal (w , maxN )
63
+ # tf.debugging.assert_greater(0, h)
64
+ tf .debugging .assert_less_equal (h , maxN )
65
+ indices = l [0 ] + tf .range (w ) # [minX, maxX]
66
+ indices = tf .reshape (indices , [1 , - 1 ])
67
+ indices = tf .tile (indices , [h , 1 ])
68
+ shifts = l [1 ] + tf .range (h ) # [minY, maxY]
69
+ indices = indices + shifts [:, None ] * HW
70
+ tf .assert_equal (tf .shape (indices ), (h , w ))
71
+
72
+ pad = maxN ** 2 - tf .size (indices )
73
+ indices = tf .reshape (indices , [- 1 ])
74
+ indices = tf .pad (indices , [[0 , pad ]], constant_values = - 1 )
75
+ return indices
76
+ return tf .map_fn (f , LR , dtype = tf .int32 )
77
+
78
+ def extractBluredX (img , points , R , maxR ):
79
+ img = img [None ]
80
+ B = tf .shape (points )[0 ]
81
+ tf .assert_rank (img , 4 )
82
+ tf .assert_equal (tf .shape (points ), (B , 2 ))
83
+ tf .assert_equal (tf .shape (R ), (B , 1 ))
84
+ H , W = [tf .shape (img )[i ] for i in [1 , 2 ]]
85
+ tf .assert_equal (H , W , 'Image should be square' )
86
+ gaussians = gaussian_kernel (maxR , R , shifts = shiftsPixels (H , points ))
87
+ gaussians = tf .transpose (gaussians , [3 , 0 , 1 , 2 ]) # [size, size, 3, B] => [B, size, size, 3]
88
+ gaussians = tf .reshape (gaussians , [B , - 1 , 3 ])
89
+ sz = tf .shape (gaussians )[1 ]
90
+ # extract areas around the points
91
+ # first, find the visible area for each point
92
+ left , right = visibleArea (points , H , size = maxR )
93
+ # extract the indices of the visible area
94
+ indices = area2indices (left , right , H , sz )
95
+ tf .assert_equal (tf .shape (indices ), (B , sz ** 2 ))
96
+ # extract the visible areas from the image
97
+ flatImg = tf .reshape (img , [1 , H * W , 3 ])
98
+ extracted = tf .gather (flatImg , indices , axis = 1 )[0 ]
99
+ tf .assert_equal (tf .shape (extracted ), (B , sz ** 2 , 3 ))
100
+
101
+ indicesLow = indices [:, 0 , None ]
102
+ extractedWeights = tf .gather (gaussians , indices - indicesLow , batch_dims = 1 )
103
+ tf .assert_equal (tf .shape (extractedWeights ), tf .shape (extracted ))
104
+ extracted = tf .reduce_sum (extracted * extractedWeights , axis = 1 )
105
+ tf .assert_equal (tf .shape (extracted ), (B , 3 ))
106
+ return extracted
107
+ ############################
108
+ def applyBluring (img , kernel ):
109
+ tf .assert_rank (img , 4 )
110
+ tf .assert_rank (kernel , 4 )
111
+ tf .assert_equal (tf .shape (img )[0 ], 1 )
112
+ B = tf .shape (kernel )[- 1 ]
113
+
114
+ imgG = tf .nn .depthwise_conv2d (img , kernel , strides = [1 , 1 , 1 , 1 ], padding = 'SAME' )[0 ]
115
+ H , W = [tf .shape (imgG )[i ] for i in range (2 )]
116
+ imgG = tf .reshape (imgG , [H , W , 3 , - 1 ])
117
+ imgG = tf .transpose (imgG , (3 , 0 , 1 , 2 ))
118
+ imgG = tf .reshape (imgG , (B , H , W , 3 ))
119
+ return imgG
120
+
121
+ def extractBlured (R ):
122
+ '''
123
+ R: list of bluring radiuses, (B, 1)
124
+ '''
125
+ R = tf .reshape (R , (tf .size (R ), 1 ))
126
+ maxR = tf .reduce_max (R )
127
+ maxR = tf .cast (maxR , tf .int32 ) + 1
128
+ gaussians = gaussian_kernel (maxR , R , shifts = tf .zeros ((tf .shape (R )[0 ], 2 )))
129
+ gaussiansN = tf .shape (gaussians )[- 1 ]
130
+
131
+ def f (img , points , ptR ):
132
+ img = img [None ]
133
+ tf .assert_rank (img , 4 )
134
+ B = tf .shape (points )[0 ]
135
+ tf .assert_equal (tf .shape (points ), (B , 2 ))
136
+ tf .assert_equal (tf .shape (ptR ), (B , 1 ))
137
+ blured = applyBluring (img , gaussians )
138
+ tf .assert_equal (tf .shape (blured ), (gaussiansN , tf .shape (img )[1 ], tf .shape (img )[2 ], 3 ))
139
+ # extract the blured values
140
+ blured = extractInterpolated (blured , points [None ])
141
+ tf .assert_equal (tf .shape (blured ), (gaussiansN , B , 3 ))
142
+
143
+ # blured contains the blured values for each point in each gaussian
144
+ # we need to select the blured value for each point based on its radius
145
+ correspondingG = tf .transpose (ptR == R [..., 0 ])
146
+ tf .assert_equal (tf .shape (correspondingG ), (gaussiansN , B ))
147
+
148
+ idx = tf .where (correspondingG )
149
+ tf .assert_equal (tf .shape (idx ), (B , 2 ))
150
+ blured = tf .gather_nd (blured , idx )
151
+ tf .assert_equal (tf .shape (blured ), (B , 3 ))
152
+ return blured
153
+ return f
0 commit comments