Skip to content
Merged
26 changes: 11 additions & 15 deletions tfjs-backend-wasm/src/kernels/Reshape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,23 @@
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core';
import {TensorInfo} from '@tensorflow/tfjs-core';
import {NamedAttrMap, NamedTensorInfoMap, registerKernel, Reshape, ReshapeAttrs, ReshapeInputs} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface ReshapeInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface ReshapeAttrs extends NamedAttrMap {
shape: number[];
}

export function reshape(
args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {shape}} = args;
return {dataId: x.dataId, shape, dtype: x.dtype};
export function reshape(args: {
inputs: NamedTensorInfoMap,
attrs: NamedAttrMap,
backend: BackendWasm
}) {
const {inputs, attrs} = args;
const {tensor} = inputs as {} as ReshapeInputs;
const {shape} = attrs as {} as ReshapeAttrs;
return {dataId: tensor.dataId, shape, dtype: tensor.dtype};
}

registerKernel({
kernelName: 'Reshape',
kernelName: Reshape,
backendName: 'wasm',
kernelFunc: reshape,
});
36 changes: 17 additions & 19 deletions tfjs-backend-wasm/src/kernels/Unpack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,42 @@
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';
import {BackendWasm} from '../backend_wasm';
import {slice} from './Slice';
import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, Unpack, UnpackAttrs, UnpackInputs} from '@tensorflow/tfjs-core';

interface UnpackInputs extends NamedTensorInfoMap {
x: TensorInfo;
}
import {BackendWasm} from '../backend_wasm';

interface UnpackAttrs extends NamedAttrMap {
axis: number;
}
import {slice} from './Slice';

function unpack(
args: {inputs: UnpackInputs, backend: BackendWasm, attrs: UnpackAttrs}):
TensorInfo[] {
const {inputs: {x}, backend, attrs: {axis}} = args;
const numOutputs = x.shape[axis];
const rank = x.shape.length;
function unpack(args: {
inputs: NamedTensorInfoMap,
backend: BackendWasm,
attrs: NamedAttrMap
}): TensorInfo[] {
const {inputs, backend, attrs} = args;
const {value} = inputs as {} as UnpackInputs;
const {axis} = attrs as {} as UnpackAttrs;
const numOutputs = value.shape[axis];
const rank = value.shape.length;
const outShape: number[] = new Array(rank - 1);
let outIndex = 0;
for (let i = 0; i < rank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
outShape[outIndex++] = value.shape[i];
}
}
const outs: TensorInfo[] = new Array(numOutputs);
const begin = new Array(rank).fill(0);
const size = x.shape.slice();
const size = value.shape.slice();
size[axis] = 1;
for (let i = 0; i < outs.length; i++) {
begin[axis] = i;
outs[i] = slice({inputs: {x}, attrs: {begin, size}, backend});
outs[i] = slice({inputs: {x: value}, attrs: {begin, size}, backend});
}
return outs.map(({dataId, dtype}) => ({dataId, dtype, shape: outShape}));
}

registerKernel({
kernelName: 'Unpack',
kernelName: Unpack,
backendName: 'wasm',
kernelFunc: unpack,
});
3 changes: 2 additions & 1 deletion tfjs-backend-webgl/src/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ export class MathBackendWebGL extends KernelBackend {
inputs: TensorInfo[],
sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean {
const cpuBackend = this.getCPUBackend();
if (!this.warnedAboutCPUBackend && cpuBackend == null) {
if (!this.warnedAboutCPUBackend && cpuBackend == null &&
!env().getBool('IS_TEST')) {
console.warn(
'Your application contains ops that are small enough to be ' +
'executed on the CPU backend, however the CPU backend cannot ' +
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Atan2_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import {Atan2} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {add} from '../ops/add';
import {reshape} from '../ops/array_ops';
import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util';
import {div} from '../ops/div';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {square} from '../ops/square';
import {neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Div_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import {Div} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/array_ops';
import * as broadcast_util from '../ops/broadcast_util';
import {div} from '../ops/div';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {square} from '../ops/square';
import {neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/FusedBatchNorm_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {add} from '../ops/add';
import {reshape} from '../ops/array_ops';
import {getReductionAxes} from '../ops/broadcast_util';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {sub} from '../ops/sub';
import {scalar} from '../ops/tensor_ops';
import {tile} from '../ops/tile';
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Mod_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import {Mod} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/array_ops';
import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util';
import {div} from '../ops/div';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {floor, neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';

Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/gradients/Multiply_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

import {Multiply} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {cast, reshape} from '../ops/array_ops';
import {cast} from '../ops/array_ops';
import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {Tensor} from '../tensor';

export const multiplyGradConfig: GradConfig = {
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/gradients/Pow_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
*/
import {Pow} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {cast, reshape} from '../ops/array_ops';
import {cast} from '../ops/array_ops';
import * as broadcast_util from '../ops/broadcast_util';
import {greater} from '../ops/greater';
import {mul} from '../ops/mul';
import {pow} from '../ops/pow';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {sub} from '../ops/sub';
import {scalar, zerosLike} from '../ops/tensor_ops';
import {log} from '../ops/unary_ops';
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Prelu_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
*/
import {Prelu} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/array_ops';
import {getReductionAxes} from '../ops/broadcast_util';
import {greater} from '../ops/greater';
import {mul} from '../ops/mul';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {zerosLike} from '../ops/tensor_ops';
import {where} from '../ops/where';
import {Tensor} from '../tensor';
Expand Down
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/Reshape_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Reshape} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/reshape';
import {Tensor} from '../tensor';

export const reshapeGradConfig: GradConfig = {
kernelName: Reshape,
inputsToSave: ['tensor'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [tensor] = saved;
return {tensor: () => reshape(dy, tensor.shape)};
}
};
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Sub_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
import {Sub} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/array_ops';
import * as broadcast_util from '../ops/broadcast_util';
import {sum} from '../ops/reduction_ops';
import {reshape} from '../ops/reshape';
import {neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';

Expand Down
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/Unpack_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Unpack, UnpackAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {stack} from '../ops/stack';
import {Tensor} from '../tensor';

export const unpackGradConfig: GradConfig = {
kernelName: Unpack,
gradFunc: (dy: Tensor[], saved: Tensor[], attrs: NamedAttrMap) => {
const unpackAttrs: UnpackAttrs = attrs as {} as UnpackAttrs;
const {axis} = unpackAttrs;
return {value: () => stack(dy, axis)};
}
};
12 changes: 12 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ export type RealInputs = Pick<NamedTensorInfoMap, 'input'>;
export const Relu = 'Relu';
export type ReluInputs = Pick<NamedTensorInfoMap, 'x'>;

export const Reshape = 'Reshape';
export type ReshapeInputs = Pick<NamedTensorInfoMap, 'tensor'>;
export interface ReshapeAttrs {
shape: number[];
}

export const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
export type ResizeNearestNeighborInputs = Pick<NamedTensorInfoMap, 'images'>;
export interface ResizeNearestNeighborAttrs {
Expand Down Expand Up @@ -508,6 +514,12 @@ export interface TransposeAttrs {
perm: number[];
}

export const Unpack = 'Unpack';
export type UnpackInputs = Pick<NamedTensorInfoMap, 'value'>;
export interface UnpackAttrs {
axis: number;
}

/**
* TensorFlow.js-only kernels
*/
Expand Down
Loading