diff --git a/.vscode/settings.json b/.vscode/settings.json index 300f0d2545..b1cd35c836 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,5 +12,6 @@ "editor.tabSize": 2, "editor.insertSpaces": true, "files.insertFinalNewline": true, - "editor.detectIndentation": false + "editor.detectIndentation": false, + "typescript.tsdk": "node_modules/typescript/lib" } diff --git a/demos/mnist/mnist.ts b/demos/mnist/mnist.ts index 1ee9e76097..bd8f6ed34d 100644 --- a/demos/mnist/mnist.ts +++ b/demos/mnist/mnist.ts @@ -68,10 +68,10 @@ export function buildModelMathAPI( return (x: Array1D): Scalar => { return math.scope(() => { - const hidden1 = - math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B)); - const hidden2 = math.relu( - math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B)); + const hidden1 = math.relu( + math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B)) as Array1D; + const hidden2 = math.relu(math.add( + math.vectorTimesMatrix(hidden1, hidden2W), hidden2B)) as Array1D; const logits = math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB); return math.argMax(logits); diff --git a/src/math/activation_functions.ts b/src/math/activation_functions.ts index 54a735b7bb..ee484dabc1 100644 --- a/src/math/activation_functions.ts +++ b/src/math/activation_functions.ts @@ -59,11 +59,11 @@ export class SigmoidFunc implements ActivationFunction { }); } - der(math: NDArrayMath, x: T, y: T) { + der(math: NDArrayMath, x: T, y: T): T { return math.scope(() => { // y * (1 - y) = y - y^2 const ySquared = math.elementWiseMul(y, y); - return math.sub(y, ySquared); + return math.subStrict(y, ySquared); }); } } diff --git a/src/math/cost_functions.ts b/src/math/cost_functions.ts index 5c17cc46ff..c787806d88 100644 --- a/src/math/cost_functions.ts +++ b/src/math/cost_functions.ts @@ -29,7 +29,7 @@ export class SquareCostFunc implements ElementWiseCostFunction { private halfOne = Scalar.new(0.5); cost(math: NDArrayMath, x1: T, x2: T): T { - const diff = math.sub(x1, x2); + const diff = math.subStrict(x1, x2); const diffSquared = math.elementWiseMul(diff, diff); const result = math.scalarTimesArray(this.halfOne, diffSquared); @@ -40,7 +40,7 @@ export class SquareCostFunc implements ElementWiseCostFunction { } der(math: NDArrayMath, x1: T, x2: T): T { - return math.sub(x1, x2); + return math.subStrict(x1, x2); } dispose() { diff --git a/src/math/math.ts b/src/math/math.ts index 2478fff41e..8d693a570e 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -558,10 +558,8 @@ export abstract class NDArrayMath { c.size === 1, `Error in scalarPlusArray: first argument must be rank 0, but got ` + `rank ${c.rank}.`); - return this.track(this.scalarPlusArrayInternal(c, a)); + return this.add(c, a) as T; } - protected abstract scalarPlusArrayInternal( - c: Scalar, a: T): T; /** * Computes a scalar minus NDArray, c - A. @@ -573,25 +571,21 @@ export abstract class NDArrayMath { c.size === 1, `Error in scalarMinusArray: first argument must be rank 0, but got ` + `rank ${c.rank}.`); - return this.track(this.scalarMinusArrayInternal(c, a)); + return this.sub(c, a) as T; } - protected abstract scalarMinusArrayInternal( - c: Scalar, a: T): T; /** - * Computes a scalar minus NDArray, A - c. + * Computes A - c. A is NDArray, c is Scalar. * @param a The NDArray A in A - c. - * @param c The scalar c in A - c. + * @param c The Scalar c in A - c. */ arrayMinusScalar(a: T, c: Scalar): T { util.assert( c.size === 1, `Error in arrayMinusScalar: second argument must be rank 0, but ` + `got rank ${c.rank}.`); - return this.track(this.arrayMinusScalarInternal(a, c)); + return this.sub(a, c) as T; } - protected abstract arrayMinusScalarInternal( - a: T, c: Scalar): T; /** * Computes -1 * A element-wise. @@ -603,50 +597,111 @@ export abstract class NDArrayMath { protected abstract negInternal(a: T): T; /** - * Adds two NDArrays element-wise, A + B. Inputs must be the same shape. + * Adds two NDArrays element-wise, A + B. Supports broadcasting. + * For a stricter version without broadcasting use math.addStrict(). + * * @param a The first NDArray to add element-wise. * @param b The second NDArray to add element-wise. */ - add(a: T, b: T): T { - util.assertShapesMatch(a.shape, b.shape, 'Error in add: '); + add(a: NDArray, b: NDArray): NDArray { + util.assertAndGetBroadcastedShape(a.shape, b.shape); return this.track(this.addInternal(a, b)); } - protected abstract addInternal(a: T, b: T): T; + protected abstract addInternal(a: NDArray, b: NDArray): NDArray; + + /** + * Adds two NDArrays element-wise, A + B. Inputs must + * be the same shape. For broadcasting support, use math.add() instead. + * + * @param a The first NDArray to multiply element-wise. + * @param b The second NDArray to multiply element-wise. + */ + addStrict(a: T, b: T): T { + util.assertShapesMatch(a.shape, b.shape, 'Error in addStrict: '); + return this.add(a, b) as T; + } /** - * Subtracts two NDArrays element-wise, A - B. Inputs must be the same shape. + * Subtracts two NDArrays element-wise, A - B. Supports broadcasting. + * For a stricter version without broadcasting use math.subStrict(). + * * @param a The first NDArray to subtract element-wise. * @param b The second NDArray to subtract element-wise. */ - sub(a: T, b: T): T { - util.assertShapesMatch(a.shape, b.shape, 'Error in sub: '); + sub(a: NDArray, b: NDArray): NDArray { + util.assertAndGetBroadcastedShape(a.shape, b.shape); return this.track(this.subInternal(a, b)); } - protected abstract subInternal(a: T, b: T): T; + protected abstract subInternal(a: NDArray, b: NDArray): NDArray; /** - * Multiplies two NDArrays element-wise (hadamard product), A * B. Inputs must - * be the same shape. + * Subtracts two NDArrays element-wise, A - B. Inputs must + * be the same shape. For broadcasting support, use math.sub() instead. + * * @param a The first NDArray to multiply element-wise. * @param b The second NDArray to multiply element-wise. */ + subStrict(a: T, b: T): T { + util.assertShapesMatch(a.shape, b.shape, 'Error in subStrict: '); + return this.sub(a, b) as T; + } + + /** + * Multiplies two NDArrays element-wise, A * B. Supports broadcasting. + * For a stricter version without broadcasting use math.multiplyStrict(). + * + * @param a The first NDArray to multiply element-wise. + * @param b The second NDArray to multiply element-wise. + */ + multiply(a: NDArray, b: NDArray): NDArray { + util.assertAndGetBroadcastedShape(a.shape, b.shape); + return this.track(this.multiplyInternal(a, b)); + } + protected abstract multiplyInternal(a: T, b: T): T; + + /** + * @deprecated Use math.multiplyStrict() instead. + */ elementWiseMul(a: T, b: T): T { - util.assertShapesMatch(a.shape, b.shape, 'Error in elementWiseMul: '); - return this.track(this.elementWiseMulInternal(a, b)); + return this.multiplyStrict(a, b); + } + + /** + * Multiplies two NDArrays element-wise, A * B. Inputs must + * be the same shape. For broadcasting support, use math.multiply() instead. + * + * @param a The first NDArray to multiply element-wise. + * @param b The second NDArray to multiply element-wise. + */ + multiplyStrict(a: T, b: T): T { + util.assertShapesMatch(a.shape, b.shape, 'Error in multiplyStrict: '); + return this.multiply(a, b) as T; } - protected abstract elementWiseMulInternal(a: T, b: T): T; /** - * Divides two NDArrays element-wise (hadamard product), A / B. Inputs must be - * the same shape. + * Divides two NDArrays element-wise, A / B. Supports broadcasting. + * For a stricter version without broadcasting use math.divideStrict(). + * * @param a The first NDArray to divide element-wise. * @param b The second NDArray to divide element-wise. */ - divide(a: T, b: T): T { - util.assertShapesMatch(a.shape, b.shape, 'Error in divide: '); + divide(a: NDArray, b: NDArray): NDArray { + util.assertAndGetBroadcastedShape(a.shape, b.shape); return this.track(this.divideInternal(a, b)); } - protected abstract divideInternal(a: T, b: T): T; + protected abstract divideInternal(a: NDArray, b: NDArray): NDArray; + + /** + * Divides two NDArrays element-wise, A / B. Inputs must + * be the same shape. For broadcasting support, use math.divide() instead. + * + * @param a The first NDArray to multiply element-wise. + * @param b The second NDArray to multiply element-wise. + */ + divideStrict(a: T, b: T): T { + util.assertShapesMatch(a.shape, b.shape, 'Error in divideStrict: '); + return this.divide(a, b) as T; + } /** * Computes a scalar divided by an NDArray, broadcasted over the NDArray, c / @@ -659,10 +714,8 @@ export abstract class NDArrayMath { c.size === 1, `Error in scalarDividedByArray: first argument must be rank 0, but ` + `got NDArray of rank ${c.rank}.`); - return this.track(this.scalarDividedByArrayInternal(c, a)); + return this.divide(c, a) as T; } - protected abstract scalarDividedByArrayInternal( - c: Scalar, a: T): T; /** * Computes an NDArray divided by a scalar, broadcasted over the NDArray, A / @@ -675,10 +728,8 @@ export abstract class NDArrayMath { c.size === 1, `Error in arrayDividedByScalar: second argument must be rank 0, ` + `but got NDArray of rank ${c.rank}.`); - return this.track(this.arrayDividedByScalarInternal(a, c)); + return this.divide(a, c) as T; } - protected abstract arrayDividedByScalarInternal( - a: T, c: Scalar): T; /** * Computes exponential of the input NDArray element-wise. y = e ^ x @@ -778,17 +829,11 @@ export abstract class NDArrayMath { c.size === 1, `Error in arrayDividedByScalar: first argument must be rank 0, but ` + `got rank ${c.rank}.`); - return this.track(this.scalarTimesArrayInternal(c, a)); + return this.multiply(c, a) as T; } - protected abstract scalarTimesArrayInternal( - c: Scalar, a: T): T; /** - * Computes an element-wise broadcasted multiplication of two matrices A and - * B. Will return a new matrix that is the max of A and B, where the smaller - * matrix will broadcast over the larger matrix. - * @param c The scalar in the operation. - * @param A the NDArray in the operation that will be broadcasted over. + * @deprecated Use math.multiply() instead. */ elementWiseMulBroadcast(a: Array2D, b: Array2D): Array2D { util.assert( @@ -799,10 +844,8 @@ export abstract class NDArrayMath { b.rank === 2, `Error in elementWiseMulBroadcast: second argument must be ` + `rank 2, but got rank ${b.rank}.`); - return this.track(this.elementWiseMulBroadcastInternal(a, b)); + return this.multiply(a, b) as Array2D; } - protected abstract elementWiseMulBroadcastInternal(a: Array2D, b: Array2D): - Array2D; ///////////////////// // Convolution ops // diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index 1967cb8313..4298b1974b 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -93,59 +93,23 @@ export class NDArrayMathCPU extends NDArrayMath { return values; } - protected scalarPlusArrayInternal(c: Scalar, a: T): T { - const resultValues = new Float32Array(a.size); - const aValues = a.getValues(); - const cVal = c.get(); - for (let i = 0; i < resultValues.length; ++i) { - resultValues[i] = cVal + aValues[i]; - } - return NDArray.make(a.shape, {values: resultValues}); - } - protected scaledArrayAddInternal( c1: Scalar, a: T, c2: Scalar, b: T) { - const cValues = new Float32Array(a.size); + const newShape = util.assertAndGetBroadcastedShape(a.shape, b.shape); + const newValues = new Float32Array(util.sizeFromShape(newShape)); + const aValues = a.getValues(); const bValues = b.getValues(); const c1Val = c1.get(); const c2Val = c2.get(); - for (let i = 0; i < cValues.length; ++i) { - cValues[i] = c1Val * aValues[i] + c2Val * bValues[i]; - } - return NDArray.make(a.shape, {values: cValues}); - } - - protected scalarTimesArrayInternal(c: Scalar, a: T): T { - const newValues = new Float32Array(a.size); - const aValues = a.getValues(); - const cVal = c.get(); - for (let i = 0; i < aValues.length; ++i) { - newValues[i] = cVal * aValues[i]; + for (let i = 0; i < newValues.length; ++i) { + newValues[i] = c1Val * aValues[i % a.size] + c2Val * bValues[i % b.size]; } - return NDArray.make(a.shape, {values: newValues}); - } - - protected scalarMinusArrayInternal(c: Scalar, a: T): T { - const negA = this.negInternal(a); - const result = this.scalarPlusArrayInternal(c, negA); - - negA.dispose(); - - return result; - } - - protected arrayMinusScalarInternal(a: T, c: Scalar): T { - const negC = this.negInternal(c); - const result = this.scalarPlusArrayInternal(negC, a); - - negC.dispose(); - - return result; + return NDArray.make(newShape, {values: newValues}); } protected negInternal(a: T): T { - return this.scalarTimesArrayInternal(Scalar.NEG_ONE, a); + return this.scalarTimesArray(Scalar.NEG_ONE, a); } protected addInternal(a: T, b: T): T { @@ -194,61 +158,29 @@ export class NDArrayMathCPU extends NDArrayMath { return Array2D.new([leftDim, rightDim], values); } - protected elementWiseMulInternal(a: T, b: T): T { - const newValues = new Float32Array(a.size); + protected multiplyInternal(a: T, b: T): T { + const newShape = util.assertAndGetBroadcastedShape(a.shape, b.shape); + const newValues = new Float32Array(util.sizeFromShape(newShape)); + const aValues = a.getValues(); const bValues = b.getValues(); - for (let i = 0; i < aValues.length; ++i) { - newValues[i] = aValues[i] * bValues[i]; - } - return NDArray.make(a.shape, {values: newValues}); - } - - protected elementWiseMulBroadcastInternal(a: Array2D, b: Array2D): Array2D { - const maxRow = Math.max(a.shape[0], b.shape[0]); - const maxCol = Math.max(a.shape[1], b.shape[1]); - - const values = new Float32Array(maxRow * maxCol); - let index = 0; - for (let row = 0; row < maxRow; row++) { - for (let col = 0; col < maxCol; col++) { - values[index++] = a.get(row % a.shape[0], col % a.shape[1]) * - b.get(row % b.shape[0], col % b.shape[1]); - } + for (let i = 0; i < newValues.length; ++i) { + newValues[i] = aValues[i % a.size] * bValues[i % b.size]; } - return Array2D.new([maxRow, maxCol], values); + return NDArray.make(newShape, {values: newValues}); } protected divideInternal(a: T, b: T): T { - const newValues = new Float32Array(a.size); - const aValues = a.getValues(); - const bValues = b.getValues(); - for (let i = 0; i < aValues.length; ++i) { - newValues[i] = aValues[i] / bValues[i]; - } - return NDArray.make(a.shape, {values: newValues}); - } + const newShape = util.assertAndGetBroadcastedShape(a.shape, b.shape); + const newValues = new Float32Array(util.sizeFromShape(newShape)); - protected scalarDividedByArrayInternal(c: Scalar, a: T): - T { - const newValues = new Float32Array(a.size); const aValues = a.getValues(); - const cValue = c.get(); - for (let i = 0; i < aValues.length; ++i) { - newValues[i] = cValue / aValues[i]; - } - return NDArray.make(a.shape, {values: newValues}); - } + const bValues = b.getValues(); - protected arrayDividedByScalarInternal(a: T, c: Scalar): - T { - const newValues = new Float32Array(a.size); - const aValues = a.getValues(); - const cValue = c.get(); - for (let i = 0; i < aValues.length; ++i) { - newValues[i] = aValues[i] / cValue; + for (let i = 0; i < newValues.length; ++i) { + newValues[i] = aValues[i % a.size] / bValues[i % b.size]; } - return NDArray.make(a.shape, {values: newValues}); + return NDArray.make(newShape, {values: newValues}); } protected sumInternal(ndarray: NDArray): Scalar { diff --git a/src/math/math_cpu_test.ts b/src/math/math_cpu_test.ts index f7e3e76957..5e86ca59d0 100644 --- a/src/math/math_cpu_test.ts +++ b/src/math/math_cpu_test.ts @@ -396,14 +396,13 @@ describe('NDArrayMathCPU element-wise mul/div', () => { let a = Array2D.new([2, 2], [1, 2, 3, 4]); let b = Array2D.new([2, 2], [5, 4, 3, 2]); let expected = Array2D.new([2, 2], [5, 8, 9, 8]); - expect(expected.equals(math.elementWiseMulBroadcast(a, b))).toBe(true); + expect(expected.equals(math.multiply(a, b))).toBe(true); // Broadcast a over b. - a = Array2D.new([2, 2], [1, 2, 3, 4]); - b = Array2D.new([4, 4], [2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8]); - expected = Array2D.new( - [4, 4], [2, 6, 4, 10, 9, 16, 15, 24, 4, 10, 6, 14, 15, 24, 21, 32]); - expect(expected.equals(math.elementWiseMulBroadcast(a, b))).toBe(true); + a = Array2D.new([1, 2], [1, 2]); + b = Array2D.new([4, 2], [2, 3, 4, 5, 6, 7, 8, 9]); + expected = Array2D.new([4, 2], [2, 6, 4, 10, 6, 14, 8, 18]); + expect(expected.equals(math.multiply(a, b))).toBe(true); }); it('multiplication, no broadcasting', () => { diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index 0c594c4046..eaf418dadd 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -21,8 +21,6 @@ import {MatrixOrientation, NDArrayMath} from './math'; import * as ndarray from './ndarray'; import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; import * as addscaledmat_gpu from './webgl/addscaledmat_gpu'; -import * as addsubmuldiv_gpu from './webgl/addsubmuldiv_gpu'; -import {OperandType} from './webgl/addsubmuldiv_gpu'; import {ArgMaxEqualsProgram} from './webgl/argmaxequals_gpu'; import {ArgMinMaxProgram} from './webgl/argminmax_gpu'; import * as avg_pool_gpu from './webgl/avg_pool_gpu'; @@ -32,6 +30,7 @@ import * as conv_backprop_gpu from './webgl/conv_backprop_gpu'; import * as conv_gpu from './webgl/conv_gpu'; import * as copy_gpu from './webgl/copy_gpu'; import {GPGPUContext} from './webgl/gpgpu_context'; +import {BinaryOpProgram} from './webgl/binaryop_gpu'; import {GPGPUProgram, GPGPUBinary} from './webgl/gpgpu_math'; import * as gpgpu_math from './webgl/gpgpu_math'; import * as gpgpu_util from './webgl/gpgpu_util'; @@ -59,7 +58,6 @@ const ADD_SCALED_MAT_PROG = 'addscaledmat'; // Element-wise ops. const RESHAPE_PROG = 'reshape'; -const ADD_SUM_MUL_DIV_PROG = 'addsummuldiv'; // Convolution. const CONV2D_PROG = 'conv'; @@ -238,21 +236,6 @@ export class NDArrayMathGPU extends NDArrayMath { resultShapeRCD, {texture: resultTex, textureShapeRC: resultTexShape}); } - protected scalarPlusArrayInternal(c: Scalar, a: T): T { - return this.addSubMulDiv( - c, a, a.shape, OperandType.SCALAR, '+', OperandType.MATRIX) as T; - } - - protected arrayMinusScalarInternal(a: T, c: Scalar): T { - return this.addSubMulDiv( - a, c, a.shape, OperandType.MATRIX, '-', OperandType.SCALAR) as T; - } - - protected scalarMinusArrayInternal(c: Scalar, a: T): T { - return this.addSubMulDiv( - c, a, a.shape, OperandType.SCALAR, '-', OperandType.MATRIX) as T; - } - protected scaledArrayAddInternal( c1: Scalar, a: T, c2: Scalar, b: T) { let cleanupB = false; @@ -278,11 +261,6 @@ export class NDArrayMathGPU extends NDArrayMath { return NDArray.make(a.shape, {texture: resultTexture, textureShapeRC}); } - protected scalarTimesArrayInternal(c: Scalar, a: T): T { - return this.addSubMulDiv( - c, a, a.shape, OperandType.SCALAR, '*', OperandType.MATRIX) as T; - } - protected negInternal(a: T): T { const program = new UnaryOpProgram(a.shape, UnaryOp.NEG); return this.compileAndRun(program, [a]); @@ -331,13 +309,9 @@ export class NDArrayMathGPU extends NDArrayMath { return this.compileAndRun(program, [a, b]); } - protected elementWiseMulInternal(a: T, b: T): T { - return this.addSubMulDiv( - a, b, a.shape, OperandType.MATRIX, '*', OperandType.MATRIX) as T; - } - - protected elementWiseMulBroadcastInternal(a: Array2D, b: Array2D): Array2D { - throw new Error('Not yet implemented!'); + protected multiplyInternal(a: T, b: T): T { + const program = new BinaryOpProgram('*', a.shape, b.shape); + return this.compileAndRun(program, [a, b]); } protected batchNormalization3DInternal( @@ -470,30 +444,18 @@ export class NDArrayMathGPU extends NDArrayMath { } protected divideInternal(a: T, b: T): T { - return this.addSubMulDiv( - a, b, a.shape, OperandType.MATRIX, '/', OperandType.MATRIX) as T; - } - - protected scalarDividedByArrayInternal(c: Scalar, a: T): - T { - return this.addSubMulDiv( - c, a, a.shape, OperandType.SCALAR, '/', OperandType.MATRIX) as T; - } - - protected arrayDividedByScalarInternal(a: T, c: Scalar): - T { - return this.addSubMulDiv( - a, c, a.shape, OperandType.MATRIX, '/', OperandType.SCALAR) as T; + const program = new BinaryOpProgram('/', a.shape, b.shape); + return this.compileAndRun(program, [a, b]); } protected addInternal(a: T, b: T): T { - return this.addSubMulDiv( - a, b, a.shape, OperandType.MATRIX, '+', OperandType.MATRIX) as T; + const program = new BinaryOpProgram('+', a.shape, b.shape); + return this.compileAndRun(program, [a, b]); } protected subInternal(a: T, b: T): T { - return this.addSubMulDiv( - a, b, a.shape, OperandType.MATRIX, '-', OperandType.MATRIX) as T; + const program = new BinaryOpProgram('-', a.shape, b.shape); + return this.compileAndRun(program, [a, b]); } protected logSumExpInternal(a: NDArray): Scalar { @@ -1001,85 +963,6 @@ export class NDArrayMathGPU extends NDArrayMath { return this.programCache[programKey]; } - private addSubMulDiv( - a: NDArray, b: NDArray, resultShape: number[], - operandA: addsubmuldiv_gpu.OperandType, - opType: addsubmuldiv_gpu.Operation, - operandB: addsubmuldiv_gpu.OperandType): NDArray { - let cleanupB = false; - - const aOrientation = MatrixOrientation.REGULAR; - let bOrientation = MatrixOrientation.REGULAR; - - let logicalBTexShape: [number, number]; - - if (operandA === OperandType.MATRIX && operandB === OperandType.MATRIX) { - util.assertShapesMatch(a.shape, b.shape); - - if (a.inGPU()) { - // Prefer B to have the shape of A. - b.getTextureShapeRC(a.getTextureShapeRC()); - } else if (b.inGPU()) { - // Prefer A to have the shape of B. - a.getTextureShapeRC(b.getTextureShapeRC()); - } - - const aTexShape = a.getTextureShapeRC(); - const bTexShape = b.getTextureShapeRC(); - logicalBTexShape = bTexShape; - - if (a.rank === 1) { - // When dealing with vectors, we can sample in transposed way without - // the need to do physical reshape. - if (!util.arraysEqual(bTexShape, aTexShape)) { - bOrientation = MatrixOrientation.TRANSPOSED; - logicalBTexShape = [bTexShape[1], bTexShape[0]]; - } - } - - if (!util.arraysEqual(aTexShape, logicalBTexShape)) { - b = this.reshapeTexture(b, aTexShape); - bOrientation = MatrixOrientation.REGULAR; - logicalBTexShape = b.getTextureShapeRC(); - cleanupB = true; - } - } else { - logicalBTexShape = b.getTextureShapeRC(); - } - - const aTexShape = a.getTextureShapeRC(); - const bTexShape = b.getTextureShapeRC(); - - const programKey = [ - ADD_SUM_MUL_DIV_PROG, operandA, aOrientation, opType, operandB, - bOrientation - ].join('_'); - const program = this.getAndSaveProgram( - programKey, - () => addsubmuldiv_gpu.getFragmentShaderSource( - operandA, aOrientation, opType, operandB, bOrientation)); - - const resultTextureShape: [number, number] = [ - Math.max(aTexShape[0], logicalBTexShape[0]), - Math.max(aTexShape[1], logicalBTexShape[1]) - ]; - - const resultTexture = - this.textureManager.acquireTexture(resultTextureShape); - - addsubmuldiv_gpu.addSubMulDiv( - this.gpgpu, program, a.getTexture(), aTexShape, b.getTexture(), - bTexShape, resultTexture, resultTextureShape); - - if (cleanupB) { - b.dispose(); - } - - return NDArray.make( - resultShape, - {texture: resultTexture, textureShapeRC: resultTextureShape}); - } - private doGPUShapesMatch(a: NDArray, b: NDArray): boolean { util.assertShapesMatch(a.shape, b.shape); if (a.inGPU()) { diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index 8c86fa839c..8b68af880b 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -35,9 +35,9 @@ describe('NDArrayMathGPU scope', () => { math.scope(() => { const result = math.scope(() => { - b = math.add(a, b); - b = math.add(a, b); - b = math.add(a, b); + b = math.add(a, b) as Array1D; + b = math.add(a, b) as Array1D; + b = math.add(a, b) as Array1D; return math.add(a, b); }); @@ -89,9 +89,9 @@ describe('NDArrayMathGPU scope', () => { const numUsedTexturesBefore = math.getTextureManager().getNumUsedTextures(); math.scope(() => { - b = math.add(a, b); - b = math.add(a, b); - b = math.add(a, b); + b = math.add(a, b) as Array1D; + b = math.add(a, b) as Array1D; + b = math.add(a, b) as Array1D; math.add(a, b); }); @@ -109,10 +109,10 @@ describe('NDArrayMathGPU scope', () => { math.scope(() => { const result = math.scope(() => { - b = math.add(a, b); + b = math.add(a, b) as Array1D; b = math.scope(() => { b = math.scope(() => { - return math.add(a, b); + return math.add(a, b) as Array1D; }); // a, original b, and two intermediate textures should be the only // textures. @@ -126,12 +126,12 @@ describe('NDArrayMathGPU scope', () => { expect(math.getTextureManager().getNumUsedTextures()) .toEqual(numUsedTexturesBefore + 4); - return math.add(a, b); + return math.add(a, b) as Array1D; }); expect(math.getTextureManager().getNumUsedTextures()) .toEqual(numUsedTexturesBefore + 4); - return math.add(a, b); + return math.add(a, b) as Array1D; }); // a, b, and result are new textures. All intermediates should be diff --git a/src/math/webgl/addsubmuldiv_gpu.ts b/src/math/webgl/addsubmuldiv_gpu.ts deleted file mode 100644 index 3ea00a7d92..0000000000 --- a/src/math/webgl/addsubmuldiv_gpu.ts +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2017 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {MatrixOrientation} from '../math'; - -import * as binaryop_gpu from './binaryop_gpu'; -import {GPGPUContext} from './gpgpu_context'; - -export type Operation = '+' | '-' | '*' | '/'; - -export enum OperandType { - MATRIX, - SCALAR -} - -export function getFragmentShaderSource( - aType: OperandType, aOrientation: MatrixOrientation, op: Operation, - bType: OperandType, bOrientation: MatrixOrientation): string { - const aUV = operandToShaderSnippet(aType, aOrientation); - const bUV = operandToShaderSnippet(bType, bOrientation); - const resultOp = `gl_FragColor = vec4(a ${op} b, 0, 0, 0);`; - return binaryop_gpu.getFragmentShaderSource(aUV, bUV, resultOp); -} - -function operandToShaderSnippet( - operand: OperandType, orientation: MatrixOrientation): string { - switch (operand) { - case OperandType.MATRIX: - return 'resultUV' + - (orientation === MatrixOrientation.REGULAR ? '.st' : '.ts'); - case OperandType.SCALAR: - return 'vec2(0.5, 0.5)'; - default: - throw new Error('Unknown operand type'); - } -} - -export function addSubMulDiv( - gpgpu: GPGPUContext, program: WebGLProgram, a: WebGLTexture, - aShapeRowCol: [number, number], b: WebGLTexture, - bShapeRowCol: [number, number], result: WebGLTexture, - resultShapeRowCol: [number, number]) { - return binaryop_gpu.binaryOp( - gpgpu, program, a, aShapeRowCol, b, bShapeRowCol, result, - resultShapeRowCol); -} - -export function uploadScalarPlusMatrixDownload( - a: number, b: Float32Array, bShape: [number, number], - bOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.SCALAR, MatrixOrientation.REGULAR, '+', OperandType.MATRIX, - bOrientation); - return binaryop_gpu.uploadBinaryOpDownload( - new Float32Array([a]), [1, 1], b, bShape, src); -} - -export function uploadMatrixMinusScalarDownload( - a: Float32Array, aShape: [number, number], b: number, - aOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.MATRIX, aOrientation, '-', OperandType.SCALAR, - MatrixOrientation.REGULAR); - return binaryop_gpu.uploadBinaryOpDownload( - a, aShape, new Float32Array([b]), [1, 1], src); -} - -export function uploadScalarMinusMatrixDownload( - a: number, b: Float32Array, bShape: [number, number], - bOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.SCALAR, MatrixOrientation.REGULAR, '-', OperandType.MATRIX, - bOrientation); - return binaryop_gpu.uploadBinaryOpDownload( - new Float32Array([a]), [1, 1], b, bShape, src); -} - -export function uploadScalarTimesMatrixDownload( - a: number, b: Float32Array, bShape: [number, number], - bOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.SCALAR, MatrixOrientation.REGULAR, '*', OperandType.MATRIX, - bOrientation); - return binaryop_gpu.uploadBinaryOpDownload( - new Float32Array([a]), [1, 1], b, bShape, src); -} - -export function uploadMatrixTimesMatrixDownload( - a: Float32Array, b: Float32Array, shape: [number, number], - aOrientation = MatrixOrientation.REGULAR, - bOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.MATRIX, aOrientation, '*', OperandType.MATRIX, bOrientation); - return binaryop_gpu.uploadBinaryOpDownload(a, shape, b, shape, src); -} - -export function uploadMatrixPlusMatrixDownload( - a: Float32Array, b: Float32Array, shape: [number, number], - aOrientation = MatrixOrientation.REGULAR, - bOrientation = MatrixOrientation.REGULAR): Float32Array { - const src = getFragmentShaderSource( - OperandType.MATRIX, aOrientation, '+', OperandType.MATRIX, bOrientation); - return binaryop_gpu.uploadBinaryOpDownload(a, shape, b, shape, src); -} diff --git a/src/math/webgl/addsubmuldiv_gpu_test.ts b/src/math/webgl/addsubmuldiv_gpu_test.ts deleted file mode 100644 index 5087ce05c6..0000000000 --- a/src/math/webgl/addsubmuldiv_gpu_test.ts +++ /dev/null @@ -1,214 +0,0 @@ -/* Copyright 2017 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as test_util from '../../test_util'; -import {MatrixOrientation} from '../math'; - -import * as addsubmuldiv_gpu from './addsubmuldiv_gpu'; - -describe('addsubmuldiv_gpu ScalarPlusMatrix', () => { - it('returns a matrix with the same shape as the input matrix', () => { - const a = new Float32Array(12 * 513); - const result = - addsubmuldiv_gpu.uploadScalarPlusMatrixDownload(0, a, [12, 513]); - expect(result.length).toEqual(a.length); - }); - - it('preserves the matrix when the scalar is 0', () => { - const a = new Float32Array([1, 2, 3]); - const result = - addsubmuldiv_gpu.uploadScalarPlusMatrixDownload(0, a, [1, 3]); - test_util.expectArraysClose(result, a, 0); - }); - - it('adds the scalar to every element in the matrix', () => { - const a = new Float32Array([1, 2, 3, 4]); - const result = - addsubmuldiv_gpu.uploadScalarPlusMatrixDownload(0.5, a, [2, 2]); - test_util.expectArraysClose( - result, new Float32Array([1.5, 2.5, 3.5, 4.5]), 0.0001); - }); -}); - -describe('addsubmuldiv_gpu MatrixMinusScalar', () => { - it('returns a matrix with the same shape as the input matrix', () => { - const a = new Float32Array(12 * 513); - const result = - addsubmuldiv_gpu.uploadMatrixMinusScalarDownload(a, [12, 513], 0); - expect(result.length).toEqual(a.length); - }); - - it('preserves the matrix when the scalar is 0', () => { - const a = new Float32Array([1, 2, 3]); - const result = - addsubmuldiv_gpu.uploadMatrixMinusScalarDownload(a, [1, 3], 0); - test_util.expectArraysClose(result, a, 0); - }); - - it('subtracts the scalar from every element in the matrix', () => { - const a = new Float32Array([1, 2, 3, 4]); - const result = - addsubmuldiv_gpu.uploadMatrixMinusScalarDownload(a, [2, 2], 0.5); - test_util.expectArraysClose( - result, new Float32Array([0.5, 1.5, 2.5, 3.5]), 0.0001); - }); -}); - -describe('addsubmuldiv_gpu ScalarMinusMatrix', () => { - it('returns a matrix with the same shape as the input matrix', () => { - const a = new Float32Array(12 * 513); - const result = - addsubmuldiv_gpu.uploadScalarMinusMatrixDownload(0, a, [12, 513]); - expect(result.length).toEqual(a.length); - }); - - it('negates the matrix when the scalar is 0', () => { - const a = new Float32Array([1, 2, 3]); - const result = - addsubmuldiv_gpu.uploadScalarMinusMatrixDownload(0, a, [1, 3]); - test_util.expectArraysClose(result, new Float32Array([-1, -2, -3]), 0); - }); - - it('subtracts the matrix value from the scalar for every element', () => { - const a = new Float32Array([1, 2, 3, 4]); - const result = - addsubmuldiv_gpu.uploadScalarMinusMatrixDownload(0.5, a, [2, 2]); - test_util.expectArraysClose( - result, new Float32Array([-0.5, -1.5, -2.5, -3.5]), 0.0001); - }); -}); - -describe('addsubmuldiv_gpu ScalarTimesMatrix', () => { - it('returns a matrix with the same shape as the input matrix', () => { - const a = new Float32Array(12 * 513); - const result = - addsubmuldiv_gpu.uploadScalarTimesMatrixDownload(0, a, [12, 513]); - expect(result.length).toEqual(a.length); - }); - - it('zeros out the matrix when the scalar is 0', () => { - const a = new Float32Array([1, 2, 3]); - const result = - addsubmuldiv_gpu.uploadScalarTimesMatrixDownload(0, a, [1, 3]); - test_util.expectArraysClose(result, new Float32Array([0, 0, 0]), 0); - }); - - it('triples the matrix when the scalar is 3', () => { - const a = new Float32Array([1, 2, 3]); - const result = - addsubmuldiv_gpu.uploadScalarTimesMatrixDownload(3, a, [1, 3]); - test_util.expectArraysClose(result, new Float32Array([3, 6, 9]), 0); - }); -}); - -describe('addsubmuldiv_gpu element-wise matrix product', () => { - function cpuMultiply(a: Float32Array, b: Float32Array): Float32Array { - const result = new Float32Array(a.length); - for (let i = 0; i < result.length; ++i) { - result[i] = a[i] * b[i]; - } - return result; - } - - it('returns a matrix the size of A (and B)', () => { - const a = new Float32Array(1234); - const b = new Float32Array(1234); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [1234 / 2, 2]); - expect(result.length).toEqual(a.length); - }); - - it('sets all result entries to 0 if A is 0', () => { - const a = new Float32Array(257 * 257); - const b = new Float32Array(a.length); - b.fill(1.0); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [257, 257]); - expect(result).toEqual(a); - }); - - it('sets all result entries to 0 if B is 0', () => { - const a = new Float32Array(257 * 257); - const b = new Float32Array(a.length); - a.fill(1.0); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [257, 257]); - expect(result).toEqual(b); - }); - - it('sets all result entries to A if B is [1]', () => { - const a = test_util.randomArrayInRange(16, -10, 10); - const b = new Float32Array(16); - b.fill(1.0); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [4, 4]); - test_util.expectArraysClose(a, result, 0.0001); - }); - - it('sets all result entries to B if A is [1]', () => { - const a = new Float32Array(16); - a.fill(1.0); - const b = test_util.randomArrayInRange(16, -10, 10); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [4, 4]); - test_util.expectArraysClose(b, result, 0.0001); - }); - - it('writes the element-wise product of A and B to result', () => { - const a = test_util.randomArrayInRange(64, -10, 10); - const b = test_util.randomArrayInRange(64, -10, 10); - const result = - addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload(a, b, [8, 8]); - const expected = cpuMultiply(a, b); - test_util.expectArraysClose(result, expected, 0.0001); - }); - - it('writes the element-wise product A * B^T to result', () => { - const a = new Float32Array([1, 2, 3, 4]); - const b = new Float32Array([3, 1, 0, 2]); - - const result = addsubmuldiv_gpu.uploadMatrixTimesMatrixDownload( - a, b, [2, 2], MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED); - - const bTransposed = new Float32Array([3, 0, 1, 2]); - const expected = cpuMultiply(a, bTransposed); - test_util.expectArraysClose(result, expected, 0.0001); - }); -}); - -describe('addsubmuldiv_gpu element-wise matrix addition', () => { - it('writes the element-wise A + B^T to result', () => { - const a = new Float32Array([1, 2, 3, 4]); - const b = new Float32Array([3, 1, 0, 2]); - - const result = addsubmuldiv_gpu.uploadMatrixPlusMatrixDownload( - a, b, [2, 2], MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED); - - const expected = new Float32Array([4, 2, 4, 6]); - test_util.expectArraysClose(result, expected, 0.0001); - }); - - it('writes the element-wise A^T + B^T to result', () => { - const a = new Float32Array([1, 2, 3, 4]); - const b = new Float32Array([3, 1, 0, 2]); - - const result = addsubmuldiv_gpu.uploadMatrixPlusMatrixDownload( - a, b, [2, 2], MatrixOrientation.TRANSPOSED, - MatrixOrientation.TRANSPOSED); - - const expected = new Float32Array([4, 3, 3, 6]); - test_util.expectArraysClose(result, expected, 0.0001); - }); -}); diff --git a/src/math/webgl/binaryop_gpu.ts b/src/math/webgl/binaryop_gpu.ts index 4a793d9342..63298b83e6 100644 --- a/src/math/webgl/binaryop_gpu.ts +++ b/src/math/webgl/binaryop_gpu.ts @@ -13,66 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {GPGPUContext} from './gpgpu_context'; - -export function getFragmentShaderSource( - aResultUV: string, bResultUV: string, op: string): string { - return ` - precision highp float; - uniform sampler2D matrixA; - uniform sampler2D matrixB; - varying vec2 resultUV; - - void main() { - float a = texture2D(matrixA, ${aResultUV}).r; - float b = texture2D(matrixB, ${bResultUV}).r; - ${op} - }`; -} - -export function binaryOp( - gpgpu: GPGPUContext, program: WebGLProgram, a: WebGLTexture, - aShapeRowCol: [number, number], b: WebGLTexture, - bShapeRowCol: [number, number], result: WebGLTexture, - resultShapeRowCol: [number, number]) { - gpgpu.setOutputMatrixTexture( - result, resultShapeRowCol[0], resultShapeRowCol[1]); - gpgpu.setProgram(program); - gpgpu.setInputMatrixTexture(a, 'matrixA', 0); - gpgpu.setInputMatrixTexture(b, 'matrixB', 1); - gpgpu.executeProgram(); -} - -export function uploadBinaryOpDownload( - a: Float32Array, aShape: [number, number], b: Float32Array, - bShape: [number, number], fragmentShaderSource: string): Float32Array { - const gpgpu = new GPGPUContext(); - const program = gpgpu.createProgram(fragmentShaderSource); - - const aTexture: WebGLTexture = - gpgpu.createMatrixTexture(aShape[0], aShape[1]); - const bTexture: WebGLTexture = - gpgpu.createMatrixTexture(bShape[0], bShape[1]); - - const resultShape: [number, number] = - [Math.max(aShape[0], bShape[0]), Math.max(aShape[1], bShape[1])]; - - const resultTexture: WebGLTexture = - gpgpu.createMatrixTexture(resultShape[0], resultShape[1]); - - gpgpu.uploadMatrixToTexture(aTexture, aShape[0], aShape[1], a); - gpgpu.uploadMatrixToTexture(bTexture, bShape[0], bShape[1], b); - - binaryOp( - gpgpu, program, aTexture, aShape, bTexture, bShape, resultTexture, - resultShape); - const result = gpgpu.downloadMatrixFromTexture( - resultTexture, resultShape[0], resultShape[1]); - - gpgpu.deleteMatrixTexture(aTexture); - gpgpu.deleteMatrixTexture(bTexture); - gpgpu.deleteMatrixTexture(resultTexture); - gpgpu.deleteProgram(program); - gpgpu.dispose(); - return result; +import {GPGPUProgram} from './gpgpu_math'; +import * as util from '../../util'; + +export class BinaryOpProgram implements GPGPUProgram { + variableNames = ['A', 'B']; + params: Array<{}>; + outputShape: number[]; + userCode: string; + supportsBroadcasting: boolean; + + constructor(op: '+' | '-' | '*' | '/', aShape: number[], bShape: number[]) { + this.supportsBroadcasting = true; + this.params = [op]; + this.outputShape = util.assertAndGetBroadcastedShape(aShape, bShape); + this.userCode = ` + void main() { + float a = getAAtOutCoords(); + float b = getBAtOutCoords(); + setOutput(a ${op} b); + } + `; + } } diff --git a/src/math/webgl/binaryop_gpu_test.ts b/src/math/webgl/binaryop_gpu_test.ts new file mode 100644 index 0000000000..dd1ad320c0 --- /dev/null +++ b/src/math/webgl/binaryop_gpu_test.ts @@ -0,0 +1,199 @@ +/* Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as test_util from '../../test_util'; + +import {BinaryOpProgram} from './binaryop_gpu'; +import {GPGPUContext} from './gpgpu_context'; +import * as gpgpu_math from './gpgpu_math'; +import {NDArray, Array1D, Array2D, Array3D, Scalar, + initializeGPU} from '../ndarray'; +import * as util from '../../util'; +import {TextureManager} from './texture_manager'; + +describe('binaryop_gpu Add', () => { + it('returns a matrix with the same shape as the input matrix', () => { + const a = Scalar.new(0); + const b = Array2D.zeros([12, 513]); + const result = uploadBinaryOpDownload(a, b, '+'); + expect(result.length).toEqual(b.size); + }); + + it('preserves the matrix when the scalar is 0', () => { + const c = Scalar.new(0); + const a = Array1D.new([1, 2, 3]); + const result = uploadBinaryOpDownload(c, a, '+'); + test_util.expectArraysClose(result, new Float32Array([1, 2, 3]), 0); + }); + + it('adds the scalar to every element in the matrix', () => { + const a = Array1D.new([1, 2, 3, 4]); + const c = Scalar.new(0.5); + const result = uploadBinaryOpDownload(c, a, '+'); + test_util.expectArraysClose( + result, new Float32Array([1.5, 2.5, 3.5, 4.5]), 0.0001); + }); +}); + +describe('binaryop_gpu Sub', () => { + it('returns a matrix with the same shape as the input matrix', () => { + const a = Array2D.zeros([12, 513]); + const c = Scalar.new(0); + const result = uploadBinaryOpDownload(a, c, '-'); + expect(result.length).toEqual(a.size); + }); + + it('preserves the matrix when the scalar is 0', () => { + const a = Array1D.new([1, 2, 3]); + const c = Scalar.new(0); + const result = uploadBinaryOpDownload(a, c, '-'); + test_util.expectArraysClose(result, new Float32Array([1, 2, 3]), 0); + }); + + it('subtracts the scalar from every element in the matrix', () => { + const a = Array1D.new([1, 2, 3, 4]); + const c = Scalar.new(0.5); + const result = uploadBinaryOpDownload(a, c, '-'); + test_util.expectArraysClose( + result, new Float32Array([0.5, 1.5, 2.5, 3.5]), 0.0001); + }); + + it('2D - 1D broadcasting', () => { + const a = Array2D.new([3, 2], [[1, 2], [3, 4], [5, 6]]); + const b = Array1D.new([1, 3]); + const result = uploadBinaryOpDownload(a, b, '-'); + test_util.expectArraysClose( + result, new Float32Array([0, -1, 2, 1, 4, 3]), 1e-4); + }); + + it('2D - 1D invalid shapes for broadcasting', () => { + const a = Array2D.new([3, 2], [[1, 2], [3, 4], [5, 6]]); + const b = Array1D.new([1, 2, 3]); + // shape [3, 2] is not compatible with shape [3]. + const f = () => uploadBinaryOpDownload(a, b, '-'); + expect(f).toThrowError(); + }); + + it('3D - 2D broadcasting', () => { + const a = Array3D.new([2, 2, 2], [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + const b = Array2D.new([2, 2], [[1, 2], [3, 5]]); + // shape [3, 2] is not compatible with shape [3]. + const res = uploadBinaryOpDownload(a, b, '-'); + test_util.expectArraysClose( + res, new Float32Array([0, 0, 0, -1, 4, 4, 4, 3]), 1e-4); + }); +}); + + +describe('binaryop_gpu Mul', () => { + function cpuMultiply(a: Float32Array, b: Float32Array): Float32Array { + const result = new Float32Array(a.length); + for (let i = 0; i < result.length; ++i) { + result[i] = a[i] * b[i]; + } + return result; + } + + it('returns a matrix with the same shape as the input matrix', () => { + const a = Array2D.zeros([12, 513]); + const c = Scalar.new(0); + const result = uploadBinaryOpDownload(c, a, '*'); + expect(result.length).toEqual(a.size); + }); + + it('zeros out the matrix when the scalar is 0', () => { + const a = Array1D.new([1, 2, 3]); + const c = Scalar.new(0); + const result = uploadBinaryOpDownload(c, a, '*'); + test_util.expectArraysClose(result, new Float32Array([0, 0, 0]), 0); + }); + + it('triples the matrix when the scalar is 3', () => { + const a = Array1D.new([1, 2, 3]); + const c = Scalar.new(3); + const result = uploadBinaryOpDownload(c, a, '*'); + test_util.expectArraysClose(result, new Float32Array([3, 6, 9]), 0); + }); + + it('sets all result entries to 0 if A is 0', () => { + const a = Array2D.zeros([25, 25]); + const expected = a.getValues(); + const b = Array2D.zerosLike(a); + b.fill(1.0); + const result = uploadBinaryOpDownload(a, b, '*'); + expect(result).toEqual(expected); + }); + + it('sets all result entries to 0 if B is 0', () => { + const a = Array2D.zeros([25, 25]); + a.fill(1.0); + const b = Array2D.zerosLike(a); + const expected = b.getValues(); + const result = uploadBinaryOpDownload(a, b, '*'); + expect(result).toEqual(expected); + }); + + it('sets all result entries to A if B is [1]', () => { + const a = Array1D.new(test_util.randomArrayInRange(16, -10, 10)); + const expected = a.getValues(); + const b = Array1D.zeros([16]); + b.fill(1.0); + const result = uploadBinaryOpDownload(a, b, '*'); + test_util.expectArraysClose(result, expected, 0.0001); + }); + + it('writes the element-wise product of A and B to result', () => { + const a = Array1D.new(test_util.randomArrayInRange(64, -10, 10)); + const b = Array1D.new(test_util.randomArrayInRange(64, -10, 10)); + const expected = cpuMultiply(a.getValues(), b.getValues()); + const result = uploadBinaryOpDownload(a, b, '*'); + test_util.expectArraysClose(result, expected, 0.0001); + }); +}); + +describe('binaryop_gpu Divide', () => { + it('Scalar / Matrix', () => { + const c = Scalar.new(2); + const a = Array2D.new([2, 3], [1, 2, 3, 4, 5, 6]); + const r = uploadBinaryOpDownload(c, a, '/'); + expect(r[0]).toBeCloseTo(2 / 1); + expect(r[1]).toBeCloseTo(2 / 2); + expect(r[2]).toBeCloseTo(2 / 3); + expect(r[3]).toBeCloseTo(2 / 4); + expect(r[4]).toBeCloseTo(2 / 5); + expect(r[5]).toBeCloseTo(2 / 6); + }); +}); + +export function uploadBinaryOpDownload( + a: NDArray, b: NDArray, op: '+'|'-'|'*'|'/'): Float32Array { + const gpgpu = new GPGPUContext(); + const textureManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, textureManager); + + const outShape = util.assertAndGetBroadcastedShape(a.shape, b.shape); + const res = NDArray.zeros(outShape); + const program = new BinaryOpProgram(op, a.shape, b.shape); + const binary = + gpgpu_math.compileProgram(gpgpu, program, [a, b], res); + gpgpu_math.runProgram(binary, [a, b], res); + + const resValues = res.getValues(); + textureManager.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + gpgpu.dispose(); + + return resValues; +} diff --git a/src/math/webgl/gpgpu_math.ts b/src/math/webgl/gpgpu_math.ts index eb1ca3cec2..9851fe968d 100644 --- a/src/math/webgl/gpgpu_math.ts +++ b/src/math/webgl/gpgpu_math.ts @@ -25,6 +25,7 @@ export interface GPGPUProgram { outputShape: number[]; params: Array<{}>; userCode: string; + supportsBroadcasting?: boolean; } export interface GPGPUBinary { @@ -52,8 +53,9 @@ export function compileProgram( logicalShape: output.shape, texShape: output.getTextureShapeRC() }; - const source = shader_compiler.makeShader(inputInfos, outShapeInfo, - userCode); + const source = shader_compiler.makeShader( + inputInfos, outShapeInfo, userCode, + program.supportsBroadcasting === true); return { program, source, diff --git a/src/math/webgl/shader_compiler.ts b/src/math/webgl/shader_compiler.ts index e296bd9767..ef4f772bd6 100644 --- a/src/math/webgl/shader_compiler.ts +++ b/src/math/webgl/shader_compiler.ts @@ -25,13 +25,13 @@ export type InputInfo = { shapeInfo: ShapeInfo }; -export function makeShader( - inputsInfo: InputInfo[], outputShape: ShapeInfo, - userCode: string): string { +export function makeShader(inputsInfo: InputInfo[], outputShape: ShapeInfo, + userCode: string, broadcast: boolean): string { const inputPrefixSnippet = inputsInfo.map(x => `uniform sampler2D ${x.name};`).join('\n'); const inputSamplingSnippet = - inputsInfo.map(x => getInputSamplingSnippet(x, outputShape)).join('\n'); + inputsInfo.map(x => getInputSamplingSnippet(x, outputShape, broadcast)) + .join('\n'); const outTexShape = outputShape.texShape; const outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape); @@ -42,7 +42,8 @@ export function makeShader( return source; } -function getInputSamplingSnippet(inInfo: InputInfo, outShapeInfo: ShapeInfo) { +function getInputSamplingSnippet( + inInfo: InputInfo, outShapeInfo: ShapeInfo, broadcast: boolean) { const shape = inInfo.shapeInfo.logicalShape; const texShape = inInfo.shapeInfo.texShape; const outTexShape = outShapeInfo.texShape; @@ -70,9 +71,10 @@ function getInputSamplingSnippet(inInfo: InputInfo, outShapeInfo: ShapeInfo) { // If input and output have matching logical shapes, add // getTexNameAtOutCoord() method that samples the input texture using the // output coordinates. - if (util.arraysEqual( + if (broadcast || util.arraysEqual( inInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape)) { - res += getSamplerAtOutputCoords(inInfo.name, texShape, outTexShape); + res += + getSamplerAtOutputCoords(inInfo.name, texShape, outTexShape, broadcast); } res += getSamplerFlat(inInfo.name, texShape); return res; @@ -217,6 +219,13 @@ function getSampler1D( const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const tR = texShape[0]; const tC = texShape[1]; + if (texShape[0] === 1 && texShape[1] === 1) { + return ` + float ${funcName}(float index) { + return texture2D(${texName}, halfCR).r; + } + `; + } if (texShape[1] === 1) { return ` float ${funcName}(float index) { @@ -282,6 +291,29 @@ function getSamplerFlat(texName: string, texShape: [number, number]): string { 'Flat'; const tNumR = texShape[0]; const tNumC = texShape[1]; + if (tNumC === 1 && tNumR === 1) { + return ` + float ${funcName}(float index) { + return texture2D(${texName}, halfCR).r; + } + `; + } + if (tNumC === 1) { + return ` + float ${funcName}(float index) { + vec2 uv = vec2(0.5, (index + 0.5) / ${tNumR}.0); + return texture2D(${texName}, uv).r; + } + `; + } + if (tNumR === 1) { + return ` + float ${funcName}(float index) { + vec2 uv = vec2((index + 0.5) / ${tNumC}.0, 0.5); + return texture2D(${texName}, uv).r; + } + `; + } return ` float ${funcName}(float index) { float texR = floor(index / ${tNumC}.0); @@ -293,7 +325,7 @@ function getSamplerFlat(texName: string, texShape: [number, number]): string { } function getSamplerAtOutputCoords(texName: string, inTexShape: [number, number], - outTexShape: [number, number]) { + outTexShape: [number, number], broadcast: boolean) { const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1) + 'AtOutCoords'; if (util.arraysEqual(inTexShape, outTexShape)) { @@ -303,10 +335,14 @@ function getSamplerAtOutputCoords(texName: string, inTexShape: [number, number], } `; } + const inSize = util.sizeFromShape(inTexShape); + const broadcastSnippet = broadcast ? `index = mod(index, ${inSize}.0);` : ''; + return ` float ${funcName}() { vec2 resTexRC = floor(gl_FragCoord.yx); float index = dot(resTexRC, vec2(${outTexShape[1]}.0, 1.0)); + ${broadcastSnippet} float texR = floor(index / ${inTexShape[1]}.0); float texC = mod(index, ${inTexShape[1]}.0); vec2 uv = (vec2(texC, texR) + halfCR) / diff --git a/src/util.ts b/src/util.ts index a130c1524c..925e12fd99 100644 --- a/src/util.ts +++ b/src/util.ts @@ -180,3 +180,35 @@ export function createShuffledIndices(n: number): Uint32Array { shuffle(shuffledIndices); return shuffledIndices; } + +export function assertAndGetBroadcastedShape( + shapeA: number[], shapeB: number[]): number[] { + const result: number[] = []; + let nextADimMustBeOne = false; + let nextBDimMustBeOne = false; + const errMsg = `Operands could not be broadcast together with shapes ` + + `${shapeA} and ${shapeB}. Currently, we only support a ` + + `stricter version of broadcasting than numpy.`; + const l = Math.max(shapeA.length, shapeB.length); + + shapeA = shapeA.slice().reverse(); + shapeB = shapeB.slice().reverse(); + for (let i = 0; i < l; i++) { + const a = shapeA[i] || 1; + const b = shapeB[i] || 1; + if ((b > 1 && nextBDimMustBeOne) || (a > 1 && nextADimMustBeOne)) { + throw Error(errMsg); + } + if (a > 1 && b === 1) { + nextBDimMustBeOne = true; + } + if (b > 1 && a === 1) { + nextADimMustBeOne = true; + } + if (a > 1 && b > 1 && a !== b) { + throw Error(errMsg); + } + result.push(Math.max(a, b)); + } + return result.reverse(); +} diff --git a/src/util_test.ts b/src/util_test.ts index 77e713c7f4..d4afa86c7a 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -87,3 +87,60 @@ describe('Util', () => { expect(util.inferShape(a)).toEqual([2, 3, 2, 1]); }); }); + +describe('util.getBroadcastedShape', () => { + it('two scalars', () => { + const res = util.assertAndGetBroadcastedShape([], []); + expect(res).toEqual([]); + }); + + it('scalar and 1d', () => { + const res = util.assertAndGetBroadcastedShape([6], []); + expect(res).toEqual([6]); + }); + + it('scalar and 2d', () => { + const res = util.assertAndGetBroadcastedShape([2, 6], []); + expect(res).toEqual([2, 6]); + }); + + it('1d and 2d', () => { + const res = util.assertAndGetBroadcastedShape([6], [2, 6]); + expect(res).toEqual([2, 6]); + }); + + it('2d and 3d', () => { + const res = util.assertAndGetBroadcastedShape([2, 6], [7, 2, 6]); + expect(res).toEqual([7, 2, 6]); + }); + + it('3d and 3d', () => { + const res = util.assertAndGetBroadcastedShape([1, 1, 6], [7, 2, 6]); + expect(res).toEqual([7, 2, 6]); + }); + + it('incompatible inner shape', () => { + const f = () => util.assertAndGetBroadcastedShape([7, 2, 5], [7, 2, 6]); + expect(f).toThrowError(); + }); + + it('incompatible middle shape', () => { + const f = () => util.assertAndGetBroadcastedShape([7, 3, 6], [7, 2, 6]); + expect(f).toThrowError(); + }); + + it('incompatible due to stricter broadcasting support', () => { + const f = () => util.assertAndGetBroadcastedShape([7, 3, 6], [7, 1, 6]); + expect(f).toThrowError(); + }); + + it('incompatible due to stricter broadcasting support', () => { + const f = () => util.assertAndGetBroadcastedShape([7, 1, 1], [7, 1]); + expect(f).toThrowError(); + }); + + it('compatible with stricter broadcasting support', () => { + const res = util.assertAndGetBroadcastedShape([7, 1, 1], [7, 1, 1]); + expect(res).toEqual([7, 1, 1]); + }); +});