Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {Activation, FusedBatchMatMulConfig} from '../ops/fused_util';
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';

Expand Down Expand Up @@ -410,8 +410,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
throw new Error('Not yet implemented');
}

Expand All @@ -426,6 +426,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
throw new Error('Not yet implemented');
}

fusedDepthwiseConv2D(
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
throw new Error('Not yet implemented');
}

depthwiseConv2D(input: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
Tensor4D {
throw new Error('Not yet implemented');
Expand Down
24 changes: 20 additions & 4 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import {complex, imag, real} from '../../ops/complex_ops';
import * as concat_util from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import * as erf_util from '../../ops/erf_util';
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as ops from '../../ops/ops';
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
Expand Down Expand Up @@ -1531,9 +1531,9 @@ export class MathBackendCPU implements KernelBackend {
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
let result = this.conv2d(x, filter, convInfo);
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
let result = this.conv2d(input, filter, convInfo);

if (bias) {
result = this.add(result, bias) as Tensor4D;
Expand Down Expand Up @@ -1973,6 +1973,22 @@ export class MathBackendCPU implements KernelBackend {
return dw.toTensor();
}

fusedDepthwiseConv2D(
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
let result = this.depthwiseConv2D(input, filter, convInfo);

if (bias) {
result = this.add(result, bias) as Tensor4D;
}
if (activation) {
result =
mapActivation(this, result, activation, preluActivationWeights) as
Tensor4D;
}
return result;
}

depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
Tensor4D {
this.assertNotComplex([x, filter], 'depthwiseConv2D');
Expand Down
52 changes: 43 additions & 9 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import * as axis_util from '../../ops/axis_util';
import {complex, imag, real} from '../../ops/complex_ops';
import {computeOutShape} from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as reduce_util from '../../ops/reduce_util';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
Expand Down Expand Up @@ -1909,7 +1909,7 @@ export class MathBackendWebGL implements KernelBackend {
}

private conv2dByMatMul(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
// result from 2D to 4D.
Expand Down Expand Up @@ -2008,7 +2008,7 @@ export class MathBackendWebGL implements KernelBackend {
}

private conv2dWithIm2Row(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
// Rearranges conv2d input so each block to be convolved over forms the
// column of a new matrix with shape [filterWidth * filterHeight *
Expand Down Expand Up @@ -2067,19 +2067,19 @@ export class MathBackendWebGL implements KernelBackend {
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID')) {
return this.conv2dByMatMul(
x, filter, convInfo, bias, activation, preluActivationWeights);
input, filter, convInfo, bias, activation, preluActivationWeights);
}
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
if (ENV.getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) {
return this.conv2dWithIm2Row(
x, filter, convInfo, bias, activation, preluActivationWeights);
input, filter, convInfo, bias, activation, preluActivationWeights);
}

const hasBias = bias != null;
Expand All @@ -2088,7 +2088,7 @@ export class MathBackendWebGL implements KernelBackend {
activation ? mapActivationToShaderProgram(activation, false) : null;
const program = new Conv2DProgram(
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
const inputs: TensorHandle[] = [x, filter];
const inputs: TensorHandle[] = [input, filter];
if (bias) {
inputs.push(bias);
}
Expand Down Expand Up @@ -2124,6 +2124,40 @@ export class MathBackendWebGL implements KernelBackend {
return this.compileAndRun(program, [x, dy]);
}

fusedDepthwiseConv2D(
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
const shouldPackDepthwiseConv = ENV.getBool('WEBGL_PACK_DEPTHWISECONV') &&
convInfo.strideWidth <= 2 &&
convInfo.outChannels / convInfo.inChannels === 1;
const fusedActivation = activation ?
mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
null;
const inputs: Tensor[] = [input, filter];

const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
if (hasBias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}

let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;
if (shouldPackDepthwiseConv) {
program = new DepthwiseConvPacked2DProgram(
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
return this.compileAndRun(
program, inputs,
this.makePackedTensor(convInfo.outShape, input.dtype));
}

program = new DepthwiseConv2DProgram(
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
return this.compileAndRun(program, inputs);
}

depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
Tensor4D {
let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;
Expand Down
39 changes: 37 additions & 2 deletions tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
outputShape: number[];
userCode: string;

constructor(convInfo: Conv2DInfo) {
constructor(
convInfo: Conv2DInfo, addBias = false, activation: string = null,
hasPreluActivation = false) {
this.outputShape = convInfo.outShape;

const xNumRows = convInfo.inHeight;
Expand All @@ -38,7 +40,36 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
const filterWidth = convInfo.filterWidth;
const channelMul = convInfo.outChannels / convInfo.inChannels;

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
} else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}

applyActivationSnippet = `result = activation(result);`;
}

const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}

const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});

Expand Down Expand Up @@ -76,7 +107,11 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
dotProd += xVal * wVal;
}
}
setOutput(dotProd);

float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
Expand Down
38 changes: 35 additions & 3 deletions tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
outputShape: number[];
userCode: string;

constructor(convInfo: Conv2DInfo) {
constructor(
convInfo: Conv2DInfo, addBias = false, activation: string = null,
hasPreluActivation = false) {
this.outputShape = convInfo.outShape;

const xNumRows = convInfo.inHeight;
Expand Down Expand Up @@ -257,11 +259,38 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {

for (let r = 0; r < filterHeight; r++) {
for (let c = 0; c < filterWidth; c++) {
mainLoop += `result += xR${r}C${c} * wR${r}C${c};`;
mainLoop += `dotProd += xR${r}C${c} * wR${r}C${c};`;
}
}

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
} else {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
}

applyActivationSnippet = `result = activation(result);`;
}

const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}

const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});

Expand All @@ -276,10 +305,13 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;

vec4 result = vec4(0.);
vec4 dotProd = vec4(0.);

${mainLoop}

vec4 result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
Expand Down
11 changes: 6 additions & 5 deletions tfjs-core/src/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,9 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
`are not yet supported in gradients. Got dilations '${dilations}'`);

return {
x: () =>
conv2dDerInput_(x4D.shape, dy, $filter, strides, pad, dataFormat),
x: () => conv2dDerInput(x4D.shape, dy, $filter, strides, pad, dataFormat),
$filter: () =>
conv2dDerFilter_(x4D, dy, $filter.shape, strides, pad, dataFormat)
conv2dDerFilter(x4D, dy, $filter.shape, strides, pad, dataFormat)
};
};

Expand Down Expand Up @@ -675,7 +674,7 @@ function eitherStridesOrDilationsAreOne(
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}

function depthwiseConv2dDerInput<T extends Tensor3D|Tensor4D>(
function depthwiseConv2dDerInput_<T extends Tensor3D|Tensor4D>(
xShape: [number, number, number, number]|[number, number, number], dy: T,
filter: Tensor4D, convInfo: conv_util.Conv2DInfo): T {
let dy4D = dy as Tensor4D;
Expand All @@ -693,7 +692,7 @@ function depthwiseConv2dDerInput<T extends Tensor3D|Tensor4D>(
return res as T;
}

function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
function depthwiseConv2dDerFilter_<T extends Tensor3D|Tensor4D>(
x: T, dy: T, filterShape: [number, number, number, number],
convInfo: conv_util.Conv2DInfo): Tensor4D {
let x4D = x as Tensor4D;
Expand Down Expand Up @@ -973,6 +972,8 @@ export const conv3d = op({conv3d_});
export const conv2dDerFilter = op({conv2dDerFilter_});
export const conv2dDerInput = op({conv2dDerInput_});
export const depthwiseConv2d = op({depthwiseConv2d_});
export const depthwiseConv2dDerInput = op({depthwiseConv2dDerInput_});
export const depthwiseConv2dDerFilter = op({depthwiseConv2dDerFilter_});
export const separableConv2d = op({separableConv2d_});
export const conv2dTranspose = op({conv2dTranspose_});
export const conv3dTranspose = op({conv3dTranspose_});
Loading