Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions tfjs-backend-wasm/src/kernels/FusedConv2D.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
* =============================================================================
*/

import {backend_util, KernelConfig, KernelFunc, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core';
import {backend_util, FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs, KernelConfig, KernelFunc, Tensor4D} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {FusableActivation} from './types';

interface FusedConv2DInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
bias?: TensorInfo;
}

let wasmFusedConv2d: (
xId: number, batchSize: number, inputHeight: number, inputWidth: number,
filterId: number, filterHeight: number, filterWidth: number, biasId: number,
Expand Down Expand Up @@ -66,11 +60,17 @@ function setup(backend: BackendWasm) {
function fusedConv2d(args: {
inputs: FusedConv2DInputs,
backend: BackendWasm,
attrs:
{convInfo: backend_util.Conv2DInfo, activation: backend_util.Activation}
attrs: FusedConv2DAttrs
}) {
const {inputs, attrs, backend} = args;
const {convInfo, activation} = attrs;
const {x, filter, bias, preluActivationWeights} = inputs;
const {strides, pad, dilations, dataFormat, dimRoundingMode, activation} =
attrs;

const convInfo = backend_util.computeConv2DInfo(
(x as Tensor4D).shape, (filter as Tensor4D).shape, strides, dilations,
pad, dimRoundingMode);

const fusedActivation =
FusableActivation[activation as {} as keyof typeof FusableActivation];
if (fusedActivation == null) {
Expand All @@ -79,7 +79,6 @@ function fusedConv2d(args: {
`in the wasm backend.`);
}

const {x, filter, bias, preluActivationWeights} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
const filterId = backend.dataIdMap.get(filter.dataId).id;

Expand Down Expand Up @@ -117,10 +116,10 @@ function fusedConv2d(args: {
const inHeight = convInfo.inHeight;
const inWidth = convInfo.inWidth;

if (convInfo.dataFormat !== 'channelsLast') {
if (dataFormat !== 'NHWC') {
throw new Error(
`wasm backend FusedConv2D does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
`${dataFormat}'. Please use 'NHWC'.`);
}

const out = backend.makeOutput(convInfo.outShape, 'float32');
Expand All @@ -137,7 +136,7 @@ function fusedConv2d(args: {
}

export const fusedConv2DConfig: KernelConfig = {
kernelName: 'FusedConv2D',
kernelName: FusedConv2D,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: fusedConv2d as {} as KernelFunc
Expand Down
29 changes: 14 additions & 15 deletions tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
* =============================================================================
*/

import {backend_util, KernelConfig, KernelFunc, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core';
import {backend_util, FusedDepthwiseConv2D, FusedDepthwiseConv2DAttrs, FusedDepthwiseConv2DInputs, KernelConfig, KernelFunc, Tensor4D} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {FusableActivation} from './types';

interface FusedDepthwiseConv2DInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
bias?: TensorInfo;
}

let wasmFusedDepthwiseConv2d: (
xId: number, batchSize: number, inputHeight: number, inputWidth: number,
filterId: number, filterHeight: number, filterWidth: number, biasId: number,
Expand All @@ -38,7 +32,7 @@ let wasmFusedDepthwiseConv2d: (

function setup(backend: BackendWasm) {
wasmFusedDepthwiseConv2d =
backend.wasm.cwrap('FusedDepthwiseConv2D', null /* void */, [
backend.wasm.cwrap(FusedDepthwiseConv2D, null /* void */, [
'number', // xId
'number', // batchSize
'number', // inputHeight
Expand Down Expand Up @@ -67,11 +61,17 @@ function setup(backend: BackendWasm) {
function fusedDepthwiseConv2d(args: {
inputs: FusedDepthwiseConv2DInputs,
backend: BackendWasm,
attrs:
{convInfo: backend_util.Conv2DInfo, activation: backend_util.Activation}
attrs: FusedDepthwiseConv2DAttrs
}) {
const {inputs, attrs, backend} = args;
const {convInfo, activation} = attrs;
const {x, filter, bias, preluActivationWeights} = inputs;
const {strides, pad, dilations, dataFormat, dimRoundingMode, activation} =
attrs;

const convInfo = backend_util.computeConv2DInfo(
(x as Tensor4D).shape, (filter as Tensor4D).shape, strides, dilations,
pad, dimRoundingMode);

const fusedActivation =
FusableActivation[activation as {} as keyof typeof FusableActivation];
if (fusedActivation == null) {
Expand All @@ -80,7 +80,6 @@ function fusedDepthwiseConv2d(args: {
`in the wasm backend.`);
}

const {x, filter, bias, preluActivationWeights} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
const filterId = backend.dataIdMap.get(filter.dataId).id;

Expand Down Expand Up @@ -118,10 +117,10 @@ function fusedDepthwiseConv2d(args: {
const inHeight = convInfo.inHeight;
const inWidth = convInfo.inWidth;

if (convInfo.dataFormat !== 'channelsLast') {
if (dataFormat !== 'NHWC') {
throw new Error(
`wasm backend FusedDepthwiseConv2D does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
`${dataFormat}'. Please use 'NHWC'.`);
}

const out = backend.makeOutput(convInfo.outShape, 'float32');
Expand All @@ -138,7 +137,7 @@ function fusedDepthwiseConv2d(args: {
}

export const fusedDepthwiseConv2DConfig: KernelConfig = {
kernelName: 'FusedDepthwiseConv2D',
kernelName: FusedDepthwiseConv2D,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: fusedDepthwiseConv2d as {} as KernelFunc
Expand Down
25 changes: 6 additions & 19 deletions tfjs-backend-wasm/src/kernels/_FusedMatMul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,20 @@
* =============================================================================
*/

import {KernelConfig, NamedAttrMap, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core';
import {_FusedMatMul, _FusedMatMulAttrs, _FusedMatMulInputs, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {FusableActivation} from './types';

interface FusedMatMulInputs extends NamedTensorInfoMap {
a: TensorInfo;
b: TensorInfo;
bias?: TensorInfo;
preluActivationWeights?: TensorInfo;
}

interface FusedMatMulAttrs extends NamedAttrMap {
transposeA: boolean;
transposeB: boolean;
activation: FusableActivation;
}

let wasmFusedMatMul: (
aId: number, aShape: Uint8Array, aShapeSize: number, bId: number,
bShape: Uint8Array, bShapeSize: number, transposeA: boolean,
transposeB: boolean, activation: number, biasId: number,
preluActivationWeightsId: number, outId: number) => void;

function setup(backend: BackendWasm) {
wasmFusedMatMul = backend.wasm.cwrap('_FusedMatMul', null /* void */, [
wasmFusedMatMul = backend.wasm.cwrap(_FusedMatMul, null /* void */, [
'number', // a_id
'array', // a_shape
'number', // a_shape.length
Expand All @@ -58,9 +45,9 @@ function setup(backend: BackendWasm) {
}

function fusedBatchMatMul(args: {
inputs: FusedMatMulInputs,
inputs: _FusedMatMulInputs,
backend: BackendWasm,
attrs: FusedMatMulAttrs
attrs: _FusedMatMulAttrs
}) {
const {inputs, backend, attrs} = args;
const {a, b, bias, preluActivationWeights} = inputs;
Expand Down Expand Up @@ -114,8 +101,8 @@ function fusedBatchMatMul(args: {
}

export const fusedMatMulConfig: KernelConfig = {
kernelName: '_FusedMatMul',
kernelName: _FusedMatMul,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: fusedBatchMatMul
kernelFunc: fusedBatchMatMul as {} as KernelFunc
};
4 changes: 3 additions & 1 deletion tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ const TEST_FILTERS: TestFilter[] = [
'basic with elu', // Only fused relu, relu6, prelu activations
// supported.
'gradient', // Gradients not defined yet.
'NCHW', // xnn pack does not support channels first.
'backProp input x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // Gradients not
// defined.
'NCHW', // xnn pack does not support channels first.
// Issue: https://github.com/tensorflow/tfjs/issues/3104.
// Actual != expected.
'relu bias stride 2 x=[1,8,8,16] f=[3,3,16,1] s=[2,2] d=8 p=same',
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_types';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, Rank, ShapeMap} from '../types';

Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ export * from '../ops/axis_util';
export * from '../ops/broadcast_util';
export * from '../ops/concat_util';
export * from '../ops/conv_util';
export {Activation, FusedConv2DConfig} from '../ops/fused_util';
export * from '../ops/fused_util';
export * from '../ops/fused_types';
export * from '../ops/reduce_util';

export {BackendValues, TypedArray, upcastType, PixelData} from '../types';
Expand Down
48 changes: 48 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import {ExplicitPadding} from '../src/ops/conv_util';

import {NamedTensorInfoMap, TensorInfo} from './kernel_registry';
import {Activation} from './ops/fused_types';
import {DataType, PixelData} from './types';

export const Abs = 'Abs';
Expand Down Expand Up @@ -786,3 +787,50 @@ export interface RotateWithOffsetAttrs {
fillValue: number|[number, number, number];
center: number|[number, number];
}

export const _FusedMatMul = '_FusedMatMul';
// tslint:disable-next-line: class-name
export interface _FusedMatMulInputs extends NamedTensorInfoMap {
a: TensorInfo;
b: TensorInfo;
bias?: TensorInfo;
preluActivationWeights?: TensorInfo;
}
// tslint:disable-next-line: class-name
export interface _FusedMatMulAttrs {
transposeA: boolean;
transposeB: boolean;
activation: Activation;
}

export const FusedConv2D = 'FusedConv2D';
export interface FusedConv2DInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
bias?: TensorInfo;
preluActivationWeights?: TensorInfo;
}
export interface FusedConv2DAttrs {
strides: [number, number]|number;
pad: 'valid'|'same'|number|ExplicitPadding;
dataFormat: 'NHWC'|'NCHW';
dilations: [number, number]|number;
dimRoundingMode: 'floor'|'round'|'ceil';
activation: Activation;
}

export const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
export interface FusedDepthwiseConv2DInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
bias?: TensorInfo;
preluActivationWeights?: TensorInfo;
}
export interface FusedDepthwiseConv2DAttrs {
strides: [number, number]|number;
pad: 'valid'|'same'|number;
dataFormat: 'NHWC'|'NCHW';
dilations: [number, number]|number;
dimRoundingMode: 'floor'|'round'|'ceil';
activation: Activation;
}
Loading