Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ node_js: "8"
install:
- npm install
script:
- npm run build
- npm run lint
- npm run build --silent
- npm run lint --silent
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"tsify": "~3.0.1",
"tslint": "~5.6.0",
"typedoc": "~0.7.2",
"typescript": "2.3.4",
"typescript": "2.4.2",
"watchify": "~3.9.0"
},
"scripts": {
Expand Down
16 changes: 8 additions & 8 deletions src/math/activation_functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ export interface ActivationFunction {
}

export class TanHFunc implements ActivationFunction {
output(math: NDArrayMath, x: NDArray) {
output<T extends NDArray>(math: NDArrayMath, x: T) {
return math.scope(() => {
return math.tanh(x);
});
}

der(math: NDArrayMath, x: NDArray, y: NDArray) {
der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
return math.scope(() => {
const ySquared = math.elementWiseMul(y, y);
// 1 - y^2.
Expand All @@ -39,27 +39,27 @@ export class TanHFunc implements ActivationFunction {
}

export class ReLUFunc implements ActivationFunction {
output(math: NDArrayMath, x: NDArray) {
output<T extends NDArray>(math: NDArrayMath, x: T) {
return math.scope(() => {
return math.relu(x);
});
}

der(math: NDArrayMath, x: NDArray, y: NDArray) {
der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
return math.scope(() => {
return math.step(x);
});
}
}

export class SigmoidFunc implements ActivationFunction {
output(math: NDArrayMath, x: NDArray) {
output<T extends NDArray>(math: NDArrayMath, x: T) {
return math.scope(() => {
return math.sigmoid(x);
});
}

der(math: NDArrayMath, x: NDArray, y: NDArray) {
der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
return math.scope(() => {
// y * (1 - y) = y - y^2
const ySquared = math.elementWiseMul(y, y);
Expand All @@ -69,13 +69,13 @@ export class SigmoidFunc implements ActivationFunction {
}

export class SquareFunc implements ActivationFunction {
output(math: NDArrayMath, x: NDArray) {
output<T extends NDArray>(math: NDArrayMath, x: T) {
return math.scope(() => {
return math.elementWiseMul(x, x);
});
}

der(math: NDArrayMath, x: NDArray, y: NDArray) {
der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
return math.scope(() => {
// dy/dx = 2*x.
return math.scalarTimesArray(Scalar.TWO, x);
Expand Down
4 changes: 2 additions & 2 deletions src/math/cost_functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export interface ElementWiseCostFunction {
export class SquareCostFunc implements ElementWiseCostFunction {
private halfOne = Scalar.new(0.5);

cost(math: NDArrayMath, x1: NDArray, x2: NDArray): NDArray {
cost<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
const diff = math.sub(x1, x2);
const diffSquared = math.elementWiseMul(diff, diff);
const result = math.scalarTimesArray(this.halfOne, diffSquared);
Expand All @@ -39,7 +39,7 @@ export class SquareCostFunc implements ElementWiseCostFunction {
return result;
}

der(math: NDArrayMath, x1: NDArray, x2: NDArray): NDArray {
der<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
return math.sub(x1, x2);
}

Expand Down
2 changes: 1 addition & 1 deletion src/math/math_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class NDArrayMathCPU extends NDArrayMath {
const outputShape =
concat3d_util.computeConcat3DOutputShape(x1.shape, x2.shape, axis);

const values = NDArray.zeros<Array3D>(outputShape);
const values = Array3D.zeros(outputShape);

for (let i = 0; i < outputShape[0]; i++) {
for (let j = 0; j < outputShape[1]; j++) {
Expand Down
12 changes: 6 additions & 6 deletions src/math/math_cpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ describe('NDArrayMathCPU matMul', () => {
});

it('A x B^t shapes do not match', () => {
const a = NDArray.zeros<Array2D>([2, 3]);
const b = NDArray.zeros<Array2D>([3, 2]);
const a = Array2D.zeros([2, 3]);
const b = Array2D.zeros([3, 2]);
const f = () => {
math.matMul(
a, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED);
Expand All @@ -255,8 +255,8 @@ describe('NDArrayMathCPU matMul', () => {
});

it('A^t x B shapes do not match', () => {
const a = NDArray.zeros<Array2D>([2, 3]);
const b = NDArray.zeros<Array2D>([3, 2]);
const a = Array2D.zeros([2, 3]);
const b = Array2D.zeros([3, 2]);
const f = () => {
math.matMul(
a, b, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR);
Expand All @@ -265,8 +265,8 @@ describe('NDArrayMathCPU matMul', () => {
});

it('A^t x B^t shapes do not match', () => {
const a = NDArray.zeros<Array2D>([3, 2]);
const b = NDArray.zeros<Array2D>([3, 2]);
const a = Array2D.zeros([3, 2]);
const b = Array2D.zeros([3, 2]);
const f = () => {
math.matMul(
a, b, MatrixOrientation.TRANSPOSED, MatrixOrientation.TRANSPOSED);
Expand Down
12 changes: 6 additions & 6 deletions src/math/ndarray.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ export class NDArray {
}

/** Creates a ndarray of zeros with the specified shape. */
static zeros<T extends NDArray>(shape: number[]): T {
static zeros(shape: number[]): NDArray {
const values = new Float32Array(util.sizeFromShape(shape));
return NDArray.make<T>(shape, {values});
return NDArray.make(shape, {values});
}

/** Creates a ndarray of zeros with the same shape as the specified ndarray.
Expand Down Expand Up @@ -390,7 +390,7 @@ export class Array1D extends NDArray {
}

static zeros(shape: [number]): Array1D {
return NDArray.zeros<Array1D>(shape);
return NDArray.zeros(shape) as Array1D;
}
}

Expand Down Expand Up @@ -441,7 +441,7 @@ export class Array2D extends NDArray {
}

static zeros(shape: [number, number]): Array2D {
return NDArray.zeros<Array2D>(shape);
return NDArray.zeros(shape) as Array2D;
}
}

Expand Down Expand Up @@ -496,7 +496,7 @@ export class Array3D extends NDArray {
}

static zeros(shape: [number, number, number]): Array3D {
return NDArray.zeros<Array3D>(shape);
return NDArray.zeros(shape) as Array3D;
}
}

Expand Down Expand Up @@ -559,7 +559,7 @@ export class Array4D extends NDArray {
}

static zeros(shape: [number, number, number, number]): Array4D {
return NDArray.zeros<Array4D>(shape);
return NDArray.zeros(shape) as Array4D;
}
}

Expand Down