Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"editor.tabSize": 2,
"editor.insertSpaces": true,
"files.insertFinalNewline": true,
"editor.detectIndentation": false
"editor.detectIndentation": false,
"typescript.tsdk": "node_modules/typescript/lib"
}
8 changes: 4 additions & 4 deletions demos/mnist/mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ export function buildModelMathAPI(

return (x: Array1D): Scalar => {
return math.scope(() => {
const hidden1 =
math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B));
const hidden2 = math.relu(
math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B));
const hidden1 = math.relu(
math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B)) as Array1D;
const hidden2 = math.relu(math.add(
math.vectorTimesMatrix(hidden1, hidden2W), hidden2B)) as Array1D;
const logits =
math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);
return math.argMax(logits);
Expand Down
4 changes: 2 additions & 2 deletions src/math/activation_functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ export class SigmoidFunc implements ActivationFunction {
});
}

der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
der<T extends NDArray>(math: NDArrayMath, x: T, y: T): T {
return math.scope(() => {
// y * (1 - y) = y - y^2
const ySquared = math.elementWiseMul(y, y);
return math.sub(y, ySquared);
return math.subStrict(y, ySquared);
});
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/math/cost_functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export class SquareCostFunc implements ElementWiseCostFunction {
private halfOne = Scalar.new(0.5);

cost<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
const diff = math.sub(x1, x2);
const diff = math.subStrict(x1, x2);
const diffSquared = math.elementWiseMul(diff, diff);
const result = math.scalarTimesArray(this.halfOne, diffSquared);

Expand All @@ -40,7 +40,7 @@ export class SquareCostFunc implements ElementWiseCostFunction {
}

der<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
return math.sub(x1, x2);
return math.subStrict(x1, x2);
}

dispose() {
Expand Down
135 changes: 89 additions & 46 deletions src/math/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,8 @@ export abstract class NDArrayMath {
c.size === 1,
`Error in scalarPlusArray: first argument must be rank 0, but got ` +
`rank ${c.rank}.`);
return this.track(this.scalarPlusArrayInternal(c, a));
return this.add(c, a) as T;
}
protected abstract scalarPlusArrayInternal<T extends NDArray>(
c: Scalar, a: T): T;

/**
* Computes a scalar minus NDArray, c - A.
Expand All @@ -573,25 +571,21 @@ export abstract class NDArrayMath {
c.size === 1,
`Error in scalarMinusArray: first argument must be rank 0, but got ` +
`rank ${c.rank}.`);
return this.track(this.scalarMinusArrayInternal(c, a));
return this.sub(c, a) as T;
}
protected abstract scalarMinusArrayInternal<T extends NDArray>(
c: Scalar, a: T): T;

/**
* Computes a scalar minus NDArray, A - c.
* Computes A - c. A is NDArray, c is Scalar.
* @param a The NDArray A in A - c.
* @param c The scalar c in A - c.
* @param c The Scalar c in A - c.
*/
arrayMinusScalar<T extends NDArray>(a: T, c: Scalar): T {
util.assert(
c.size === 1,
`Error in arrayMinusScalar: second argument must be rank 0, but ` +
`got rank ${c.rank}.`);
return this.track(this.arrayMinusScalarInternal(a, c));
return this.sub(a, c) as T;
}
protected abstract arrayMinusScalarInternal<T extends NDArray>(
a: T, c: Scalar): T;

/**
* Computes -1 * A element-wise.
Expand All @@ -603,50 +597,111 @@ export abstract class NDArrayMath {
protected abstract negInternal<T extends NDArray>(a: T): T;

/**
* Adds two NDArrays element-wise, A + B. Inputs must be the same shape.
* Adds two NDArrays element-wise, A + B. Supports broadcasting.
* For a stricter version without broadcasting use math.addStrict().
*
* @param a The first NDArray to add element-wise.
* @param b The second NDArray to add element-wise.
*/
add<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in add: ');
add(a: NDArray, b: NDArray): NDArray {
util.assertAndGetBroadcastedShape(a.shape, b.shape);
return this.track(this.addInternal(a, b));
}
protected abstract addInternal<T extends NDArray>(a: T, b: T): T;
protected abstract addInternal(a: NDArray, b: NDArray): NDArray;

/**
* Adds two NDArrays element-wise, A + B. Inputs must
* be the same shape. For broadcasting support, use math.add() instead.
*
* @param a The first NDArray to multiply element-wise.
* @param b The second NDArray to multiply element-wise.
*/
addStrict<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in addStrict: ');
return this.add(a, b) as T;
}

/**
* Subtracts two NDArrays element-wise, A - B. Inputs must be the same shape.
* Subtracts two NDArrays element-wise, A - B. Supports broadcasting.
* For a stricter version without broadcasting use math.subStrict().
*
* @param a The first NDArray to subtract element-wise.
* @param b The second NDArray to subtract element-wise.
*/
sub<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in sub: ');
sub(a: NDArray, b: NDArray): NDArray {
util.assertAndGetBroadcastedShape(a.shape, b.shape);
return this.track(this.subInternal(a, b));
}
protected abstract subInternal<T extends NDArray>(a: T, b: T): T;
protected abstract subInternal(a: NDArray, b: NDArray): NDArray;

/**
* Multiplies two NDArrays element-wise (hadamard product), A * B. Inputs must
* be the same shape.
* Subtracts two NDArrays element-wise, A - B. Inputs must
* be the same shape. For broadcasting support, use math.sub() instead.
*
* @param a The first NDArray to multiply element-wise.
* @param b The second NDArray to multiply element-wise.
*/
subStrict<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in subStrict: ');
return this.sub(a, b) as T;
}

/**
* Multiplies two NDArrays element-wise, A * B. Supports broadcasting.
* For a stricter version without broadcasting use math.multiplyStrict().
*
* @param a The first NDArray to multiply element-wise.
* @param b The second NDArray to multiply element-wise.
*/
multiply(a: NDArray, b: NDArray): NDArray {
util.assertAndGetBroadcastedShape(a.shape, b.shape);
return this.track(this.multiplyInternal(a, b));
}
protected abstract multiplyInternal<T extends NDArray>(a: T, b: T): T;

/**
* @deprecated Use math.multiplyStrict() instead.
*/
elementWiseMul<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in elementWiseMul: ');
return this.track(this.elementWiseMulInternal(a, b));
return this.multiplyStrict(a, b);
}

/**
* Multiplies two NDArrays element-wise, A * B. Inputs must
* be the same shape. For broadcasting support, use math.multiply() instead.
*
* @param a The first NDArray to multiply element-wise.
* @param b The second NDArray to multiply element-wise.
*/
multiplyStrict<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in multiplyStrict: ');
return this.multiply(a, b) as T;
}
protected abstract elementWiseMulInternal<T extends NDArray>(a: T, b: T): T;

/**
* Divides two NDArrays element-wise (hadamard product), A / B. Inputs must be
* the same shape.
* Divides two NDArrays element-wise, A / B. Supports broadcasting.
* For a stricter version without broadcasting use math.divideStrict().
*
* @param a The first NDArray to divide element-wise.
* @param b The second NDArray to divide element-wise.
*/
divide<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in divide: ');
divide(a: NDArray, b: NDArray): NDArray {
util.assertAndGetBroadcastedShape(a.shape, b.shape);
return this.track(this.divideInternal(a, b));
}
protected abstract divideInternal<T extends NDArray>(a: T, b: T): T;
protected abstract divideInternal(a: NDArray, b: NDArray): NDArray;

/**
* Divides two NDArrays element-wise, A / B. Inputs must
* be the same shape. For broadcasting support, use math.divide() instead.
*
* @param a The first NDArray to multiply element-wise.
* @param b The second NDArray to multiply element-wise.
*/
divideStrict<T extends NDArray>(a: T, b: T): T {
util.assertShapesMatch(a.shape, b.shape, 'Error in divideStrict: ');
return this.divide(a, b) as T;
}

/**
* Computes a scalar divided by an NDArray, broadcasted over the NDArray, c /
Expand All @@ -659,10 +714,8 @@ export abstract class NDArrayMath {
c.size === 1,
`Error in scalarDividedByArray: first argument must be rank 0, but ` +
`got NDArray of rank ${c.rank}.`);
return this.track(this.scalarDividedByArrayInternal(c, a));
return this.divide(c, a) as T;
}
protected abstract scalarDividedByArrayInternal<T extends NDArray>(
c: Scalar, a: T): T;

/**
* Computes an NDArray divided by a scalar, broadcasted over the NDArray, A /
Expand All @@ -675,10 +728,8 @@ export abstract class NDArrayMath {
c.size === 1,
`Error in arrayDividedByScalar: second argument must be rank 0, ` +
`but got NDArray of rank ${c.rank}.`);
return this.track(this.arrayDividedByScalarInternal(a, c));
return this.divide(a, c) as T;
}
protected abstract arrayDividedByScalarInternal<T extends NDArray>(
a: T, c: Scalar): T;

/**
* Computes exponential of the input NDArray element-wise. y = e ^ x
Expand Down Expand Up @@ -778,17 +829,11 @@ export abstract class NDArrayMath {
c.size === 1,
`Error in arrayDividedByScalar: first argument must be rank 0, but ` +
`got rank ${c.rank}.`);
return this.track(this.scalarTimesArrayInternal(c, a));
return this.multiply(c, a) as T;
}
protected abstract scalarTimesArrayInternal<T extends NDArray>(
c: Scalar, a: T): T;

/**
* Computes an element-wise broadcasted multiplication of two matrices A and
* B. Will return a new matrix that is the max of A and B, where the smaller
* matrix will broadcast over the larger matrix.
* @param c The scalar in the operation.
* @param A the NDArray in the operation that will be broadcasted over.
* @deprecated Use math.multiply() instead.
*/
elementWiseMulBroadcast(a: Array2D, b: Array2D): Array2D {
util.assert(
Expand All @@ -799,10 +844,8 @@ export abstract class NDArrayMath {
b.rank === 2,
`Error in elementWiseMulBroadcast: second argument must be ` +
`rank 2, but got rank ${b.rank}.`);
return this.track(this.elementWiseMulBroadcastInternal(a, b));
return this.multiply(a, b) as Array2D;
}
protected abstract elementWiseMulBroadcastInternal(a: Array2D, b: Array2D):
Array2D;

/////////////////////
// Convolution ops //
Expand Down
Loading