Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tfjs-backend-cpu/src/kernels/FlipLeftRight.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, NumericDataType, TypedArray} from '@tensorflow/tfjs-core';
import {FlipLeftRight, FlipLeftRightInputs, util} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';

export const flipLeftRightConfig: KernelConfig = {
kernelName: FlipLeftRight,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {image} = inputs as FlipLeftRightInputs;
const cpuBackend = backend as MathBackendCPU;

const output = util.getTypedArrayFromDType(
image.dtype as NumericDataType, util.sizeFromShape(image.shape));
const [batch, imageHeight, imageWidth, numChannels] = image.shape;

const imageVals = cpuBackend.data.get(image.dataId).values as TypedArray;

for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;

for (let row = 0; row < imageHeight; row++) {
const rowOffset = row * (imageWidth * numChannels);

for (let col = 0; col < imageWidth; col++) {
const colOffset = col * numChannels;

for (let channel = 0; channel < numChannels; channel++) {
const coords = [batch, row, col, channel];

const x = coords[2];

const coordX = Math.round(imageWidth - x);
const outIdx = batchOffset + rowOffset + colOffset + channel;

let outputValue = imageVals[outIdx];
// If the coordinate position falls within the image boundaries...
if (coordX >= 0 && coordX < imageWidth) {
// set the output to the image value at the coordinate position.
const rotatedColOffset = coordX * numChannels;
const imageIdx =
batchOffset + rowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
output[outIdx] = outputValue;
}
}
}
}

const dataId = cpuBackend.write(output, image.shape, image.dtype);
return {dataId, shape: image.shape, dtype: image.dtype};
}
};
8 changes: 5 additions & 3 deletions tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {dilation2dConfig} from './kernels/Dilation2D';
import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter';
import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput';
import {divConfig} from './kernels/Div';
import {flipLeftRightConfig} from './kernels/FlipLeftRight';
import {maxConfig} from './kernels/Max';
import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';
import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4';
Expand All @@ -35,9 +36,10 @@ import {transposeConfig} from './kernels/Transpose';
// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
dilation2dConfig, dilation2dBackpropInputConfig,
dilation2dBackpropFilterConfig, divConfig, maxPoolWithArgmaxConfig, maxConfig,
nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, rotateWithOffsetConfig,
squareConfig, squaredDifferenceConfig, transposeConfig
dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig,
maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config, rotateWithOffsetConfig, squareConfig,
squaredDifferenceConfig, transposeConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ tfjs_cc_library(
":Div",
":Equal",
":Exp",
":FlipLeftRight",
":FloorDiv",
":FusedBatchNorm",
":FusedConv2D",
Expand Down Expand Up @@ -454,6 +455,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "FlipLeftRight",
srcs = ["kernels/FlipLeftRight.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "FloorDiv",
srcs = ["kernels/FloorDiv.cc"],
Expand Down
78 changes: 78 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/FlipLeftRight.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <math.h>
#include <cmath>
#include <cstddef>
#include <vector>

#include "src/cc/backend.h"
#include "src/cc/util.h"

namespace tfjs {
namespace wasm {

extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

void FlipLeftRight(const size_t image_id, const size_t batch,
const size_t image_height, const size_t image_width,
const size_t num_channels, const size_t out_id) {
auto& image_info = backend::get_tensor_info(image_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const float* image_buf = image_info.f32();
float* out_buf = out_info.f32_write();

for (size_t batch_idx = 0; batch_idx < batch; ++batch_idx) {
const size_t batch_offset =
batch_idx * image_width * image_height * num_channels;
for (size_t row = 0; row < image_height; ++row) {
const size_t row_offset = row * (image_width * num_channels);
for (size_t col = 0; col < image_width; ++col) {
const size_t col_offset = col * num_channels;

for (size_t channel = 0; channel < num_channels; ++channel) {
const size_t x = col;
const size_t coord_x = image_width - x;
const size_t image_idx =
batch_offset + row_offset + col_offset + channel;

float output_value = image_buf[image_idx];
// If the coordinate position falls within the image boundaries...
if (coord_x >= 0 && coord_x < image_width) {
const size_t flipped_col_offset = coord_x * num_channels;
const size_t rotated_image_idx =
batch_offset + row_offset + flipped_col_offset + channel;
output_value = image_buf[rotated_image_idx];
}

*out_buf = output_value;
out_buf++;
}
}
}
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
58 changes: 58 additions & 0 deletions tfjs-backend-wasm/src/kernels/FlipLeftRight.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {FlipLeftRight, FlipLeftRightInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

let wasmFlipLeftRight: (
xId: number, batch: number, imageHeight: number, imageWidth: number,
numChannels: number, outId: number) => void;

function setup(backend: BackendWasm) {
wasmFlipLeftRight = backend.wasm.cwrap(FlipLeftRight, null /* void */, [
'number', // xId
'number', // batch
'number', // imageHeight
'number', // imageWidth
'number', // numChannels
'number', // outId
]);
}

export function flipLeftRight(
args: {inputs: FlipLeftRightInputs, backend: BackendWasm}): TensorInfo {
const {inputs, backend} = args;
const {image} = inputs;

const out = backend.makeOutput(image.shape, image.dtype);
const imageId = backend.dataIdMap.get(image.dataId).id;
const outId = backend.dataIdMap.get(out.dataId).id;

const [batch, imageHeight, imageWidth, numChannels] = image.shape;

wasmFlipLeftRight(
imageId, batch, imageHeight, imageWidth, numChannels, outId);
return out;
}

export const flipLeftRightConfig: KernelConfig = {
kernelName: FlipLeftRight,
backendName: 'wasm',
kernelFunc: flipLeftRight as {} as KernelFunc,
setupFunc: setup
};
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import {divConfig} from './kernels/Div';
import {equalConfig} from './kernels/Equal';
import {expConfig} from './kernels/Exp';
import {fillConfig} from './kernels/Fill';
import {flipLeftRightConfig} from './kernels/FlipLeftRight';
import {floorDivConfig} from './kernels/FloorDiv';
import {fusedBatchNormConfig} from './kernels/FusedBatchNorm';
import {fusedConv2DConfig} from './kernels/FusedConv2D';
Expand Down Expand Up @@ -111,6 +112,7 @@ const kernelConfigs: KernelConfig[] = [
equalConfig,
expConfig,
fillConfig,
flipLeftRightConfig,
floorDivConfig,
fusedMatMulConfig,
fusedBatchNormConfig,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ const TEST_FILTERS: TestFilter[] = [
},
{include: 'slice '},
{include: 'rotate '},
{include: 'flipLeftRight '},
{include: 'square '},
{
startsWith: 'min ',
Expand Down
45 changes: 45 additions & 0 deletions tfjs-backend-webgl/src/flip_left_right_gpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {GPGPUProgram} from './gpgpu_math';

export class FlipLeftRightProgram implements GPGPUProgram {
variableNames = ['Image'];
outputShape: number[] = [];
userCode: string;

constructor(imageShape: [number, number, number, number]) {
const imageWidth = imageShape[2];
this.outputShape = imageShape;

this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];

int coordX = ${imageWidth} - x;
float outputValue;
if(coordX >= 0 && coordX < ${imageWidth}) {
outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
} else {
outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
}
setOutput(outputValue);
}
`;
}
}
35 changes: 35 additions & 0 deletions tfjs-backend-webgl/src/kernels/FlipLeftRight.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, Tensor4D} from '@tensorflow/tfjs-core';
import {FlipLeftRight, FlipLeftRightInputs} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {FlipLeftRightProgram} from '../flip_left_right_gpu';

export const flipLeftRightConfig: KernelConfig = {
kernelName: FlipLeftRight,
backendName: 'webgl',
kernelFunc: ({inputs, backend}) => {
const {image} = inputs as FlipLeftRightInputs;
const webglBackend = backend as MathBackendWebGL;

const program = new FlipLeftRightProgram((image as Tensor4D).shape);
const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
return output;
}
};
5 changes: 3 additions & 2 deletions tfjs-backend-webgl/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core';

import {divConfig} from './kernels/Div';
import {flipLeftRightConfig} from './kernels/FlipLeftRight';
import {fromPixelsConfig} from './kernels/FromPixels';
import {maxConfig} from './kernels/Max';
import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';
Expand All @@ -30,8 +31,8 @@ import {transposeConfig} from './kernels/Transpose';

// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
maxConfig, fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig,
nonMaxSuppressionV3Config, nonMaxSuppressionV4Config,
maxConfig, flipLeftRightConfig, fromPixelsConfig, divConfig,
maxPoolWithArgmaxConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config, rotateWithOffsetConfig, squareConfig,
squaredDifferenceConfig, transposeConfig
];
Expand Down
Loading