Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
1555ed8
clean
annxingyuan Feb 4, 2020
22171d3
shell
annxingyuan Feb 4, 2020
99b1e5c
compute lse
annxingyuan Feb 4, 2020
3865aac
softmax
annxingyuan Feb 4, 2020
971fe5a
add cpu
annxingyuan Feb 4, 2020
14100b7
upgrade deps
annxingyuan Feb 4, 2020
cb717af
setup
annxingyuan Feb 4, 2020
5560bb8
wtf
annxingyuan Feb 4, 2020
8a59b0c
use
annxingyuan Feb 4, 2020
d5037f0
simplify
annxingyuan Feb 4, 2020
6981390
softmax test
annxingyuan Feb 4, 2020
6b01d0c
fix
annxingyuan Feb 5, 2020
099a594
pass in dim
annxingyuan Feb 5, 2020
188bde6
test case
annxingyuan Feb 5, 2020
19a91bd
logs
annxingyuan Feb 5, 2020
74af251
delete
annxingyuan Feb 5, 2020
007edeb
add batch to key
annxingyuan Feb 5, 2020
bbcea30
move log statement
annxingyuan Feb 5, 2020
932e7a6
remove batch from cache
annxingyuan Feb 6, 2020
1244ca3
Merge branch 'master' into softmax
annxingyuan Feb 6, 2020
ccebc1e
add note
annxingyuan Feb 6, 2020
fcb079a
remove build flags
annxingyuan Feb 7, 2020
2daaa4f
remove logs
annxingyuan Feb 7, 2020
dd0ba9f
testing
annxingyuan Feb 7, 2020
7962041
remove header
annxingyuan Feb 7, 2020
4e7fdc8
add neg
annxingyuan Feb 7, 2020
869768b
register
annxingyuan Feb 8, 2020
3c9dda6
notequal
annxingyuan Feb 8, 2020
9c58b33
lint
annxingyuan Feb 10, 2020
88d8705
revive spy
annxingyuan Feb 10, 2020
648bc6c
save
annxingyuan Feb 10, 2020
40d5e82
add neg
annxingyuan Feb 10, 2020
e149666
save outputs
annxingyuan Feb 11, 2020
8903f08
start
annxingyuan Feb 11, 2020
2c9572b
fix
annxingyuan Feb 11, 2020
8d9f789
edit
annxingyuan Feb 13, 2020
fa76890
Merge branch 'master' into softmax
annxingyuan Feb 13, 2020
29a7cdc
remove
annxingyuan Feb 13, 2020
2ae0a59
revive
annxingyuan Feb 13, 2020
4442f31
revive
annxingyuan Feb 13, 2020
c4385c4
build
annxingyuan Feb 13, 2020
82921ca
revive test
annxingyuan Feb 13, 2020
cde9135
move kernels
annxingyuan Feb 13, 2020
657f5c5
start
annxingyuan Feb 13, 2020
3cda765
add sub
annxingyuan Feb 13, 2020
5991de3
subimpl
annxingyuan Feb 14, 2020
3d6ae21
merge
annxingyuan Feb 14, 2020
bdddd58
add exp
annxingyuan Feb 14, 2020
83e6d46
add sum
annxingyuan Feb 14, 2020
4e67f8d
softmax
annxingyuan Feb 14, 2020
3bff003
remove softmax
annxingyuan Feb 14, 2020
583a3f4
binary impl
annxingyuan Feb 14, 2020
cb5c1a5
clean
annxingyuan Feb 14, 2020
2543cc2
Merge branch 'master' into modularize_softmax
annxingyuan Feb 17, 2020
a2a5792
webgl wip
annxingyuan Feb 17, 2020
93d8c41
max webgl
annxingyuan Feb 18, 2020
277d246
clean
annxingyuan Feb 18, 2020
66cabfb
add reshape
annxingyuan Feb 18, 2020
94eb151
add sub shell
annxingyuan Feb 18, 2020
e278865
Merge branch 'master' into modularize_softmax
annxingyuan Feb 18, 2020
1716084
add sub
annxingyuan Feb 18, 2020
ab272bd
add exp
annxingyuan Feb 18, 2020
bf3af8e
add sum
annxingyuan Feb 18, 2020
209a602
getting correct answer
annxingyuan Feb 19, 2020
69a9d07
fix
annxingyuan Feb 19, 2020
018392b
softmax
annxingyuan Feb 19, 2020
66aad43
fix sum
annxingyuan Feb 19, 2020
c61a309
extract reduce
annxingyuan Feb 19, 2020
107c384
share reduce code
annxingyuan Feb 19, 2020
0092d25
setup
annxingyuan Feb 19, 2020
c42416d
dataid
annxingyuan Feb 19, 2020
bf4b9e7
remove logs
annxingyuan Feb 19, 2020
6ab20b6
merge
annxingyuan Mar 18, 2020
1e9da7c
merge conflict
annxingyuan Mar 18, 2020
d07b645
build
annxingyuan Mar 18, 2020
67e7153
add cpu div
annxingyuan Mar 18, 2020
a505261
clean
annxingyuan Mar 19, 2020
9ca3b63
move things around
annxingyuan Mar 19, 2020
69b5293
max
annxingyuan Mar 19, 2020
26e1db2
exp
annxingyuan Mar 19, 2020
5607f5c
sub
annxingyuan Mar 19, 2020
9687794
add sum
annxingyuan Mar 20, 2020
e9b00bf
cap
annxingyuan Mar 20, 2020
26dadc3
simplify
annxingyuan Mar 20, 2020
5eb5143
simplify
annxingyuan Mar 20, 2020
c594bc7
div
annxingyuan Mar 20, 2020
8b4ca3d
add arg
annxingyuan Mar 20, 2020
a332b15
reduce
annxingyuan Mar 20, 2020
ea195f6
updates
annxingyuan Mar 20, 2020
02f1201
reduce shape helper
annxingyuan Mar 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,6 @@ export class MathBackendCPU extends KernelBackend {
return result.toTensor() as T;
}

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
const b = this.exp(a);
const sumExp = this.sum(b, axes).reshape(expandedShape);

return this.realDivide(b, sumExp) as T;
}

subtract(a: Tensor, b: Tensor): Tensor {
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
return this.broadcastedBinaryComplexOp(
Expand Down Expand Up @@ -493,13 +482,13 @@ export class MathBackendCPU extends KernelBackend {
(aValue, bValue) => aValue * bValue);
}

realDivide(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'realDivide');
// realDivide(a: Tensor, b: Tensor): Tensor {
// assertNotComplex([a, b], 'realDivide');

const op = (a: number, b: number) => a / b;
const outputDtype = 'float32';
return this.broadcastedBinaryOp(a, b, outputDtype, op);
}
// const op = (a: number, b: number) => a / b;
// const outputDtype = 'float32';
// return this.broadcastedBinaryOp(a, b, outputDtype, op);
// }

floorDiv(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'floorDiv');
Expand Down
23 changes: 23 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Div.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Div} from '../../../kernel_names';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {createBinaryOp} from '../utils/kernel_utils';

export const div = createBinaryOp((a: number, b: number) => a / b);
export const divConfig = createBinaryKernelConfig(Div, div);
48 changes: 48 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Exp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Exp, ExpInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import * as util from '../../../util';
import {MathBackendCPU} from '../backend_cpu';

export const exp = (x: TypedArray): TypedArray => {
const outValues = util.getTypedArrayFromDType('float32', x.length);

for (let i = 0; i < x.length; ++i) {
outValues[i] = Math.exp(x[i]);
}

return outValues;
};

export const expConfig: KernelConfig = {
kernelName: Exp,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {x} = inputs as ExpInputs;
const cpuBackend = backend as MathBackendCPU;

const xVals = cpuBackend.data.get(x.dataId).values as Float32Array;

const result = exp(xVals);

const dataId = cpuBackend.write(result, x.shape, x.dtype);
return {dataId, shape: x.shape, dtype: x.dtype};
}
};
69 changes: 69 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Max.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import * as axis_util from '../../../ops/axis_util';
import {DataType, NumericDataType, TypedArray} from '../../../types';
import * as util from '../../../util';
import {sizeFromShape} from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export const max =
(x: TypedArray, reduceSize: number, outShape: number[], dtype: DataType):
TypedArray => {
const outValues = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(outShape));

for (let i = 0; i < x.length; ++i) {
const offset = i * reduceSize;
let max = x[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = x[offset + j];
if (value > max) {
max = value;
}
}
outValues[i] = max;
}

return outValues;
};

export const maxConfig: KernelConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as MaxInputs;
const {axes} = attrs as {} as MaxAttrs;
const cpuBackend = backend as MathBackendCPU;

assertNotComplex(x, 'max');

axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length);

const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);

const xVals = cpuBackend.data.get(x.dataId).values as Float32Array;
const result = max(xVals, sizeFromShape(reduceShape), outShape, x.dtype);

const dataId = cpuBackend.write(result, outShape, x.dtype);
return {dataId, shape: outShape, dtype: x.dtype};
}
};
64 changes: 64 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Softmax.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Softmax, SoftmaxAttrs, SoftmaxInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import * as axis_util from '../../../ops/axis_util';
import {parseAxisParam, sizeFromShape} from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

import {div} from './Div';
import {exp} from './Exp';
import {max} from './Max';
import {sub} from './Sub';
import {sum} from './Sum';

export const softmaxConfig: KernelConfig = {
kernelName: Softmax,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {logits} = inputs as SoftmaxInputs;
const {dim} = attrs as {} as SoftmaxAttrs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex(logits, 'softmax');

const axes = parseAxisParam([dim], logits.shape);

const [reduceOutShape, reduceShape] =
axis_util.computeOutAndReduceShapes(logits.shape, axes);
const logitsValues =
cpuBackend.data.get(logits.dataId).values as Float32Array;
const maxLogit = max(
logitsValues, sizeFromShape(reduceShape), reduceOutShape, logits.dtype);

const expandedShape = axis_util.expandShapeToKeepDim(reduceOutShape, axes);

const [aValues,] =
sub(logits.shape, expandedShape, logitsValues, maxLogit, logits.dtype);

const b = exp(aValues);

const sumExp =
sum(b, sizeFromShape(reduceShape), reduceOutShape, logits.dtype);

const [resultData, resultShape] =
div(logits.shape, reduceShape, b, sumExp, logits.dtype);
const dataId = cpuBackend.write(resultData, resultShape, logits.dtype);
return {dataId, shape: resultShape, dtype: logits.dtype};
}
};
35 changes: 9 additions & 26 deletions tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,14 @@
* =============================================================================
*/

import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';
import {broadcastedBinaryOp} from '../utils/kernel_utils';
import {SquaredDifference} from '../../../kernel_names';
import {createBinaryOp} from '../utils/kernel_utils';
import {createBinaryKernelConfig} from '../utils/kernel_utils';

export const squaredDifferenceConfig: KernelConfig = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as SquaredDifferenceInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], SquaredDifference);
const squaredDifferenceImpl = createBinaryOp((aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const [resultData, resultShape] = broadcastedBinaryOp(
a.shape, b.shape, aVals, bVals, a.dtype, (aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const dataId = cpuBackend.write(resultData, resultShape, a.dtype);
return {dataId, shape: resultShape, dtype: a.dtype};
}
};
export const squaredDifferenceConfig =
createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl);
24 changes: 24 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Sub.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Sub} from '../../../kernel_names';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {createBinaryOp} from '../utils/kernel_utils';

export const sub = createBinaryOp((a: number, b: number) => a - b);

export const subConfig = createBinaryKernelConfig(Sub, sub);
69 changes: 69 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Sum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Sum, SumAttrs, SumInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import * as axis_util from '../../../ops/axis_util';
import {upcastType} from '../../../types';
import {DataType, NumericDataType, TypedArray} from '../../../types';
import * as util from '../../../util';
import {sizeFromShape} from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export const sum =
(x: TypedArray, reduceSize: number, outShape: number[], dtype: DataType):
TypedArray => {
const outValues = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(outShape));

for (let i = 0; i < x.length; ++i) {
const offset = i * reduceSize;
let sum = 0;
for (let j = 0; j < reduceSize; ++j) {
const value = x[offset + j];
sum += value;
}
outValues[i] = sum;
}

return outValues;
};

export const sumConfig: KernelConfig = {
kernelName: Sum,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as SumInputs;
const {axes} = attrs as {} as SumAttrs;
const cpuBackend = backend as MathBackendCPU;

assertNotComplex(x, 'sum');

axis_util.assertAxesAreInnerMostDims('sum', axes, x.shape.length);

const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const resultDtype = upcastType(x.dtype, 'int32');

const xVals = cpuBackend.data.get(x.dataId).values as Float32Array;
const result = sum(xVals, sizeFromShape(reduceShape), outShape, x.dtype);

const dataId = cpuBackend.write(result, outShape, resultDtype);
return {dataId, shape: outShape, dtype: resultDtype};
}
};
Loading