Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import * as seedrandom from 'seedrandom';

import {ENGINE} from '../../engine';
import {env} from '../../environment';

import {warn} from '../../log';
import * as array_ops_util from '../../ops/array_ops_util';
import * as axis_util from '../../ops/axis_util';
import * as broadcast_util from '../../ops/broadcast_util';
import {complex, imag, real} from '../../ops/complex_ops';
import * as concat_util from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {div} from '../../ops/div';
import * as erf_util from '../../ops/erf_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
Expand All @@ -47,6 +47,7 @@ import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';

import {assertNotComplex} from './cpu_util';

function mapActivation(
Expand Down Expand Up @@ -377,17 +378,6 @@ export class MathBackendCPU extends KernelBackend {
return result.toTensor() as T;
}

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
const b = this.exp(a);
const sumExp = this.sum(b, axes).reshape(expandedShape);

return this.realDivide(b, sumExp) as T;
}

subtract(a: Tensor, b: Tensor): Tensor {
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
return this.broadcastedBinaryComplexOp(
Expand Down Expand Up @@ -493,14 +483,6 @@ export class MathBackendCPU extends KernelBackend {
(aValue, bValue) => aValue * bValue);
}

realDivide(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'realDivide');

const op = (a: number, b: number) => a / b;
const outputDtype = 'float32';
return this.broadcastedBinaryOp(a, b, outputDtype, op);
}

floorDiv(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'floorDiv');

Expand Down
29 changes: 29 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Div.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Div} from '../../../kernel_names';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
<<<<<<< HEAD
import {createBinaryOp} from '../utils/kernel_utils';

export const div = createBinaryOp((a: number, b: number) => a / b);
=======
import {createBinaryKernelImpl} from '../utils/kernel_utils';

export const div = createBinaryKernelImpl((a: number, b: number) => a / b);
>>>>>>> modularize_div
export const divConfig = createBinaryKernelConfig(Div, div);
48 changes: 48 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Exp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Exp, ExpInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import * as util from '../../../util';
import {MathBackendCPU} from '../backend_cpu';

export const exp = (x: TypedArray): TypedArray => {
const outValues = util.getTypedArrayFromDType('float32', x.length);

for (let i = 0; i < x.length; ++i) {
outValues[i] = Math.exp(x[i]);
}

return outValues;
};

export const expConfig: KernelConfig = {
kernelName: Exp,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {x} = inputs as ExpInputs;
const cpuBackend = backend as MathBackendCPU;

const xVals = cpuBackend.data.get(x.dataId).values as Float32Array;

const result = exp(xVals);

const dataId = cpuBackend.write(result, x.shape, x.dtype);
return {dataId, shape: x.shape, dtype: x.dtype};
}
};
69 changes: 69 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Max.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import * as axis_util from '../../../ops/axis_util';
import {DataType, NumericDataType, TypedArray} from '../../../types';
import * as util from '../../../util';
import {sizeFromShape} from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export const max =
(x: TypedArray, reduceSize: number, outShape: number[], dtype: DataType):
TypedArray => {
const outValues = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(outShape));

for (let i = 0; i < x.length; ++i) {
const offset = i * reduceSize;
let max = x[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = x[offset + j];
if (value > max) {
max = value;
}
}
outValues[i] = max;
}

return outValues;
};

export const maxConfig: KernelConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as MaxInputs;
const {axes} = attrs as {} as MaxAttrs;
const cpuBackend = backend as MathBackendCPU;

assertNotComplex(x, 'max');

axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length);

const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);

const xVals = cpuBackend.data.get(x.dataId).values as Float32Array;
const result = max(xVals, sizeFromShape(reduceShape), outShape, x.dtype);

const dataId = cpuBackend.write(result, outShape, x.dtype);
return {dataId, shape: outShape, dtype: x.dtype};
}
};
64 changes: 64 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Softmax.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Softmax, SoftmaxAttrs, SoftmaxInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import * as axis_util from '../../../ops/axis_util';
import {parseAxisParam, sizeFromShape} from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

import {div} from './Div';
import {exp} from './Exp';
import {max} from './Max';
import {sub} from './Sub';
import {sum} from './Sum';

export const softmaxConfig: KernelConfig = {
kernelName: Softmax,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {logits} = inputs as SoftmaxInputs;
const {dim} = attrs as {} as SoftmaxAttrs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex(logits, 'softmax');

const axes = parseAxisParam([dim], logits.shape);

const [reduceOutShape, reduceShape] =
axis_util.computeOutAndReduceShapes(logits.shape, axes);
const logitsValues =
cpuBackend.data.get(logits.dataId).values as Float32Array;
const maxLogit = max(
logitsValues, sizeFromShape(reduceShape), reduceOutShape, logits.dtype);

const expandedShape = axis_util.expandShapeToKeepDim(reduceOutShape, axes);

const [aValues,] =
sub(logits.shape, expandedShape, logitsValues, maxLogit, logits.dtype);

const b = exp(aValues);

const sumExp =
sum(b, sizeFromShape(reduceShape), reduceOutShape, logits.dtype);

const [resultData, resultShape] =
div(logits.shape, reduceShape, b, sumExp, logits.dtype);
const dataId = cpuBackend.write(resultData, resultShape, logits.dtype);
return {dataId, shape: resultShape, dtype: logits.dtype};
}
};
35 changes: 9 additions & 26 deletions tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,14 @@
* =============================================================================
*/

import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';
import {broadcastedBinaryOp} from '../utils/kernel_utils';
import {SquaredDifference} from '../../../kernel_names';
import {createBinaryKernelImpl} from '../utils/kernel_utils';
import {createBinaryKernelConfig} from '../utils/kernel_utils';

export const squaredDifferenceConfig: KernelConfig = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as SquaredDifferenceInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], SquaredDifference);
const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const [resultData, resultShape] = broadcastedBinaryOp(
a.shape, b.shape, aVals, bVals, a.dtype, (aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const dataId = cpuBackend.write(resultData, resultShape, a.dtype);
return {dataId, shape: resultShape, dtype: a.dtype};
}
};
export const squaredDifferenceConfig =
createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl);
24 changes: 24 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Sub.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Sub} from '../../../kernel_names';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {createBinaryOp} from '../utils/kernel_utils';

export const sub = createBinaryOp((a: number, b: number) => a - b);

export const subConfig = createBinaryKernelConfig(Sub, sub);
Loading