Skip to content

Commit

Permalink
Modularize 8 more kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
lina128 committed Nov 19, 2020
1 parent c90287c commit e34ab88
Show file tree
Hide file tree
Showing 15 changed files with 628 additions and 329 deletions.
311 changes: 0 additions & 311 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ const split = kernel_impls.split;
const tile = kernel_impls.tile;
const topkImpl = kernel_impls.topkImpl;
const whereImpl = kernel_impls.whereImpl;
import * as seedrandom from 'seedrandom';
import {assertNotComplex} from './cpu_util';

interface DataId {}
Expand Down Expand Up @@ -242,20 +241,6 @@ export class MathBackendCPU extends KernelBackend {
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
}

minimum(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'minimum');

return this.broadcastedBinaryOp(
a, b, a.dtype, (aVal, bVal) => Math.min(aVal, bVal));
}

maximum(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'maximum');

return this.broadcastedBinaryOp(
a, b, a.dtype, (aVal, bVal) => Math.max(aVal, bVal));
}

eluDer<T extends Tensor>(dy: T, y: T): T {
assertNotComplex([dy, y], 'eluDer');

Expand Down Expand Up @@ -314,27 +299,6 @@ export class MathBackendCPU extends KernelBackend {
return result.toTensor().reshape(shapeInfo.outputShape);
}

batchToSpaceND<T extends Tensor>(
x: T, blockShape: number[], crops: number[][]): T {
assertNotComplex([x], 'batchToSpaceND');

const prod = blockShape.reduce((a, b) => a * b);

const reshaped = backend_util.getReshaped(x.shape, blockShape, prod);
const permuted =
backend_util.getPermuted(reshaped.length, blockShape.length);
const reshapedPermuted =
backend_util.getReshapedPermuted(x.shape, blockShape, prod);
const sliceBeginCoords =
backend_util.getSliceBeginCoords(crops, blockShape.length);
const sliceSize =
backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length);

return tf.transpose(x.reshape(reshaped), permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize) as T;
}

private pool3d(
x: Tensor5D, convInfo: backend_util.Conv3DInfo,
poolType: 'max'|'avg'): Tensor5D {
Expand Down Expand Up @@ -1065,47 +1029,6 @@ export class MathBackendCPU extends KernelBackend {
return tf.tensor4d(result, dy.shape);
}

multinomial(
logits: Tensor2D, normalized: boolean, numSamples: number,
seed: number): Tensor2D {
assertNotComplex(logits, 'multinomial');

const probabilities = normalized ? logits : tf.softmax(logits);
const batchSize = probabilities.shape[0];
const numEvents = probabilities.shape[1];
const res = tf.zeros<Rank.R2>([batchSize, numSamples], 'int32');
const resVals = this.readSync(res.dataId) as TypedArray;
const probVals = this.readSync(probabilities.dataId) as TypedArray;

for (let b = 0; b < batchSize; ++b) {
const offset = b * numEvents;
// The cdf won't include the last event. It will be implicit if no other
// event happened.
const cdf = new Float32Array(numEvents - 1);
cdf[0] = probVals[offset];
for (let event = 1; event < cdf.length; ++event) {
cdf[event] = cdf[event - 1] + probVals[offset + event];
}

const random = seedrandom.alea(seed.toString());
const outOffset = b * numSamples;
for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
const r = random();

// Assume last event happened by default.
resVals[outOffset + sampleId] = cdf.length;

for (let event = 0; event < cdf.length; event++) {
if (r < cdf[event]) {
resVals[outOffset + sampleId] = event;
break;
}
}
}
}
return res;
}

nonMaxSuppression(
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
iouThreshold: number, scoreThreshold: number): Tensor1D {
Expand All @@ -1117,87 +1040,6 @@ export class MathBackendCPU extends KernelBackend {
boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
}

depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):
Tensor4D {
util.assert(
dataFormat === 'NHWC',
() => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${
dataFormat}`);
util.assert(
blockSize > 1,
() =>
`blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);

const batchSize = x.shape[0];
const inputHeight = x.shape[1];
const inputWidth = x.shape[2];
const inputDepth = x.shape[3];

const outputHeight = inputHeight * blockSize;
const outputWidth = inputWidth * blockSize;
const outputDepth = inputDepth / (blockSize * blockSize);

const xValues = this.readSync(x.dataId) as TypedArray;
const result =
new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);

let outputIdx = 0;
for (let b = 0; b < batchSize; ++b) {
for (let h = 0; h < outputHeight; ++h) {
const inH = Math.floor(h / blockSize);
const offsetH = (h % blockSize);
for (let w = 0; w < outputWidth; ++w) {
const inW = Math.floor(w / blockSize);
const offsetW = (w % blockSize);
const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
for (let d = 0; d < outputDepth; ++d) {
const inD = d + offsetD;
const inputIdx =
inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
result[outputIdx++] = xValues[inputIdx];
}
}
}
}
return tf.tensor4d(
result, [batchSize, outputHeight, outputWidth, outputDepth]);
}

private broadcastedBinaryOp(
a: Tensor, b: Tensor, dtype: DataType,
op: (a: number, b: number) => number): Tensor {
const newShape = backend_util.assertAndGetBroadcastShape(a.shape, b.shape);
const result = tf.buffer(newShape, dtype);
const aVals = this.readSync(a.dataId) as TypedArray;
const bVals = this.readSync(b.dataId) as TypedArray;
const aBroadcastDims = backend_util.getBroadcastDims(a.shape, newShape);
const bBroadcastDims = backend_util.getBroadcastDims(b.shape, newShape);

const resVals = result.values;
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < resVals.length; ++i) {
resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
const aBuf = this.bufferSync(a);
const bBuf = this.bufferSync(b);
for (let i = 0; i < resVals.length; ++i) {
const loc = result.indexToLoc(i);

const aLoc = loc.slice(-a.rank);
aBroadcastDims.forEach(d => aLoc[d] = 0);
const aIndex = aBuf.locToIndex(aLoc);

const bLoc = loc.slice(-b.rank);
bBroadcastDims.forEach(d => bLoc[d] = 0);
const bIndex = bBuf.locToIndex(bLoc);

resVals[i] = op(aVals[aIndex], bVals[bIndex]);
}
}
return result.toTensor();
}

split<T extends Tensor>(x: T, sizeSplits: number[], axis: number): T[] {
return split(x, sizeSplits, axis);
}
Expand All @@ -1213,143 +1055,6 @@ export class MathBackendCPU extends KernelBackend {
return super.epsilon();
}

cropAndResize(
images: Tensor4D,
boxes: Tensor2D,
boxIndex: Tensor1D,
cropSize: [number, number],
method: string,
extrapolationValue: number,
) {
const [batch, imageHeight, imageWidth, numChannels] = images.shape;
const numBoxes = boxes.shape[0];

const [cropHeight, cropWidth] = cropSize;
const output =
tf.buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');

const boxVals = this.readSync(boxes.dataId) as TypedArray;
const boxIndVals = this.readSync(boxIndex.dataId) as TypedArray;
const imageVals = this.readSync(images.dataId) as TypedArray;

const inStride = images.strides; // to calculate flat indexes into image
const outStride = output.strides; // to calculate flat indexes into output

// Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
for (let b = 0; b < numBoxes; b++) {
const startInd = b * 4;
const y1 = boxVals[startInd];
const x1 = boxVals[startInd + 1];
const y2 = boxVals[startInd + 2];
const x2 = boxVals[startInd + 3];

const bInd: number = boxIndVals[b];
if (bInd >= batch) {
continue;
}

const heightScale = (cropHeight > 1) ?
(y2 - y1) * (imageHeight - 1) / (cropHeight - 1) :
0;
const widthScale =
(cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;

for (let y = 0; y < cropHeight; y++) {
const yInd: number = (cropHeight > 1) ?
y1 * (imageHeight - 1) + y * (heightScale) :
0.5 * (y1 + y2) * (imageHeight - 1);

if (yInd < 0 || yInd > imageHeight - 1) {
for (let x = 0; x < cropWidth; x++) {
for (let c = 0; c < numChannels; c++) {
const ind =
c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
}
continue;
}

if (method === 'bilinear') {
const topInd = Math.floor(yInd);
const bottomInd = Math.ceil(yInd);
const yLerp = yInd - topInd;

for (let x = 0; x < cropWidth; x++) {
const xInd = (cropWidth > 1) ?
x1 * (imageWidth - 1) + x * widthScale :
0.5 * (x1 + x2) * (imageWidth - 1);

if (xInd < 0 || xInd > imageWidth - 1) {
for (let c = 0; c < numChannels; c++) {
const ind =
c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}

const leftInd = Math.floor(xInd);
const rightInd = Math.ceil(xInd);
const xLerp = xInd - leftInd;

for (let c = 0; c < numChannels; c++) {
let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
bInd * inStride[0];
const topLeft = imageVals[ind];

ind = c + rightInd * inStride[2] + topInd * inStride[1] +
bInd * inStride[0];
const topRight = imageVals[ind];

ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
bInd * inStride[0];
const bottomLeft = imageVals[ind];

ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
bInd * inStride[0];
const bottomRight = imageVals[ind];

const top = topLeft + (topRight - topLeft) * xLerp;
const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;

ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = top + ((bottom - top) * yLerp);
}
}
} else { // method == "nearest"
for (let x = 0; x < cropWidth; ++x) {
const xInd = (cropWidth > 1) ?
x1 * (imageWidth - 1) + x * widthScale :
0.5 * (x1 + x2) * (imageWidth - 1);

if (xInd < 0 || xInd > imageWidth - 1) {
for (let c = 0; c < numChannels; c++) {
const ind =
c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}

const closestX = Math.round(xInd);
const closestY = Math.round(yInd);
for (let c = 0; c < numChannels; c++) {
const inInd = c + closestX * inStride[2] +
closestY * inStride[1] + bInd * inStride[0];
const outInd =
c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[outInd] = imageVals[inInd];
}
}
}
}
}
return output.toTensor() as Tensor4D;
}

sparseToDense<R extends Rank>(
sparseIndices: Tensor, sparseValues: Tensor, outputShape: ShapeMap[R],
defaultValue: Scalar): Tensor<R> {
Expand Down Expand Up @@ -1406,22 +1111,6 @@ export class MathBackendCPU extends KernelBackend {
strides, defaultValue, sumDupeIndices);
}

onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
if (x.dtype === 'string') {
throw new Error('onesLike is not supported for string tensors');
} else {
// TODO(lina128): Use fill kernel directly once this kernel is
// modularized.
return tf.fill(x.shape, 1, x.dtype);
}
}

zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
const values = util.getArrayFromDType(
x.dtype, util.sizeFromShape(x.shape)) as TypedArray;
return this.makeOutput(values, x.shape, x.dtype);
}

linspace(start: number, stop: number, num: number): Tensor1D {
return backend_util.linspaceImpl(start, stop, num);
}
Expand Down

0 comments on commit e34ab88

Please sign in to comment.