Skip to content

Commit 24d88d0

Browse files
make R starting from 0.0
1 parent 5ca411a commit 24d88d0

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

NN/RestorationModel/CRestorationModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _addResiduals(self, latents, residuals):
5656

5757
return latents
5858

59-
def _addRadius(self, latents, R=None, fakeR=1e-5, training=False):
59+
def _addRadius(self, latents, R=None, fakeR=0.0, training=False):
6060
B = tf.shape(latents)[0]
6161
if self._blurRadiusEncoder is not None:
6262
if R is None:

Utils/CroppingAugm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def SubsampleProcessor(target_crop_size, N, extras=[], sampler='uniform'):
3636
withBlur = blurConfig is not None
3737
if withBlur:
3838
blurRange = blurConfig['min'] + tf.linspace(0.0, blurConfig['max'], blurConfig['N'])
39+
minR = tf.reduce_min(blurRange)
3940
blurN = tf.size(blurRange)
4041
blurShared = blurConfig.get('shared', False)
4142
if blurShared:
@@ -60,20 +61,21 @@ def _FF(img):
6061
sobel = extractInterpolated(sobel, positions)
6162
res['sobel'] = tf.reshape(sobel, [N, 6])
6263

63-
if withBlur and blurShared:
64-
idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32)
65-
R = tf.gather(blurRange, idx)
66-
R = tf.reshape(R, (1,))
64+
if withBlur:
65+
if blurShared:
66+
idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32)
67+
R = tf.gather(blurRange, idx)
68+
R = tf.reshape(R, (1,))
69+
R = tf.fill([N, 1], R[0])
70+
else:
71+
idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32)
72+
R = tf.gather(blurRange, idx)
73+
R = tf.reshape(R, (N, 1))
74+
pass
75+
76+
tf.assert_equal(tf.shape(R), (N, 1))
6777
res['blured'] = blur(src, positions[0], R)
68-
res['blur R'] = tf.fill([N, 1], R[0])
69-
pass
70-
71-
if withBlur and not blurShared:
72-
idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32)
73-
R = tf.gather(blurRange, idx)
74-
R = tf.reshape(R, (N, 1))
75-
res['blured'] = blur(src, positions[0], R)
76-
res['blur R'] = R
78+
res['blur R'] = R - minR # ensure that R is starting from 0.0
7779
pass
7880
return res
7981
return _FF

0 commit comments

Comments
 (0)