Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
"editor.insertSpaces": true,
"files.insertFinalNewline": true,
"editor.detectIndentation": false,
"editor.wrappingIndent": "none",
"typescript.tsdk": "node_modules/typescript/lib"
}
4 changes: 2 additions & 2 deletions src/graph_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ export class GraphRunner {
}

});
setTimeout(() => this.trainNetwork());
requestAnimationFrame(() => this.trainNetwork());
}

infer(
Expand Down Expand Up @@ -243,7 +243,7 @@ export class GraphRunner {
this.currentInferenceLoopNumPasses = numPasses;
if (!this.isInferring) {
this.inferencePassesThisRun = 0;
setTimeout(() => this.inferNetwork());
requestAnimationFrame(() => this.inferNetwork());
}
this.isInferring = true;
}
Expand Down
62 changes: 29 additions & 33 deletions src/math/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ export abstract class NDArrayMath {
*/
enableDebugMode() {
this.debugMode = true;
console.warn('Debugging mode is ON. The output of every math call will ' +
'be downloaded to CPU and checked for NaNs. ' +
'This significantly impacts performance.');
console.warn(
'Debugging mode is ON. The output of every math call will ' +
'be downloaded to CPU and checked for NaNs. ' +
'This significantly impacts performance.');
}

/**
Expand All @@ -97,7 +98,7 @@ export abstract class NDArrayMath {
endScope(result: ScopeResult) {
let arraysToKeep = this.activeScopeNDArraysToKeep;
if (result != null) {
arraysToKeep = arraysToKeep.concat(result as NDArray|NDArray[]);
arraysToKeep = arraysToKeep.concat(result as NDArray | NDArray[]);
}
// Dispose the current scope.
for (let i = 0; i < this.activeScope.length; i++) {
Expand Down Expand Up @@ -321,22 +322,15 @@ export abstract class NDArrayMath {
protected abstract cloneInternal<T extends NDArray>(ndarray: T): T;

/**
* Reshapes an NDArray to a new shape. The size of the input NDArray must
* match the size of the requested shape.
* @param ndarray The input NDArray.
* @param newShape The new shape to reshape the NDArray to. Must be the same
* size as the NDArray.
* @deprecated Please call reshape() directly on the ndarray object.
*/
reshape<T1 extends NDArray, T2 extends NDArray>(
ndarray: T1, newShape: number[]): T2 {
util.assert(
ndarray.size === util.sizeFromShape(newShape),
`Error in reshape: old size ${ndarray.size} must match new size ` +
`${util.sizeFromShape(newShape)}.`);
return this.track(this.reshapeInternal<T1, T2>(ndarray, newShape));
console.warn(
'math.reshape() is deprecated. Please call reshape() ' +
'directly on the ndarray object');
return ndarray.reshape(newShape);
}
protected abstract reshapeInternal<T1 extends NDArray, T2 extends NDArray>(
ndarray: T1, newShape: number[]): T2;

/**
* Extracts a slice from a matrix. The operation extraces a slice from input
Expand Down Expand Up @@ -1148,7 +1142,8 @@ export abstract class NDArrayMath {
* @param h Array of previous cell outputs.
* @return Tuple [nextCellStates, cellOutputs]
*/
multiRNNCell(lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
multiRNNCell(
lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
h: Array2D[]): [Array2D[], Array2D[]] {
util.assert(
data.shape[0] === 1,
Expand Down Expand Up @@ -1187,8 +1182,9 @@ export abstract class NDArrayMath {
* @param h Previous cell output.
* @return Tuple [nextCellState, cellOutput]
*/
basicLSTMCell(forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D,
data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D] {
basicLSTMCell(
forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D, data: Array2D,
c: Array2D, h: Array2D): [Array2D, Array2D] {
const res = this.scope(() => {
util.assert(
data.shape[0] === 1,
Expand All @@ -1207,25 +1203,25 @@ export abstract class NDArrayMath {

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]);
const j = this.slice2D(res, [0, res.shape[1] / 4 * 1],
[res.shape[0], res.shape[1] / 4]);
const f = this.slice2D(res, [0, res.shape[1] / 4 * 2],
[res.shape[0], res.shape[1] / 4]);
const o = this.slice2D(res, [0, res.shape[1] / 4 * 3],
[res.shape[0], res.shape[1] / 4]);

const newC = this.add(
this.multiplyStrict(c,
this.sigmoid(this.scalarPlusArray(forgetBias, f))),
this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
const newH = this.multiplyStrict(
this.tanh(newC), this.sigmoid(o)) as Array2D;
const j = this.slice2D(
res, [0, res.shape[1] / 4 * 1], [res.shape[0], res.shape[1] / 4]);
const f = this.slice2D(
res, [0, res.shape[1] / 4 * 2], [res.shape[0], res.shape[1] / 4]);
const o = this.slice2D(
res, [0, res.shape[1] / 4 * 3], [res.shape[0], res.shape[1] / 4]);

const newC =
this.add(
this.multiplyStrict(
c, this.sigmoid(this.scalarPlusArray(forgetBias, f))),
this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
const newH =
this.multiplyStrict(this.tanh(newC), this.sigmoid(o)) as Array2D;

return [newC, newH];
});
return [res[0], res[1]];
}

}

export enum MatrixOrientation {
Expand Down
5 changes: 0 additions & 5 deletions src/math/math_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ export class NDArrayMathCPU extends NDArrayMath {
ndarray.shape, {values: new Float32Array(ndarray.getValues())});
}

protected reshapeInternal<T1 extends NDArray, T2 extends NDArray>(
ndarray: T1, newShape: number[]): T2 {
return this.cloneInternal(ndarray).reshape<T2>(newShape);
}

protected slice2DInternal(
input: Array2D, beginRowCol: [number, number],
sizeRowCol: [number, number]): Array2D {
Expand Down
Loading