Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

2d atrous convolution and atrous depthwise convolution #794

Merged
merged 3 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ export class MathBackendCPU implements KernelBackend {
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const padLeft = convInfo.padInfo.left;
const padTop = convInfo.padInfo.top;
const y = ops.buffer<Rank.R4>(convInfo.outShape, x.dtype);
Expand All @@ -871,17 +873,24 @@ export class MathBackendCPU implements KernelBackend {
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * convInfo.strideHeight - padLeft;
const xRMin = Math.max(0, xRCorner);
const xRMax = Math.min(convInfo.inHeight, filterHeight + xRCorner);
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * convInfo.strideWidth - padTop;
const xCMin = Math.max(0, xCCorner);
const xCMax = Math.min(convInfo.inWidth, filterWidth + xCCorner);

let dotProd = 0;
for (let xR = xRMin; xR < xRMax; ++xR) {
const wR = xR - xRCorner;
for (let xC = xCMin; xC < xCMax; ++xC) {
const wC = xC - xCCorner;
for (let wR = 0; wR < filterHeight; wR++) {
const xR = xRCorner + wR * dilationHeight;

if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}

for (let wC = 0; wC < filterWidth; wC++) {
const xC = xCCorner + wC * dilationWidth;

if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}

for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
const pixel = x.get(b, xR, xC, d1);
const weight = filter.get(wR, wC, d1, d2);
Expand Down Expand Up @@ -989,6 +998,8 @@ export class MathBackendCPU implements KernelBackend {
Tensor4D {
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const padLeft = convInfo.padInfo.left;
const padTop = convInfo.padInfo.top;
const chMul = convInfo.outChannels / convInfo.inChannels;
Expand All @@ -998,18 +1009,24 @@ export class MathBackendCPU implements KernelBackend {
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * convInfo.strideHeight - padLeft;
const xRMin = Math.max(0, xRCorner);
const xRMax = Math.min(convInfo.inHeight, filterHeight + xRCorner);
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * convInfo.strideWidth - padTop;
const xCMin = Math.max(0, xCCorner);
const xCMax = Math.min(convInfo.inWidth, filterWidth + xCCorner);
for (let q = 0; q < chMul; ++q) {
let dotProd = 0;
for (let xR = xRMin; xR < xRMax; ++xR) {
const wR = xR - xRCorner;
for (let xC = xCMin; xC < xCMax; ++xC) {
const wC = xC - xCCorner;
for (let wR = 0; wR < filterHeight; ++wR) {
const xR = xRCorner + wR * dilationHeight;

if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}

for (let wC = 0; wC < filterWidth; ++wC) {
const xC = xCCorner + wC * dilationWidth;

if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}

const pixel = x.get(b, xR, xC, d1);
const weight = filter.get(wR, wC, d1, q);
dotProd += pixel * weight;
Expand All @@ -1021,6 +1038,7 @@ export class MathBackendCPU implements KernelBackend {
}
}
}

return y.toTensor();
}

Expand Down
6 changes: 4 additions & 2 deletions src/kernels/webgl/conv_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ export class Conv2DProgram implements GPGPUProgram {
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;

Expand All @@ -52,14 +54,14 @@ export class Conv2DProgram implements GPGPUProgram {
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR;
int xR = xRCorner + wR * ${dilationHeight};

if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}

for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC;
int xC = xCCorner + wC * ${dilationWidth};

if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
Expand Down
6 changes: 4 additions & 2 deletions src/kernels/webgl/conv_gpu_depthwise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const channelMul = convInfo.outChannels / convInfo.inChannels;
Expand All @@ -56,14 +58,14 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
float dotProd = 0.0;
// TODO(dsmilkov): Flatten the two for loops and vec4 the operations.
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR;
int xR = xRCorner + wR * ${dilationHeight};

if (xR < 0 || xR >= ${xNumRows}) {
continue;
}

for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC;
int xC = xCCorner + wC * ${dilationWidth};

if (xC < 0 || xC >= ${xNumCols}) {
continue;
Expand Down
79 changes: 68 additions & 11 deletions src/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ export class ConvOps {
* - For more info, see this guide:
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
* @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
* the data is stored in the order of [batch, in_width, in_channels]. Only
* "NWC" is currently supported.
* @param dilation The dilation rate in which we sample input values in
* atrous convolution. Defaults to `1`. If it is greater than 1, then
* stride must be `1`.
* @param dimRoundingMode The rounding mode used when computing output
* dimensions if pad is a number. If none is provided, it will not round
* and error if the output is of fractional size.
Expand All @@ -48,6 +54,7 @@ export class ConvOps {
@operation
static conv1d<T extends Tensor2D|Tensor3D>(
input: T, filter: Tensor3D, stride: number, pad: 'valid'|'same'|number,
dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
let input3D = input as Tensor3D;
let reshapedTo3D = false;
Expand All @@ -74,15 +81,27 @@ export class ConvOps {
input3D.shape[2] === filter.shape[1],
`Error in conv1d: depth of input (${input3D.shape[2]}) must match ` +
`input depth for filter ${filter.shape[1]}.`);
util.assert(
eitherStridesOrDilationsAreOne(stride, dilation),
'Error in conv1D: Either stride or dilation must be 1.' +
`Got stride ${stride} and dilation '${dilation}'`);
util.assert(
dataFormat === 'NWC',
`Error in conv1d: got dataFormat of ${
dataFormat} but only NWC is currently supported.`);

const filter4D =
filter.as4D(1, filter.shape[0], filter.shape[1], filter.shape[2]);
const input4D =
input3D.as4D(input3D.shape[0], 1, input3D.shape[1], input3D.shape[2]);
const strides: [number, number] = [1, stride];
const dilations: [number, number] = [1, dilation];

const conv2dDataFormat = 'NHWC';

const res =
ConvOps.conv2d(input4D, filter4D, strides, pad, dimRoundingMode);
const res = ConvOps.conv2d(
input4D, filter4D, strides, pad, conv2dDataFormat, dilations,
dimRoundingMode);

if (reshapedTo3D) {
return res.as2D(res.shape[2], res.shape[3]) as T;
Expand All @@ -108,6 +127,15 @@ export class ConvOps {
* - For more info, see this guide:
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dimRoundingMode The rounding mode used when computing output
* dimensions if pad is a number. If none is provided, it will not round
* and error if the output is of fractional size.
Expand All @@ -116,7 +144,9 @@ export class ConvOps {
@operation
static conv2d<T extends Tensor3D|Tensor4D>(
x: T, filter: Tensor4D, strides: [number, number]|number,
pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
dilations: [number, number]|number = [1, 1],
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
let x4D = x as Tensor4D;
let reshapedTo4D = false;

Expand All @@ -142,13 +172,24 @@ export class ConvOps {
x4D.shape[3] === filter.shape[2],
`Error in conv2d: depth of input (${x4D.shape[3]}) must match ` +
`input depth for filter ${filter.shape[2]}.`);

const dilations = 1;
util.assert(
eitherStridesOrDilationsAreOne(strides, dilations),
'Error in conv2D: Either strides or dilations must be 1.' +
`Got strides ${strides} and dilations '${dilations}'`);
util.assert(
dataFormat === 'NHWC',
`Error in conv2d: got dataFormat of ${
dataFormat} but only NHWC is currently supported.`);

const convInfo = conv_util.computeConv2DInfo(
x4D.shape, filter.shape, strides, dilations, pad, dimRoundingMode);

const grad = (dy: Tensor4D) => {
util.assert(
tupleValuesAreOne(dilations),
'Error in gradient of conv2D: dilation rates greater than 1 are not' +
`yet supported in gradients. Got dilations '${dilations}'`);

return {
x: () => ConvOps.conv2dDerInput(x4D.shape, dy, filter, strides, pad),
filter: () =>
Expand Down Expand Up @@ -375,9 +416,13 @@ export class ConvOps {
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
* @param dimRoundingMode The rounding mode used when computing output
* dimensions if pad is a number. If none is provided, it will not round
* and error if the output is of fractional size.
Expand All @@ -386,7 +431,8 @@ export class ConvOps {
@operation
static depthwiseConv2d<T extends Tensor3D|Tensor4D>(
input: T, filter: Tensor4D, strides: [number, number]|number,
pad: 'valid'|'same'|number, dilations: [number, number]|number = [1, 1],
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
dilations: [number, number]|number = [1, 1],
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
let input4D = input as Tensor4D;
let reshapedTo4D = false;
Expand All @@ -410,11 +456,11 @@ export class ConvOps {
if (dilations == null) {
dilations = [1, 1];
}
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
util.assert(
dilationHeight === 1 && dilationWidth === 1,
'Error in depthwiseConv2D: dilation rates greater than 1 are not yet ' +
`supported. Got dilations '${dilations}'`);
eitherStridesOrDilationsAreOne(strides, dilations),
'Error in depthwiseConv2d: Either strides or dilations must be 1.' +
`Got strides ${strides} and dilations '${dilations}'`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
Expand All @@ -438,3 +484,14 @@ export class ConvOps {
function parseTupleParam(param: number|[number, number]): [number, number] {
return typeof param === 'number' ? [param, param] : param;
}

function tupleValuesAreOne(param: number|[number, number]): boolean {
const [dimA, dimB] = parseTupleParam(param);
return dimA === 1 && dimB === 1;
}

function eitherStridesOrDilationsAreOne(
strides: number|[number, number],
dilations: number|[number, number]): boolean {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
Loading