Skip to content

Commit cabb839

Browse files
implemented a way to extract blured pixels from an image
1 parent 5861df0 commit cabb839

File tree

3 files changed

+249
-34
lines changed

3 files changed

+249
-34
lines changed

NN/utils.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,9 @@
11
import tensorflow as tf
22
import tensorflow_addons as tfa
3-
import tensorflow_probability as tfp
43
import tensorflow.keras.layers as L
54
import tensorflow as tf
65
import tensorflow_probability as tfp
76

8-
def gaussian_kernel(size, stdsPx):
9-
stds = tf.cast(stdsPx, tf.float32) / size
10-
B = tf.shape(stds)[0]
11-
stds = tf.reshape(stds, [B, 1])
12-
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size)
13-
x = tf.cast(x ** 2, tf.float32)
14-
x = tf.tile(x[None], [B, 1])
15-
x = tf.nn.softmax(-x / (2.0 * (stds**2)))
16-
x = tf.matmul(x[:, :, None], x[:, None, :])
17-
gauss = tf.reshape(x, [B, size, size, 1])
18-
gauss = tf.repeat(gauss, 3, axis=-1)
19-
gauss = tf.repeat(gauss, 3, axis=0)
20-
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B]
21-
return gauss
22-
237
def masked(x, mask):
248
'''
259
very weird hack to apply a mask to the tensor and ensure that the shape is preserved
@@ -156,21 +140,4 @@ def is_namedtuple(obj) -> bool:
156140
isinstance(obj, tuple) and
157141
hasattr(obj, '_asdict') and
158142
hasattr(obj, '_fields')
159-
)
160-
161-
if '__main__' == __name__:
162-
print('Utils test')
163-
print('All tests passed successfully!')
164-
import cv2
165-
import tensorflow_addons as tfa
166-
167-
img = cv2.imread('d:\photo_2024-03-26_15-48-33.jpg')
168-
img = img.astype('float32') / 255.0
169-
gaussians = gaussian_kernel(48, tf.constant([10., 20., 30.]))
170-
imgG = tf.nn.conv2d(img[None], gaussians, strides=[1, 1, 1, 1], padding='SAME')[0]
171-
172-
for i in [3, 6, 9]:
173-
g = imgG[..., i-3:i].numpy()
174-
g = (g * 255.0).astype('uint8')
175-
cv2.imshow('Gaussian %i' % i, g)
176-
cv2.waitKey(0)
143+
)

NN/utils_bluring.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import tensorflow as tf
2+
from NN.utils import extractInterpolated
3+
4+
def create1DGaussian(size, stds, shifts):
5+
B = tf.shape(stds)[0]
6+
tf.assert_equal(tf.shape(shifts), (B, ))
7+
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size)
8+
x = tf.cast(x, tf.float32)
9+
x = tf.tile(x[None], [B, 1]) + shifts[..., None]
10+
x = tf.nn.softmax(-(x ** 2) / (2.0 * (stds ** 2)), axis=-1)
11+
x = tf.reshape(x, [B, size])
12+
return x
13+
14+
def gaussian_kernel(size, stdsPx, shifts=None):
15+
if shifts is None:
16+
shifts = tf.zeros((tf.shape(stdsPx)[0], 2))
17+
18+
stds = tf.cast(stdsPx, tf.float32)
19+
B = tf.shape(stds)[0]
20+
stds = tf.reshape(stds, [B, 1])
21+
22+
gX = create1DGaussian(size, stds, shifts[:, 0])[..., None]
23+
gY = create1DGaussian(size, stds, shifts[:, 1])[..., None, :]
24+
gauss = tf.matmul(gX, gY)
25+
26+
gauss = tf.reshape(gauss, [B, size, size, 1])
27+
gauss = tf.tile(gauss, [1, 1, 1, 3])
28+
tf.assert_equal(tf.shape(gauss), (B, size, size, 3))
29+
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B]
30+
tf.assert_equal(tf.shape(gauss), (size, size, 3, B))
31+
return gauss
32+
33+
############################
34+
# trying to implement more efficient bluring
35+
def shiftsPixels(HW, points):
36+
d = 1.0 / tf.cast(HW, tf.float32)
37+
return points - (tf.floor(points / d) * d + (d / 2.0))
38+
39+
def visibleArea(points, HW, size):
40+
HW = tf.cast(HW, tf.float32)
41+
HW = tf.repeat(HW, repeats=2)
42+
HW = tf.reshape(HW, (1, 2))
43+
points = points * HW
44+
points = tf.floor(points)
45+
points = tf.cast(points, tf.int32)
46+
47+
HW = tf.cast(HW, tf.int32)
48+
left = tf.maximum(0, points - size)
49+
right = tf.minimum(HW, points + size)
50+
return left, right
51+
52+
def area2indices(left, right, HW, maxN):
53+
B = tf.shape(left)[0]
54+
LR = tf.concat([left, right], axis=-1)
55+
tf.assert_equal(tf.shape(LR), (B, 4))
56+
57+
def f(lr):
58+
l, r = lr[:2], lr[2:]
59+
wh = r - l
60+
w, h = wh[0], wh[1]
61+
# tf.debugging.assert_greater(0, w)
62+
tf.debugging.assert_less_equal(w, maxN)
63+
# tf.debugging.assert_greater(0, h)
64+
tf.debugging.assert_less_equal(h, maxN)
65+
indices = l[0] + tf.range(w) # [minX, maxX]
66+
indices = tf.reshape(indices, [1, -1])
67+
indices = tf.tile(indices, [h, 1])
68+
shifts = l[1] + tf.range(h) # [minY, maxY]
69+
indices = indices + shifts[:, None] * HW
70+
tf.assert_equal(tf.shape(indices), (h, w))
71+
72+
pad = maxN**2 - tf.size(indices)
73+
indices = tf.reshape(indices, [-1])
74+
indices = tf.pad(indices, [[0, pad]], constant_values=-1)
75+
return indices
76+
return tf.map_fn(f, LR, dtype=tf.int32)
77+
78+
def extractBluredX(img, points, R, maxR):
79+
img = img[None]
80+
B = tf.shape(points)[0]
81+
tf.assert_rank(img, 4)
82+
tf.assert_equal(tf.shape(points), (B, 2))
83+
tf.assert_equal(tf.shape(R), (B, 1))
84+
H, W = [tf.shape(img)[i] for i in [1, 2]]
85+
tf.assert_equal(H, W, 'Image should be square')
86+
gaussians = gaussian_kernel(maxR, R, shifts=shiftsPixels(H, points))
87+
gaussians = tf.transpose(gaussians, [3, 0, 1, 2]) # [size, size, 3, B] => [B, size, size, 3]
88+
gaussians = tf.reshape(gaussians, [B, -1, 3])
89+
sz = tf.shape(gaussians)[1]
90+
# extract areas around the points
91+
# first, find the visible area for each point
92+
left, right = visibleArea(points, H, size=maxR)
93+
# extract the indices of the visible area
94+
indices = area2indices(left, right, H, sz)
95+
tf.assert_equal(tf.shape(indices), (B, sz ** 2))
96+
# extract the visible areas from the image
97+
flatImg = tf.reshape(img, [1, H * W, 3])
98+
extracted = tf.gather(flatImg, indices, axis=1)[0]
99+
tf.assert_equal(tf.shape(extracted), (B, sz ** 2, 3))
100+
101+
indicesLow = indices[:, 0, None]
102+
extractedWeights = tf.gather(gaussians, indices - indicesLow, batch_dims=1)
103+
tf.assert_equal(tf.shape(extractedWeights), tf.shape(extracted))
104+
extracted = tf.reduce_sum(extracted * extractedWeights, axis=1)
105+
tf.assert_equal(tf.shape(extracted), (B, 3))
106+
return extracted
107+
############################
108+
def applyBluring(img, kernel):
109+
tf.assert_rank(img, 4)
110+
tf.assert_rank(kernel, 4)
111+
tf.assert_equal(tf.shape(img)[0], 1)
112+
B = tf.shape(kernel)[-1]
113+
114+
imgG = tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')[0]
115+
H, W = [tf.shape(imgG)[i] for i in range(2)]
116+
imgG = tf.reshape(imgG, [H, W, 3, -1])
117+
imgG = tf.transpose(imgG, (3, 0, 1, 2))
118+
imgG = tf.reshape(imgG, (B, H, W, 3))
119+
return imgG
120+
121+
def extractBlured(R):
122+
'''
123+
R: list of bluring radiuses, (B, 1)
124+
'''
125+
R = tf.reshape(R, (tf.size(R), 1))
126+
maxR = tf.reduce_max(R)
127+
maxR = tf.cast(maxR, tf.int32) + 1
128+
gaussians = gaussian_kernel(maxR, R, shifts=tf.zeros((tf.shape(R)[0], 2)))
129+
gaussiansN = tf.shape(gaussians)[-1]
130+
131+
def f(img, points, ptR):
132+
img = img[None]
133+
tf.assert_rank(img, 4)
134+
B = tf.shape(points)[0]
135+
tf.assert_equal(tf.shape(points), (B, 2))
136+
tf.assert_equal(tf.shape(ptR), (B, 1))
137+
blured = applyBluring(img, gaussians)
138+
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3))
139+
# extract the blured values
140+
blured = extractInterpolated(blured, points[None])
141+
tf.assert_equal(tf.shape(blured), (gaussiansN, B, 3))
142+
143+
# blured contains the blured values for each point in each gaussian
144+
# we need to select the blured value for each point based on its radius
145+
correspondingG = tf.transpose(ptR == R[..., 0])
146+
tf.assert_equal(tf.shape(correspondingG), (gaussiansN, B))
147+
148+
idx = tf.where(correspondingG)
149+
tf.assert_equal(tf.shape(idx), (B, 2))
150+
blured = tf.gather_nd(blured, idx)
151+
tf.assert_equal(tf.shape(blured), (B, 3))
152+
return blured
153+
return f

tests/test_utils_bluring.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import pytest
2+
import numpy as np
3+
import tensorflow as tf
4+
from NN.utils_bluring import shiftsPixels, visibleArea, area2indices, \
5+
gaussian_kernel, applyBluring, extractBlured
6+
from NN.utils import extractInterpolated
7+
8+
# test shiftsPixels
9+
def test_shiftsPixels():
10+
H = 3
11+
d = (1.0 / H)
12+
points = np.array([
13+
[d * 0.9, d * 0.5], [d * 1.1, d * 0.6], [d * 2.3, d * 0.7], [d * 2.8, d * 10.],
14+
]).astype(np.float32)
15+
correct = np.array([
16+
[0.4, 0.0], [-0.4, 0.1], [-0.2, 0.2], [0.3, 0.5],
17+
]).astype(np.float32)
18+
19+
shifts = shiftsPixels(H, points).numpy() / d
20+
diff = np.abs(shifts - correct)
21+
assert diff.max() < 1e-5, '%s != %s' % (shifts, correct)
22+
return
23+
24+
# test visibleArea
25+
def test_visibleArea():
26+
HW = 10
27+
size = 7
28+
points = np.array([
29+
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2]
30+
]).astype(np.float32)
31+
left, right = visibleArea(points, HW, size)
32+
left = left.numpy()
33+
right = right.numpy()
34+
correctLeft = np.array([
35+
[2, 2], [0, 0], [6, 6], [0, 6], [6, 0], [4, 0],
36+
])
37+
correctRight = np.array([
38+
[ 8, 8], [ 4, 4], [10, 10], [ 4, 10], [10, 4], [10, 5]
39+
])
40+
assert np.allclose(left, correctLeft), '%s != %s' % (left, correctLeft)
41+
assert np.allclose(right, correctRight), '%s != %s' % (right, correctRight)
42+
return
43+
44+
# test area2indices
45+
def test_area2indices():
46+
HW = 10
47+
indices = area2indices([[0, 0], [2, 1]], [[2, 3], [4, 5]], HW)
48+
correct = np.array([
49+
[ 0, 1,
50+
0 + HW * 1, 1 + HW * 1,
51+
0 + HW * 2, 1 + HW * 2,
52+
] + (HW ** 2 - 6) * [-1],
53+
[ 2 + HW * 1, 3 + HW * 1,
54+
2 + HW * 2, 3 + HW * 2,
55+
2 + HW * 3, 3 + HW * 3,
56+
2 + HW * 4, 3 + HW * 4,
57+
] + (HW ** 2 - (4 - 2) * (5 - 1)) * [-1],
58+
])
59+
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct):
60+
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b)
61+
continue
62+
return
63+
64+
# test area2indices full
65+
def test_area2indices_full():
66+
hwOld = 10
67+
indices = area2indices([[0, 0]], [[hwOld, hwOld]], hwOld)
68+
correct = np.array([range(hwOld ** 2)])
69+
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct):
70+
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b)
71+
continue
72+
return
73+
######################################################################
74+
# extractBlured same as extractInterpolated after applying gaussians
75+
# TODO: find out why the results are so numerically instable
76+
def test_extractBlured_same():
77+
H = 10
78+
img = np.random.rand(H, H, 3).astype(np.float32)
79+
points = np.array([
80+
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2]
81+
]).astype(np.float32)
82+
R = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
83+
imgBlured = applyBluring(img[None], gaussian_kernel(H, R))
84+
assert imgBlured.shape == (6, H, H, 3)
85+
86+
blur = extractBlured(R)
87+
blured = blur(img, points, ptR=np.full((6, 1), 1.0))
88+
89+
extracted = extractInterpolated(imgBlured[0, None], points[None])[0]
90+
assert extracted.shape == blured.shape
91+
for i, a, b in zip(range(extracted.shape[0]), extracted, blured):
92+
diff = np.abs(a - b).max()
93+
assert diff < 5e-2, '%d | %s != %s' % (i, a, b)
94+
continue
95+
return

0 commit comments

Comments
 (0)