Skip to content

Commit

Permalink
Merge 9abd6eb into 38771b1
Browse files Browse the repository at this point in the history
  • Loading branch information
copybara-service[bot] committed Dec 3, 2021
2 parents 38771b1 + 9abd6eb commit a2f5a14
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tensorflow_graphics/image/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ def sample(image: type_alias.TensorLike,
if resampling_type == ResamplingType.NEAREST:
warp = tf.math.round(warp)

if border_type == BorderType.DUPLICATE:
image_size = tf.cast(tf.shape(input=image)[1:3], dtype=warp.dtype)
height, width = tf.unstack(image_size, axis=-1)
warp_x, warp_y = tf.unstack(warp, axis=-1)
warp_x = tf.clip_by_value(warp_x, 0.0, width - 1.0)
warp_y = tf.clip_by_value(warp_y, 0.0, height - 1.0)
warp = tf.stack((warp_x, warp_y), axis=-1)

return tfa_image.resampler(image, warp)
if border_type == BorderType.ZERO:
image = tf.pad(image, ((0, 0), (1, 1), (1, 1), (0, 0)))
warp = warp + 1

warp_shape = tf.shape(warp)
flat_warp = tf.reshape(warp, (warp_shape[0], -1, 2))
flat_sampled = tfa_image.interpolate_bilinear(
image, flat_warp, indexing="xy")
output_shape = tf.concat((warp_shape[:-1], tf.shape(flat_sampled)[-1:]), 0)
return tf.reshape(flat_sampled, output_shape)


def perspective_transform(
Expand Down

0 comments on commit a2f5a14

Please sign in to comment.