Skip to content
This repository was archived by the owner on Mar 11, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,70 @@ template struct FillProjectiveTransform<CPUDevice, double>;

} // end namespace functor

typedef Eigen::ThreadPoolDevice CPUDevice;
using CPUDevice = Eigen::ThreadPoolDevice;
Comment thread
AakashKumarNain marked this conversation as resolved.

using functor::FillProjectiveTransform;
using generator::Extend;
using generator::Interpolation;
using generator::INTERPOLATION_BILINEAR;
using generator::INTERPOLATION_NEAREST;
using generator::ProjectiveGenerator;

bool InterpolationFromString(const string& interpolation_str,
Interpolation* interpolation) {
if (interpolation_str == "NEAREST") {
*interpolation = Interpolation::INTERPOLATION_NEAREST;
} else if (interpolation_str == "BILINEAR") {
*interpolation = Interpolation::INTERPOLATION_BILINEAR;
} else {
return false;
}
return true;
}

bool ExtendFromString(const string& extend_str, Extend* extend) {
if (extend_str == "REFLECT") {
*extend = Extend::EXTEND_REFLECT;
} else if (extend_str == "CONSTANT") {
*extend = Extend::EXTEND_CONSTANT;
} else if (extend_str == "NEAREST") {
*extend = Extend::EXTEND_NEAREST;
} else if (extend_str == "MIRROR") {
*extend = Extend::EXTEND_MIRROR;
} else if (extend_str == "WRAP") {
*extend = Extend::EXTEND_WRAP;
} else {
return false;
}
return true;
}

template <typename Device, typename T>
class ImageProjectiveTransformV2 : public OpKernel {
private:
Interpolation interpolation_;
Extend extend_;
T constant_values_;

public:
explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx)
: OpKernel(ctx) {
string interpolation_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
if (interpolation_str == "NEAREST") {
interpolation_ = INTERPOLATION_NEAREST;
} else if (interpolation_str == "BILINEAR") {
interpolation_ = INTERPOLATION_BILINEAR;
} else {
LOG(FATAL) << "Invalid interpolation " << interpolation_str
<< ". Supported types: NEAREST, BILINEAR";
}
OP_REQUIRES(
ctx, InterpolationFromString(interpolation_str, &interpolation_),
errors::InvalidArgument("Invalid interpolation ", interpolation_str,
". Supported types: NEAREST, BILINEAR."));

string extend_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("extend", &extend_str));
OP_REQUIRES(
ctx, ExtendFromString(extend_str, &extend_),
errors::InvalidArgument(
"Invalid extend ", extend_str,
". Supported types: REFLECT, CONSTANT, NEAREST, MIRROR, WRAP."));

float constant_values;
OP_REQUIRES_OK(ctx, ctx->GetAttr("constant_values", &constant_values));
constant_values_ = static_cast<T>(constant_values);
}

void Compute(OpKernelContext* ctx) override {
Expand Down Expand Up @@ -117,8 +155,9 @@ class ImageProjectiveTransformV2 : public OpKernel {
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();

(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
(FillProjectiveTransform<Device, T>(
interpolation_, extend_, constant_values_))(ctx->eigen_device<Device>(),
&output, images, transform);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ namespace addons {
namespace generator {

enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
enum Extend {
EXTEND_REFLECT,
EXTEND_CONSTANT,
EXTEND_NEAREST,
EXTEND_MIRROR,
EXTEND_WRAP
};

using Eigen::array;
using Eigen::DenseIndex;
Expand All @@ -40,15 +47,22 @@ class ProjectiveGenerator {
typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_;
const Interpolation interpolation_;
const Extend extend_;
const T constant_values_;

public:
static const int kNumParameters = 8;

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
typename TTypes<float>::ConstMatrix transforms,
const Interpolation interpolation)
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
const Interpolation interpolation, const Extend extend,
const T constant_values)
: input_(input),
transforms_(transforms),
interpolation_(interpolation),
extend_(extend),
constant_values_(constant_values) {}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const {
Expand All @@ -60,9 +74,9 @@ class ProjectiveGenerator {
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
if (projection == 0) {
// Return the fill value (0) for infinite coordinates,
// Return the constant_values_ for infinite coordinates,
// which are outside the input image
return T(0);
return constant_values_;
}
const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
Expand All @@ -71,68 +85,148 @@ class ProjectiveGenerator {
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection;

const T fill_value = T(0);
switch (interpolation_) {
case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
return nearest_interpolation(coords[0], input_y, input_x, coords[3]);
case INTERPOLATION_BILINEAR:
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
return bilinear_interpolation(coords[0], input_y, input_x, coords[3]);
}
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
// or INTERPOLATION_BILINEAR.
return T(0);
return constant_values_;
}

private:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
nearest_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
const DenseIndex channel) const {
return read_with_fill_value(batch, DenseIndex(std::round(y)),
DenseIndex(std::round(x)), channel, fill_value);
DenseIndex(std::round(x)), channel);
}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
const DenseIndex channel) const {
const float y_floor = std::floor(y);
const float x_floor = std::floor(x);
const float y_ceil = y_floor + 1;
const float x_ceil = x_floor + 1;
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
const float value_yfloor =
(x_ceil - x) * static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_floor), DenseIndex(x_floor),
channel, fill_value)) +
(x - x_floor) * static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_floor), DenseIndex(x_ceil),
channel, fill_value));
(x_ceil - x) *
static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_floor), DenseIndex(x_floor), channel)) +
(x - x_floor) *
static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_floor), DenseIndex(x_ceil), channel));
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
const float value_yceil =
(x_ceil - x) * static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_ceil), DenseIndex(x_floor),
channel, fill_value)) +
(x - x_floor) * static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_ceil), DenseIndex(x_ceil),
channel, fill_value));
(x_ceil - x) *
static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_ceil), DenseIndex(x_floor), channel)) +
(x - x_floor) *
static_cast<float>(read_with_fill_value(
batch, DenseIndex(y_ceil), DenseIndex(x_ceil), channel));
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
const DenseIndex channel, const T fill_value) const {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE DenseIndex
map_coordinate(const DenseIndex in, const DenseIndex dim) const {
DenseIndex out(in);

if (in < 0) {
switch (extend_) {
case EXTEND_MIRROR:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(2 * dim - 2);
out = d * (-out / d) + out;
out = (out <= 1 - dim) ? out + d : -out;
}
break;
case EXTEND_REFLECT:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(2 * dim);
if (out < -d) {
out = d * (-out / d) + out;
}
out = (out < -dim) ? out + d : -out - 1;
}
break;
case EXTEND_WRAP:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(dim - 1);
out += d * ((-out / d) + 1);
}
break;
case EXTEND_NEAREST:
out = 0;
break;
case EXTEND_CONSTANT:
out = -1;
break;
}
} else if (in >= dim) {
switch (extend_) {
case EXTEND_MIRROR:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(2 * dim - 2);
out -= d * (out / d);
out = (out >= dim) ? d - out : out;
}
break;
case EXTEND_REFLECT:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(2 * dim);
out -= d * (out / d);
out = (out >= dim) ? d - out - 1 : out;
}
break;
case EXTEND_WRAP:
if (dim <= 1) {
out = 0;
} else {
const DenseIndex d(dim - 1);
out -= d * (out / d);
}
break;
case EXTEND_NEAREST:
out = dim - 1;
break;
case EXTEND_CONSTANT:
out = -1;
break;
}
}

return out;
}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
read_with_fill_value(const DenseIndex batch, const DenseIndex y,
const DenseIndex x, const DenseIndex channel) const {
// batch and channel must be correct, because they are passed unchanged from
// the input.
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
x < input_.dimension(2))
? input_(array<DenseIndex, 4>{batch, y, x, channel})
: fill_value;
const DenseIndex my(map_coordinate(y, input_.dimension(1))),
mx(map_coordinate(x, input_.dimension(2)));

return (my >= 0 && mx >= 0)
? input_(array<DenseIndex, 4>{batch, my, mx, channel})
: constant_values_;
}
};

Expand All @@ -143,6 +237,7 @@ class ProjectiveGenerator {
// some Eigen device code.
namespace functor {

using generator::Extend;
using generator::Interpolation;
using generator::ProjectiveGenerator;

Expand All @@ -152,16 +247,21 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
const Interpolation interpolation_;
const Extend extend_;
const T constant_values_;

FillProjectiveTransform(Interpolation interpolation)
: interpolation_(interpolation) {}
FillProjectiveTransform(const Interpolation interpolation,
const Extend extend, const T constant_values)
: interpolation_(interpolation),
extend_(extend),
constant_values_(constant_values) {}

EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
output->device(device) = output->generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
output->device(device) = output->generate(ProjectiveGenerator<Device, T>(
images, transform, interpolation_, extend_, constant_values_));
}
};

Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ REGISTER_OP("Addons>ImageProjectiveTransformV2")
.Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Attr("extend: string")
.Attr("constant_values: float = 0.0")
.Output("transformed_images: dtype")
.SetShapeFn(ResizeShapeFn)
.Doc(kImageProjectiveTransformDoc);
Expand All @@ -151,4 +153,4 @@ REGISTER_OP("Addons>ImageConnectedComponents")
.Doc(ImageConnectedComponentsDoc);

} // end namespace addons
} // namespace tensorflow
} // namespace tensorflow
Loading