Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Allow different dtypes in binary math ops (#1432)
Browse files Browse the repository at this point in the history
Allow users to provide different dtypes in binary arithmetic ops (add/sub/mul/div/...) and matmul, just like in numpy.

The dtype of the result is upcasted i.e. matMul(float32, int32) => float32

This will result in release patch 0.14.1, which will fix the breakage in 0.14.0 caused by #1408 due to improved dtype inference where tensor(new Int32Array()) is inferred to be int32, and was float32.

Fixes tensorflow/tfjs#934, tensorflow/tfjs#966
  • Loading branch information
dsmilkov committed Dec 6, 2018
1 parent 2ff431b commit 8db48a6
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 128 deletions.
9 changes: 4 additions & 5 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ export class MathBackendCPU implements KernelBackend {
[b.strides[1], 1, b.strides[0]];

const size = leftDim * rightDim;
const result = new Float32Array(batchDim * size);

const result = buffer([batchDim, leftDim, rightDim], a.dtype);
const resVals = result.values as TypedArray;
const blockSize = this.blockSize;

for (let b = 0; b < batchDim; b++) {
Expand All @@ -428,15 +428,14 @@ export class MathBackendCPU implements KernelBackend {
sum += aValues[b * aBatch + i * aOuterStep + k * aInnerStep] *
bValues[k * bInnerStep + j * bOuterStep + b * bBatch];
}
result[b * size + (i * rightDim + j)] += sum;
resVals[b * size + (i * rightDim + j)] += sum;
}
}
}
}
}
}

return ops.tensor3d(result, [batchDim, leftDim, rightDim]);
return result.toTensor() as Tensor3D;
}

multiply(a: Tensor, b: Tensor): Tensor {
Expand Down
33 changes: 21 additions & 12 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D,
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types';
import * as util from '../util';
import {getTypedArrayFromDType, sizeFromShape} from '../util';

import {DataMover, DataStorage, KernelBackend} from './backend';
import * as backend_util from './backend_util';
import {mergeRealAndImagArrays} from './complex_util';
Expand Down Expand Up @@ -682,6 +681,8 @@ export class MathBackendWebGL implements KernelBackend {
return this.multiply(a3D, b3D).sum(axis, true /* keepDims */);
}

const dtype = upcastType(a.dtype, b.dtype);

// TODO(https://github.com/tensorflow/tfjs/issues/693): Support 3D tensors
if (batch === 1) {
const aSqueezed = a.as2D(a.shape[1], a.shape[2]);
Expand All @@ -690,13 +691,17 @@ export class MathBackendWebGL implements KernelBackend {
const program = new MatMulPackedProgram(
aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB],
transposeA, transposeB);
const output =
this.makePackedTensor(program.outputShape, dtype) as Tensor2D;
const result =
this.compileAndRun<Tensor2D>(program, [aSqueezed, bSqueezed]);

this.compileAndRun<Tensor2D>(program, [aSqueezed, bSqueezed], output);
return result.reshape([1, result.shape[0], result.shape[1]]);
} else {
return this.compileAndRun(
new MatMulProgram(a.shape, b.shape, transposeA, transposeB), [a, b]);
const program =
new MatMulProgram(a.shape, b.shape, transposeA, transposeB);
const output =
this.makeOutputArray(program.outputShape, dtype) as Tensor3D;
return this.compileAndRun(program, [a, b], output);
}
}

Expand Down Expand Up @@ -1517,7 +1522,8 @@ export class MathBackendWebGL implements KernelBackend {
convInfo.outChannels / convInfo.inChannels === 1) {
program = new DepthwiseConvPacked2DProgram(convInfo);
return this.compileAndRun(
program, [x, filter], this.makePackedTensor(convInfo.outShape));
program, [x, filter],
this.makePackedTensor(convInfo.outShape, x.dtype));
}

program = new DepthwiseConv2DProgram(convInfo);
Expand Down Expand Up @@ -1769,16 +1775,17 @@ export class MathBackendWebGL implements KernelBackend {
return Tensor.make(shape, {}, dtype) as T;
}

private makePackedTensor<T extends Tensor>(shape: number[]): T {
const packedTensor = Tensor.make(shape, {});
private makePackedTensor<T extends Tensor>(shape: number[], dtype: DataType):
T {
const packedTensor = Tensor.make(shape, {}, dtype);
this.texData.get(packedTensor.dataId).isPacked = true;
return packedTensor as T;
}

private unpackTensor<T extends Tensor>(input: T): T {
const program = new UnpackProgram(input.shape);
return this.compileAndRun(
program, [input], Tensor.make(program.outputShape, {}));
program, [input], Tensor.make(program.outputShape, {}, input.dtype));
}

private getBatchDim(shape: number[], dimsToSkip = 2): number {
Expand Down Expand Up @@ -1815,7 +1822,8 @@ export class MathBackendWebGL implements KernelBackend {
pageToCpu = true): K {
if (output == null) {
if (program.usesPackedTextures) {
output = this.makePackedTensor(program.outputShape) as {} as K;
output = this.makePackedTensor(program.outputShape, inputs[0].dtype) as
{} as K;
} else {
output = this.makeOutputArray(program.outputShape, inputs[0].dtype) as
{} as K;
Expand Down Expand Up @@ -1872,11 +1880,12 @@ export class MathBackendWebGL implements KernelBackend {
preProcessProgram = new UnpackProgram(input.shape);
processedInput = this.compileAndRun(
preProcessProgram, [input],
Tensor.make(preProcessProgram.outputShape, {}));
Tensor.make(preProcessProgram.outputShape, {}, input.dtype));
} else {
preProcessProgram = new PackProgram(input.shape);
processedInput = this.compileAndRun(
preProcessProgram, [input], this.makePackedTensor(input.shape));
preProcessProgram, [input],
this.makePackedTensor(input.shape, input.dtype));
}

texData = this.texData.get(processedInput.dataId);
Expand Down
84 changes: 58 additions & 26 deletions src/ops/arithmetic_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ describeWithFlags('div', ALL_ENVS, () => {
expectArraysClose(result, expected);
});

it('throws when passed tensors of different types', () => {
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
const b = tf.tensor2d([1, 2, 3, 4, 2, 5], [2, 3], 'int32');
it('upcasts when dtypes dont match', () => {
let res = tf.div(tf.scalar(6, 'int32'), tf.scalar(3, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [2]);

expect(() => tf.div(a, b)).toThrowError();
expect(() => tf.div(b, a)).toThrowError();
res = tf.div(tf.scalar(6, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [6]);
});

it('throws when passed tensors of different shapes', () => {
Expand Down Expand Up @@ -580,11 +582,18 @@ describeWithFlags('mul', ALL_ENVS, () => {
expect(() => tf.mul(tf.scalar(1), {} as tf.Tensor))
.toThrowError(/Argument 'b' passed to 'mul' must be a Tensor/);
});
it('throws when dtypes dont match', () => {
expect(() => tf.mul(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
it('upcasts when dtypes dont match', () => {
let res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(3, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [6]);

res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [2]);

res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [0]);
});

it('accepts a tensor-like object', () => {
Expand Down Expand Up @@ -1149,11 +1158,26 @@ describeWithFlags('add', ALL_ENVS, () => {
.toThrowError(/Argument 'b' passed to 'add' must be a Tensor/);
});

it('throws when dtypes dont match', () => {
expect(() => tf.add(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
it('upcasts when dtypes dont match', () => {
let res = tf.add(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [2]);

res = tf.add(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [2]);

res = tf.add(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [1]);

res = tf.add(tf.complex(4, 7), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [5, 7]);

res = tf.add(tf.complex(4, 7), tf.scalar(1, 'int32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [5, 7]);
});

it('accepts a tensor-like object', () => {
Expand Down Expand Up @@ -1495,18 +1519,26 @@ describeWithFlags('sub', ALL_ENVS, () => {
expect(() => tf.sub(tf.scalar(1), {} as tf.Tensor))
.toThrowError(/Argument 'b' passed to 'sub' must be a Tensor/);
});
it('throws when dtypes dont match', () => {
expect(() => tf.sub(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
});
it('upcasts when dtypes dont match', () => {
let res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [0]);

it('throws when dtypes dont match', () => {
expect(() => tf.sub(tf.scalar(1, 'float32'), tf.complex(1, 2)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(float32\) and second\(complex64\) input must match/);
res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [0]);

res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [1]);

res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [3, 7]);

res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'int32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [3, 7]);
});

it('accepts a tensor-like object', () => {
Expand Down
60 changes: 29 additions & 31 deletions src/ops/binary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {ENV} from '../environment';
import {KernelBackend} from '../kernels/backend';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {assertTypesMatch} from '../tensor_util';
import {makeTypesMatch} from '../tensor_util';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike, upcastType} from '../types';
import * as util from '../util';
Expand Down Expand Up @@ -53,9 +53,9 @@ import {neg} from './unary_ops';
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function add_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'add');
const $b = convertToTensor(b, 'b', 'add');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'add');
let $b = convertToTensor(b, 'b', 'add');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -172,9 +172,9 @@ function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function sub_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'sub');
const $b = convertToTensor(b, 'b', 'sub');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'sub');
let $b = convertToTensor(b, 'b', 'sub');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -318,9 +318,9 @@ function powStrict_<T extends Tensor>(base: T, exp: Tensor): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function mul_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'mul');
const $b = convertToTensor(b, 'b', 'mul');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'mul');
let $b = convertToTensor(b, 'b', 'mul');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -391,9 +391,9 @@ function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'div');
const $b = convertToTensor(b, 'b', 'div');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'div');
let $b = convertToTensor(b, 'b', 'div');
[$a, $b] = makeTypesMatch($a, $b);

let forwardFunc: (backend: KernelBackend) => Tensor;
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
Expand Down Expand Up @@ -454,9 +454,9 @@ function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function floorDiv_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'floorDiv');
const $b = convertToTensor(b, 'b', 'floorDiv');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'floorDiv');
let $b = convertToTensor(b, 'b', 'floorDiv');
[$a, $b] = makeTypesMatch($a, $b);

const forwardFunc = (backend: KernelBackend) => backend.floorDiv($a, $b);
const outShape =
Expand Down Expand Up @@ -526,9 +526,9 @@ function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function mod_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'mod');
const $b = convertToTensor(b, 'b', 'mod');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'mod');
let $b = convertToTensor(b, 'b', 'mod');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -598,14 +598,13 @@ function minimum_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'minimum');
let $b = convertToTensor(b, 'b', 'minimum');
assertTypesMatch($a, $b);
[$a, $b] = makeTypesMatch($a, $b);

if ($a.dtype === 'bool') {
$a = $a.toInt();
}
if ($b.dtype === 'bool') {
$b = $b.toInt();
}

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
const derA = () => dy.mul($a.lessEqual($b).toFloat());
Expand Down Expand Up @@ -660,14 +659,13 @@ function maximum_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'maximum');
let $b = convertToTensor(b, 'b', 'maximum');
assertTypesMatch($a, $b);
[$a, $b] = makeTypesMatch($a, $b);

if ($a.dtype === 'bool') {
$a = $a.toInt();
}
if ($b.dtype === 'bool') {
$b = $b.toInt();
}

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
const derA = () => dy.mul($a.greaterEqual($b).toFloat());
Expand Down Expand Up @@ -721,9 +719,9 @@ function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function squaredDifference_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'squaredDifference');
const $b = convertToTensor(b, 'b', 'squaredDifference');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'squaredDifference');
let $b = convertToTensor(b, 'b', 'squaredDifference');
[$a, $b] = makeTypesMatch($a, $b);

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
Expand Down Expand Up @@ -772,9 +770,9 @@ function squaredDifferenceStrict_<T extends Tensor>(
/** @doc {heading: 'Operations', subheading: 'Basic math'} */
function atan2_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'atan2');
const $b = convertToTensor(b, 'b', 'atan2');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'atan2');
let $b = convertToTensor(b, 'b', 'atan2');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down

0 comments on commit 8db48a6

Please sign in to comment.