diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc index f1b18cc52e..a0c24cd597 100644 --- a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc +++ b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc @@ -43,32 +43,70 @@ template struct FillProjectiveTransform; } // end namespace functor -typedef Eigen::ThreadPoolDevice CPUDevice; +using CPUDevice = Eigen::ThreadPoolDevice; using functor::FillProjectiveTransform; +using generator::Extend; using generator::Interpolation; -using generator::INTERPOLATION_BILINEAR; -using generator::INTERPOLATION_NEAREST; using generator::ProjectiveGenerator; +bool InterpolationFromString(const string& interpolation_str, + Interpolation* interpolation) { + if (interpolation_str == "NEAREST") { + *interpolation = Interpolation::INTERPOLATION_NEAREST; + } else if (interpolation_str == "BILINEAR") { + *interpolation = Interpolation::INTERPOLATION_BILINEAR; + } else { + return false; + } + return true; +} + +bool ExtendFromString(const string& extend_str, Extend* extend) { + if (extend_str == "REFLECT") { + *extend = Extend::EXTEND_REFLECT; + } else if (extend_str == "CONSTANT") { + *extend = Extend::EXTEND_CONSTANT; + } else if (extend_str == "NEAREST") { + *extend = Extend::EXTEND_NEAREST; + } else if (extend_str == "MIRROR") { + *extend = Extend::EXTEND_MIRROR; + } else if (extend_str == "WRAP") { + *extend = Extend::EXTEND_WRAP; + } else { + return false; + } + return true; +} + template class ImageProjectiveTransformV2 : public OpKernel { private: Interpolation interpolation_; + Extend extend_; + T constant_values_; public: explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx) : OpKernel(ctx) { string interpolation_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); - if (interpolation_str == "NEAREST") { - interpolation_ = INTERPOLATION_NEAREST; - } else if (interpolation_str == "BILINEAR") { - interpolation_ = INTERPOLATION_BILINEAR; - } else { - LOG(FATAL) << "Invalid interpolation " << interpolation_str - << ". Supported types: NEAREST, BILINEAR"; - } + OP_REQUIRES( + ctx, InterpolationFromString(interpolation_str, &interpolation_), + errors::InvalidArgument("Invalid interpolation ", interpolation_str, + ". Supported types: NEAREST, BILINEAR.")); + + string extend_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("extend", &extend_str)); + OP_REQUIRES( + ctx, ExtendFromString(extend_str, &extend_), + errors::InvalidArgument( + "Invalid extend ", extend_str, + ". Supported types: REFLECT, CONSTANT, NEAREST, MIRROR, WRAP.")); + + float constant_values; + OP_REQUIRES_OK(ctx, ctx->GetAttr("constant_values", &constant_values)); + constant_values_ = static_cast(constant_values); } void Compute(OpKernelContext* ctx) override { @@ -117,8 +155,9 @@ class ImageProjectiveTransformV2 : public OpKernel { auto images = images_t.tensor(); auto transform = transform_t.matrix(); - (FillProjectiveTransform(interpolation_))( - ctx->eigen_device(), &output, images, transform); + (FillProjectiveTransform( + interpolation_, extend_, constant_values_))(ctx->eigen_device(), + &output, images, transform); } }; diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.h b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.h index 2a90168d07..e4538a6bbb 100644 --- a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.h +++ b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.h @@ -30,6 +30,13 @@ namespace addons { namespace generator { enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR }; +enum Extend { + EXTEND_REFLECT, + EXTEND_CONSTANT, + EXTEND_NEAREST, + EXTEND_MIRROR, + EXTEND_WRAP +}; using Eigen::array; using Eigen::DenseIndex; @@ -40,6 +47,8 @@ class ProjectiveGenerator { typename TTypes::ConstTensor input_; typename TTypes::ConstMatrix transforms_; const Interpolation interpolation_; + const Extend extend_; + const T constant_values_; public: static const int kNumParameters = 8; @@ -47,8 +56,13 @@ class ProjectiveGenerator { EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ProjectiveGenerator(typename TTypes::ConstTensor input, typename TTypes::ConstMatrix transforms, - const Interpolation interpolation) - : input_(input), transforms_(transforms), interpolation_(interpolation) {} + const Interpolation interpolation, const Extend extend, + const T constant_values) + : input_(input), + transforms_(transforms), + interpolation_(interpolation), + extend_(extend), + constant_values_(constant_values) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const array& coords) const { @@ -60,9 +74,9 @@ class ProjectiveGenerator { : &transforms_.data()[transforms_.dimension(1) * coords[0]]; float projection = transform[6] * output_x + transform[7] * output_y + 1.f; if (projection == 0) { - // Return the fill value (0) for infinite coordinates, + // Return the constant_values_ for infinite coordinates, // which are outside the input image - return T(0); + return constant_values_; } const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / @@ -71,32 +85,29 @@ class ProjectiveGenerator { (transform[3] * output_x + transform[4] * output_y + transform[5]) / projection; - const T fill_value = T(0); switch (interpolation_) { case INTERPOLATION_NEAREST: // Switch the order of x and y again for indexing into the image. - return nearest_interpolation(coords[0], input_y, input_x, coords[3], - fill_value); + return nearest_interpolation(coords[0], input_y, input_x, coords[3]); case INTERPOLATION_BILINEAR: - return bilinear_interpolation(coords[0], input_y, input_x, coords[3], - fill_value); + return bilinear_interpolation(coords[0], input_y, input_x, coords[3]); } // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST // or INTERPOLATION_BILINEAR. - return T(0); + return constant_values_; } private: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T nearest_interpolation(const DenseIndex batch, const float y, const float x, - const DenseIndex channel, const T fill_value) const { + const DenseIndex channel) const { return read_with_fill_value(batch, DenseIndex(std::round(y)), - DenseIndex(std::round(x)), channel, fill_value); + DenseIndex(std::round(x)), channel); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T bilinear_interpolation(const DenseIndex batch, const float y, const float x, - const DenseIndex channel, const T fill_value) const { + const DenseIndex channel) const { const float y_floor = std::floor(y); const float x_floor = std::floor(x); const float y_ceil = y_floor + 1; @@ -104,35 +115,118 @@ class ProjectiveGenerator { // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) const float value_yfloor = - (x_ceil - x) * static_cast(read_with_fill_value( - batch, DenseIndex(y_floor), DenseIndex(x_floor), - channel, fill_value)) + - (x - x_floor) * static_cast(read_with_fill_value( - batch, DenseIndex(y_floor), DenseIndex(x_ceil), - channel, fill_value)); + (x_ceil - x) * + static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), channel)) + + (x - x_floor) * + static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), channel)); // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) const float value_yceil = - (x_ceil - x) * static_cast(read_with_fill_value( - batch, DenseIndex(y_ceil), DenseIndex(x_floor), - channel, fill_value)) + - (x - x_floor) * static_cast(read_with_fill_value( - batch, DenseIndex(y_ceil), DenseIndex(x_ceil), - channel, fill_value)); + (x_ceil - x) * + static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), channel)) + + (x - x_floor) * + static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), channel)); // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value( - const DenseIndex batch, const DenseIndex y, const DenseIndex x, - const DenseIndex channel, const T fill_value) const { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE DenseIndex + map_coordinate(const DenseIndex in, const DenseIndex dim) const { + DenseIndex out(in); + + if (in < 0) { + switch (extend_) { + case EXTEND_MIRROR: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(2 * dim - 2); + out = d * (-out / d) + out; + out = (out <= 1 - dim) ? out + d : -out; + } + break; + case EXTEND_REFLECT: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(2 * dim); + if (out < -d) { + out = d * (-out / d) + out; + } + out = (out < -dim) ? out + d : -out - 1; + } + break; + case EXTEND_WRAP: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(dim - 1); + out += d * ((-out / d) + 1); + } + break; + case EXTEND_NEAREST: + out = 0; + break; + case EXTEND_CONSTANT: + out = -1; + break; + } + } else if (in >= dim) { + switch (extend_) { + case EXTEND_MIRROR: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(2 * dim - 2); + out -= d * (out / d); + out = (out >= dim) ? d - out : out; + } + break; + case EXTEND_REFLECT: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(2 * dim); + out -= d * (out / d); + out = (out >= dim) ? d - out - 1 : out; + } + break; + case EXTEND_WRAP: + if (dim <= 1) { + out = 0; + } else { + const DenseIndex d(dim - 1); + out -= d * (out / d); + } + break; + case EXTEND_NEAREST: + out = dim - 1; + break; + case EXTEND_CONSTANT: + out = -1; + break; + } + } + + return out; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + read_with_fill_value(const DenseIndex batch, const DenseIndex y, + const DenseIndex x, const DenseIndex channel) const { // batch and channel must be correct, because they are passed unchanged from // the input. - return (0 <= y && y < input_.dimension(1) && 0 <= x && - x < input_.dimension(2)) - ? input_(array{batch, y, x, channel}) - : fill_value; + const DenseIndex my(map_coordinate(y, input_.dimension(1))), + mx(map_coordinate(x, input_.dimension(2))); + + return (my >= 0 && mx >= 0) + ? input_(array{batch, my, mx, channel}) + : constant_values_; } }; @@ -143,6 +237,7 @@ class ProjectiveGenerator { // some Eigen device code. namespace functor { +using generator::Extend; using generator::Interpolation; using generator::ProjectiveGenerator; @@ -152,16 +247,21 @@ struct FillProjectiveTransform { typedef typename TTypes::ConstTensor InputType; typedef typename TTypes::ConstTensor TransformsType; const Interpolation interpolation_; + const Extend extend_; + const T constant_values_; - FillProjectiveTransform(Interpolation interpolation) - : interpolation_(interpolation) {} + FillProjectiveTransform(const Interpolation interpolation, + const Extend extend, const T constant_values) + : interpolation_(interpolation), + extend_(extend), + constant_values_(constant_values) {} EIGEN_ALWAYS_INLINE void operator()(const Device& device, OutputType* output, const InputType& images, const TransformsType& transform) const { - output->device(device) = output->generate( - ProjectiveGenerator(images, transform, interpolation_)); + output->device(device) = output->generate(ProjectiveGenerator( + images, transform, interpolation_, extend_, constant_values_)); } }; diff --git a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc index 86f1ae85ba..e19a8f48df 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc @@ -135,6 +135,8 @@ REGISTER_OP("Addons>ImageProjectiveTransformV2") .Input("output_shape: int32") .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") + .Attr("extend: string") + .Attr("constant_values: float = 0.0") .Output("transformed_images: dtype") .SetShapeFn(ResizeShapeFn) .Doc(kImageProjectiveTransformDoc); @@ -151,4 +153,4 @@ REGISTER_OP("Addons>ImageConnectedComponents") .Doc(ImageConnectedComponentsDoc); } // end namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow_addons/image/tests/transform_ops_test.py b/tensorflow_addons/image/tests/transform_ops_test.py index d634455548..2a93e6d08a 100644 --- a/tensorflow_addons/image/tests/transform_ops_test.py +++ b/tensorflow_addons/image/tests/transform_ops_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow_addons.image import transform_ops +from scipy import ndimage from skimage import transform _DTYPES = { @@ -214,6 +215,31 @@ def test_rotate_odd(dtype): ) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("extend", ["CONSTANT", "MIRROR", "NEAREST", "WRAP"]) +def test_rotate_extend(extend): + image = tf.constant( + [ + [0, 0, 0, 0, 0], + [0, 1, 1, 1, 0], + [0, 1, 0, 1, 0], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 0], + ], + tf.float32, + ) + + transformed = transform_ops.rotate( + image, np.pi / 4.0, interpolation="BILINEAR", extend=extend + ) + np.testing.assert_allclose( + transformed.numpy(), + ndimage.rotate(image.numpy(), 45, order=1, mode=extend.lower(), reshape=False), + rtol=1e-6, + atol=5e-6, + ) + + @pytest.mark.usefixtures("maybe_run_functions_eagerly") @pytest.mark.parametrize("dtype", _DTYPES) def test_compose_rotate(dtype): diff --git a/tensorflow_addons/image/transform_ops.py b/tensorflow_addons/image/transform_ops.py index 75ade66760..fdee0f2d67 100644 --- a/tensorflow_addons/image/transform_ops.py +++ b/tensorflow_addons/image/transform_ops.py @@ -40,6 +40,8 @@ def transform( transforms: TensorLike, interpolation: str = "NEAREST", output_shape: Optional[list] = None, + extend: str = "CONSTANT", + constant_values: TensorLike = 0.0, name: Optional[str] = None, ) -> tf.Tensor: """Applies the given transform(s) to the image(s). @@ -60,7 +62,9 @@ def transform( Supported values: "NEAREST", "BILINEAR". output_shape: Output dimesion after the transform, [height, width]. If None, output is the same size as input image. - + extend: Extend mode. Supported values: "REFLECT", + "CONSTANT", "NEAREST", "MIRROR", "WRAP". + constant_values: The fill value to use in "CONSTANT" extend mode. name: The name of the op. Returns: @@ -113,6 +117,8 @@ def transform( output_shape=output_shape, transforms=transforms, interpolation=interpolation.upper(), + extend=extend.upper(), + constant_values=constant_values, ) return img_utils.from_4D_image(output, original_ndims) @@ -277,6 +283,8 @@ def _image_projective_transform_grad(op, grad): images = op.inputs[0] transforms = op.inputs[1] interpolation = op.get_attr("interpolation") + extend = op.get_attr("extend") + constant_values = op.get_attr("constant_values") image_or_images = tf.convert_to_tensor(images, name="images") transform_or_transforms = tf.convert_to_tensor( @@ -305,6 +313,8 @@ def _image_projective_transform_grad(op, grad): transforms=transforms, output_shape=tf.shape(image_or_images)[1:3], interpolation=interpolation, + extend=extend, + constant_values=constant_values, ) return [output, None, None] @@ -313,6 +323,8 @@ def rotate( images: TensorLike, angles: TensorLike, interpolation: str = "NEAREST", + extend: str = "CONSTANT", + constant_values: TensorLike = 0.0, name: Optional[str] = None, ) -> tf.Tensor: """Rotate image(s) counterclockwise by the passed angle(s) in radians. @@ -327,6 +339,9 @@ def rotate( batch. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + extend: Extend mode. Supported values: "REFLECT", + "CONSTANT", "NEAREST", "MIRROR", "WRAP". + constant_values: The fill value to use in "CONSTANT" extend mode. name: The name of the op. Returns: @@ -349,6 +364,8 @@ def rotate( images, angles_to_projective_transforms(angles, image_height, image_width), interpolation=interpolation, + extend=extend, + constant_values=constant_values, ) return img_utils.from_4D_image(output, original_ndims)