diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index fd8f7677fcc..f9806bad8d8 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -63,6 +63,11 @@ async function main() { let command = `touch ${filePath}`; execSync(command); + // create a test file + filePath = `./src/ops/${args.op}_test.ts`; + command = `touch ${filePath}`; + execSync(command); + if (args.chained) { filePath = `./src/public/chained_ops/${args.op}.ts`; command = `touch ${filePath}`; diff --git a/tfjs-core/src/ops/band_part.ts b/tfjs-core/src/ops/band_part.ts new file mode 100644 index 00000000000..1a65674cbb2 --- /dev/null +++ b/tfjs-core/src/ops/band_part.ts @@ -0,0 +1,122 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert} from '../util'; + +import {reshape, stack, unstack} from './array_ops'; +import {greaterEqual} from './greater_equal'; +import {lessEqual} from './less_equal'; +import {logicalAnd, where} from './logical_ops'; +import {op} from './operation'; +import {sub} from './sub'; +import {range, scalar, zeros} from './tensor_ops'; + +/** + * Copy a tensor setting everything outside a central band in each innermost + * matrix to zero. + * + * The band part is computed as follows: Assume input has `k` dimensions + * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where + * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. + * The indicator function + * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))` + * `&& (num_upper < 0 || (n-m) <= num_upper)` + * + * ```js + * const x = tf.tensor2d([[ 0, 1, 2, 3], + * [-1, 0, 1, 2], + * [-2, -1, 0, 1], + * [-3, -2, -1, 0]]); + * let y = tf.linalg.bandPart(x, 1, -1); + * y.print(); // [[ 0, 1, 2, 3], + * // [-1, 0, 1, 2], + * // [ 0, -1, 0, 1], + * // [ 0, 0 , -1, 0]] + * let z = tf.linalg.bandPart(x, 2, 1); + * z.print(); // [[ 0, 1, 0, 0], + * // [-1, 0, 1, 0], + * // [-2, -1, 0, 1], + * // [ 0, -2, -1, 0]] + * ``` + * + * @param x Rank `k` tensor + * @param numLower Number of subdiagonals to keep. + * If negative, keep entire lower triangle. + * @param numUpper Number of subdiagonals to keep. + * If negative, keep entire upper triangle. + * @returns Rank `k` tensor of the same shape as input. + * The extracted banded tensor. + */ +/** + * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'} + */ +function bandPart_( + a: T|TensorLike, numLower: number, numUpper: number): T { + assert( + numLower % 1 === 0, + () => `bandPart(): numLower must be an integer, got ${numLower}.`); + assert( + numUpper % 1 === 0, + () => `bandPart(): numUpper must be an integer, got ${numUpper}.`); + + const $a = convertToTensor(a, 'a', 'bandPart'); + + assert( + $a.rank >= 2, + () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`); + + const shape = $a.shape; + const [M, N] = $a.shape.slice(-2); + + if (!(numLower <= M)) { + throw new Error( + `bandPart(): numLower (${numLower})` + + ` must not be greater than the number of rows (${M}).`); + } + if (!(numUpper <= N)) { + throw new Error( + `bandPart(): numUpper (${numUpper})` + + ` must not be greater than the number of columns (${N}).`); + } + + if (numLower < 0) { + numLower = M; + } + if (numUpper < 0) { + numUpper = N; + } + + const i = reshape(range(0, M, 1, 'int32'), [-1, 1]); + const j = range(0, N, 1, 'int32'); + const ij = sub(i, j); + + const inBand = logicalAnd( + lessEqual(ij, scalar(+numLower, 'int32')), + greaterEqual(ij, scalar(-numUpper, 'int32'))); + + const zero = zeros([M, N], $a.dtype); + + return reshape( + stack(unstack(reshape($a, [-1, M, N])) + .map(mat => where(inBand, mat, zero))), + shape) as T; +} + +export const bandPart = op({bandPart_}); diff --git a/tfjs-core/src/ops/band_part_test.ts b/tfjs-core/src/ops/band_part_test.ts new file mode 100644 index 00000000000..0c7d9f22392 --- /dev/null +++ b/tfjs-core/src/ops/band_part_test.ts @@ -0,0 +1,181 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {Tensor2D, Tensor3D} from '../tensor'; +import {expectArraysClose} from '../test_util'; + +import {scalar, tensor1d, tensor2d, tensor3d} from './ops'; + +describeWithFlags('bandPart', ALL_ENVS, () => { + it('keeps tensor unchanged', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, -1, -1).array(), + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]); + }); + + it('upper triangular matrix', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, -1).array(), + [[1, 1, 1], [0, 1, 1], [0, 0, 1]]); + }); + + it('lower triangular matrix', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, -1, 0).array(), + [[1, 0, 0], [1, 1, 0], [1, 1, 1]]); + }); + + it('diagonal elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 0).array(), + [[1, 0, 0], [0, 1, 0], [0, 0, 1]]); + }); + + it('lower triangular elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, 0).array(), + [[1, 0, 0], [1, 1, 0], [0, 1, 1]]); + }); + + it('upper triangular elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 1).array(), + [[1, 1, 0], [0, 1, 1], [0, 0, 1]]); + }); + + it('4X4 matrix - tensorflow python examples', async () => { + const x: Tensor2D = tensor2d( + [[0, 1, 2, 3], [-1, 0, 1, 2], [-2, -1, 0, 1], [-3, -2, -1, 0]]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, -1).array(), + [[0, 1, 2, 3], [-1, 0, 1, 2], [0, -1, 0, 1], [0, 0, -1, 0]]); + expectArraysClose( + await tf.linalg.bandPart(x, 2, 1).array(), + [[0, 1, 0, 0], [-1, 0, 1, 0], [-2, -1, 0, 1], [0, -2, -1, 0]]); + }); + + it('3 dimensional matrix', async () => { + const x: Tensor3D = tensor3d([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 0).array(), + [[[1, 0], [0, 1]], [[1, 0], [0, 1]]]); + }); + + it('2X3X3 tensor', async () => { + const x: Tensor3D = tensor3d( + [[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, 2).array(), + [[[1, 1, 1], [1, 1, 1], [0, 1, 1]], [[1, 1, 1], [1, 1, 1], [0, 1, 1]]]); + }); + + const la = tf.linalg; + + it('fails for scalar', async () => { + const x = scalar(1); + expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); + }); + + it('fails for 1D tensor', async () => { + const x = tensor1d([1, 2, 3, 4, 5]); + expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); + }); + + it('fails if numLower or numUpper too large', async () => { + const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + + for (const numLower of [3, 5, 8, 13]) { + for (const numUpper of [-1, 0, 1, 2]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*numLower/i); + } + } + + for (const numLower of [-1, 0, 1]) { + for (const numUpper of [4, 5, 9]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*numUpper/i); + } + } + + for (const numLower of [3, 5, 8, 13]) { + for (const numUpper of [4, 5, 9]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*(numLower|numUpper)/i); + } + } + }); + + it('works for 3x4 example', async () => { + const a = tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); + + expectArraysClose( + await la.bandPart(a, 0, 0).array(), + [[1, 0, 0, 0], [0, 6, 0, 0], [0, 0, 11, 0]]); + expectArraysClose( + await la.bandPart(a, 0, 1).array(), + [[1, 2, 0, 0], [0, 6, 7, 0], [0, 0, 11, 12]]); + expectArraysClose( + await la.bandPart(a, 0, 2).array(), + [[1, 2, 3, 0], [0, 6, 7, 8], [0, 0, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { + expectArraysClose( + await la.bandPart(a, 0, numUpper).array(), + [[1, 2, 3, 4], [0, 6, 7, 8], [0, 0, 11, 12]]); + } + + expectArraysClose( + await la.bandPart(a, 1, 0).array(), + [[1, 0, 0, 0], [5, 6, 0, 0], [0, 10, 11, 0]]); + expectArraysClose( + await la.bandPart(a, 1, 1).array(), + [[1, 2, 0, 0], [5, 6, 7, 0], [0, 10, 11, 12]]); + expectArraysClose( + await la.bandPart(a, 1, 2).array(), + [[1, 2, 3, 0], [5, 6, 7, 8], [0, 10, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { + expectArraysClose( + await la.bandPart(a, 1, numUpper).array(), + [[1, 2, 3, 4], [5, 6, 7, 8], [0, 10, 11, 12]]); + } + + for (const numLower of [2, 3, -1, -2]) { + expectArraysClose( + await la.bandPart(a, numLower, 0).array(), + [[1, 0, 0, 0], [5, 6, 0, 0], [9, 10, 11, 0]]); + expectArraysClose( + await la.bandPart(a, numLower, 1).array(), + [[1, 2, 0, 0], [5, 6, 7, 0], [9, 10, 11, 12]]); + expectArraysClose( + await la.bandPart(a, numLower, 2).array(), + [[1, 2, 3, 0], [5, 6, 7, 8], [9, 10, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { + expectArraysClose( + await la.bandPart(a, numLower, numUpper).array(), + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); + } + } + }); +}); diff --git a/tfjs-core/src/ops/gram_schmidt.ts b/tfjs-core/src/ops/gram_schmidt.ts new file mode 100644 index 00000000000..599a85c373c --- /dev/null +++ b/tfjs-core/src/ops/gram_schmidt.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor1D, Tensor2D} from '../tensor'; +import {assert} from '../util'; + +import {squeeze, stack} from './array_ops'; +import {div} from './div'; +import {mul} from './mul'; +import {norm} from './norm'; +import {op} from './operation'; +import {sum} from './reduction_ops'; +import {split} from './split'; +import {sub} from './sub'; + +/** + * Gram-Schmidt orthogonalization. + * + * ```js + * const x = tf.tensor2d([[1, 2], [3, 4]]); + * let y = tf.linalg.gramSchmidt(x); + * y.print(); + * console.log('Othogonalized:'); + * y.dot(y.transpose()).print(); // should be nearly the identity matrix. + * console.log('First row direction maintained:'); + * const data = await y.array(); + * console.log(data[0][1] / data[0][0]); // should be nearly 2. + * ``` + * + * @param xs The vectors to be orthogonalized, in one of the two following + * formats: + * - An Array of `tf.Tensor1D`. + * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows + * of `xs`. + * In each case, all the vectors must have the same length and the length + * must be greater than or equal to the number of vectors. + * @returns The orthogonalized and normalized vectors or matrix. + * Orthogonalization means that the vectors or the rows of the matrix + * are orthogonal (zero inner products). Normalization means that each + * vector or each row of the matrix has an L2 norm that equals `1`. + */ +/** + * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'} + */ +function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { + let inputIsTensor2D: boolean; + if (Array.isArray(xs)) { + inputIsTensor2D = false; + assert( + xs != null && xs.length > 0, + () => 'Gram-Schmidt process: input must not be null, undefined, or ' + + 'empty'); + const dim = xs[0].shape[0]; + for (let i = 1; i < xs.length; ++i) { + assert( + xs[i].shape[0] === dim, + () => + 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + + `(${(xs as Tensor1D[])[i].shape[0]} vs. ${dim})`); + } + } else { + inputIsTensor2D = true; + xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0])); + } + + assert( + xs.length <= xs[0].shape[0], + () => `Gram-Schmidt: Number of vectors (${ + (xs as Tensor1D[]).length}) exceeds ` + + `number of dimensions (${(xs as Tensor1D[])[0].shape[0]}).`); + + const ys: Tensor1D[] = []; + const xs1d = xs; + for (let i = 0; i < xs.length; ++i) { + ys.push(ENGINE.tidy(() => { + let x = xs1d[i]; + if (i > 0) { + for (let j = 0; j < i; ++j) { + const proj = mul(sum(mul(ys[j], x)), ys[j]); + x = sub(x, proj); + } + } + return div(x, norm(x, 'euclidean')); + })); + } + + if (inputIsTensor2D) { + return stack(ys, 0) as Tensor2D; + } else { + return ys; + } +} + +export const gramSchmidt = op({gramSchmidt_}); diff --git a/tfjs-core/src/ops/gram_schmidt_test.ts b/tfjs-core/src/ops/gram_schmidt_test.ts new file mode 100644 index 00000000000..b976e8ae14e --- /dev/null +++ b/tfjs-core/src/ops/gram_schmidt_test.ts @@ -0,0 +1,87 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {Tensor1D, Tensor2D} from '../tensor'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { + it('2x2, Array of Tensor1D', async () => { + const xs: Tensor1D[] = [ + tf.randomNormal([2], 0, 1, 'float32', 1), + tf.randomNormal([2], 0, 1, 'float32', 2) + ]; + const ys = tf.linalg.gramSchmidt(xs) as Tensor1D[]; + const y = tf.stack(ys) as Tensor2D; + // Test that the results are orthogonalized and normalized. + expectArraysClose( + await y.transpose().matMul(y).array(), await tf.eye(2).array()); + // Test angle between xs[0] and ys[0] is zero, i.e., the orientation of the + // first vector is kept. + expectArraysClose( + await tf.sum(xs[0].mul(ys[0])).array(), + await tf.norm(xs[0]).mul(tf.norm(ys[0])).array()); + }); + + it('3x3, Array of Tensor1D', async () => { + const xs: Tensor1D[] = [ + tf.randomNormal([3], 0, 1, 'float32', 1), + tf.randomNormal([3], 0, 1, 'float32', 2), + tf.randomNormal([3], 0, 1, 'float32', 3) + ]; + const ys = tf.linalg.gramSchmidt(xs) as Tensor1D[]; + const y = tf.stack(ys) as Tensor2D; + expectArraysClose( + await y.transpose().matMul(y).array(), await tf.eye(3).array()); + expectArraysClose( + await tf.sum(xs[0].mul(ys[0])).array(), + await tf.norm(xs[0]).mul(tf.norm(ys[0])).array()); + }); + + it('3x3, Matrix', async () => { + const xs: Tensor2D = tf.randomNormal([3, 3], 0, 1, 'float32', 1); + const y = tf.linalg.gramSchmidt(xs) as Tensor2D; + expectArraysClose( + await y.transpose().matMul(y).array(), await tf.eye(3).array()); + }); + + it('2x3, Matrix', async () => { + const xs: Tensor2D = tf.randomNormal([2, 3], 0, 1, 'float32', 1); + const y = tf.linalg.gramSchmidt(xs) as Tensor2D; + const yT: Tensor2D = y.transpose(); + expectArraysClose(await y.matMul(yT).array(), await tf.eye(2).array()); + }); + + it('3x2 Matrix throws Error', () => { + const xs = tf.tensor2d([[1, 2], [3, -1], [5, 1]]); + expect(() => tf.linalg.gramSchmidt(xs)) + .toThrowError( + /Number of vectors \(3\) exceeds number of dimensions \(2\)/); + }); + + it('Mismatching dimensions input throws Error', () => { + const xs: Tensor1D[] = + [tf.tensor1d([1, 2, 3]), tf.tensor1d([-1, 5, 1]), tf.tensor1d([0, 0])]; + + expect(() => tf.linalg.gramSchmidt(xs)).toThrowError(/Non-unique/); + }); + + it('Empty input throws Error', () => { + expect(() => tf.linalg.gramSchmidt([])).toThrowError(/empty/); + }); +}); diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts deleted file mode 100644 index 4f10e56fb33..00000000000 --- a/tfjs-core/src/ops/linalg_ops.ts +++ /dev/null @@ -1,368 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -/** - * Linear algebra ops. - */ - -import {ENGINE} from '../engine'; -import {dispose} from '../globals'; -import {Tensor, Tensor1D, Tensor2D} from '../tensor'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import {assert} from '../util'; - -import {squeeze, stack, unstack} from './array_ops'; -import {eye} from './eye'; -import {logicalAnd, where} from './logical_ops'; -import {norm} from './norm'; -import {op} from './operation'; -import {sum} from './reduction_ops'; -import {split} from './split'; -import {sub} from './sub'; -import {range, scalar, tensor2d, zeros} from './tensor_ops'; - -/** - * Copy a tensor setting everything outside a central band in each innermost - * matrix to zero. - * - * The band part is computed as follows: Assume input has `k` dimensions - * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where - * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. - * The indicator function - * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))` - * `&& (num_upper < 0 || (n-m) <= num_upper)` - * - * ```js - * const x = tf.tensor2d([[ 0, 1, 2, 3], - * [-1, 0, 1, 2], - * [-2, -1, 0, 1], - * [-3, -2, -1, 0]]); - * let y = tf.linalg.bandPart(x, 1, -1); - * y.print(); // [[ 0, 1, 2, 3], - * // [-1, 0, 1, 2], - * // [ 0, -1, 0, 1], - * // [ 0, 0 , -1, 0]] - * let z = tf.linalg.bandPart(x, 2, 1); - * z.print(); // [[ 0, 1, 0, 0], - * // [-1, 0, 1, 0], - * // [-2, -1, 0, 1], - * // [ 0, -2, -1, 0]] - * ``` - * - * @param x Rank `k` tensor - * @param numLower Number of subdiagonals to keep. - * If negative, keep entire lower triangle. - * @param numUpper Number of subdiagonals to keep. - * If negative, keep entire upper triangle. - * @returns Rank `k` tensor of the same shape as input. - * The extracted banded tensor. - */ -/** - * @doc {heading:'Operations', - * subheading:'Linear Algebra', - * namespace:'linalg'} - */ -function bandPart_( - a: T|TensorLike, numLower: number, numUpper: number): T { - if (numLower % 1 !== 0) { - throw new Error( - `bandPart(): numLower must be an integer, got ${numLower}.`); - } - if (numUpper % 1 !== 0) { - throw new Error( - `bandPart(): numUpper must be an integer, got ${numUpper}.`); - } - - const $a = convertToTensor(a, 'a', 'bandPart'); - - if ($a.rank < 2) { - throw new Error(`bandPart(): Rank must be at least 2, got ${$a.rank}.`); - } - - const shape = $a.shape, [M, N] = $a.shape.slice(-2); - - if (!(numLower <= M)) { - throw new Error( - `bandPart(): numLower (${numLower})` + - ` must not be greater than the number of rows (${M}).`); - } - if (!(numUpper <= N)) { - throw new Error( - `bandPart(): numUpper (${numUpper})` + - ` must not be greater than the number of columns (${N}).`); - } - - if (numLower < 0) { - numLower = M; - } - if (numUpper < 0) { - numUpper = N; - } - - const i = range(0, M, 1, 'int32').reshape([-1, 1]), - j = range(0, N, 1, 'int32'), ij = sub(i, j); - - const inBand = logicalAnd( - ij.lessEqual(scalar(+numLower, 'int32')), - ij.greaterEqual(scalar(-numUpper, 'int32'))); - - const zero = zeros([M, N], $a.dtype); - - return stack(unstack($a.reshape([-1, M, N])) - .map(mat => where(inBand, mat, zero))) - .reshape(shape) as T; -} - -/** - * Gram-Schmidt orthogonalization. - * - * ```js - * const x = tf.tensor2d([[1, 2], [3, 4]]); - * let y = tf.linalg.gramSchmidt(x); - * y.print(); - * console.log('Othogonalized:'); - * y.dot(y.transpose()).print(); // should be nearly the identity matrix. - * console.log('First row direction maintained:'); - * const data = await y.array(); - * console.log(data[0][1] / data[0][0]); // should be nearly 2. - * ``` - * - * @param xs The vectors to be orthogonalized, in one of the two following - * formats: - * - An Array of `tf.Tensor1D`. - * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows - * of `xs`. - * In each case, all the vectors must have the same length and the length - * must be greater than or equal to the number of vectors. - * @returns The orthogonalized and normalized vectors or matrix. - * Orthogonalization means that the vectors or the rows of the matrix - * are orthogonal (zero inner products). Normalization means that each - * vector or each row of the matrix has an L2 norm that equals `1`. - */ -/** - * @doc {heading:'Operations', - * subheading:'Linear Algebra', - * namespace:'linalg'} - */ -function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { - let inputIsTensor2D: boolean; - if (Array.isArray(xs)) { - inputIsTensor2D = false; - assert( - xs != null && xs.length > 0, - () => 'Gram-Schmidt process: input must not be null, undefined, or ' + - 'empty'); - const dim = xs[0].shape[0]; - for (let i = 1; i < xs.length; ++i) { - assert( - xs[i].shape[0] === dim, - () => - 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + - `(${(xs as Tensor1D[])[i].shape[0]} vs. ${dim})`); - } - } else { - inputIsTensor2D = true; - xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0])); - } - - assert( - xs.length <= xs[0].shape[0], - () => `Gram-Schmidt: Number of vectors (${ - (xs as Tensor1D[]).length}) exceeds ` + - `number of dimensions (${(xs as Tensor1D[])[0].shape[0]}).`); - - const ys: Tensor1D[] = []; - const xs1d = xs; - for (let i = 0; i < xs.length; ++i) { - ys.push(ENGINE.tidy(() => { - let x = xs1d[i]; - if (i > 0) { - for (let j = 0; j < i; ++j) { - const proj = sum(ys[j].mul(x)).mul(ys[j]); - x = x.sub(proj); - } - } - return x.div(norm(x, 'euclidean')); - })); - } - - if (inputIsTensor2D) { - return stack(ys, 0) as Tensor2D; - } else { - return ys; - } -} - -/** - * Compute QR decomposition of m-by-n matrix using Householder transformation. - * - * Implementation based on - * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] - * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) - * - * ```js - * const a = tf.tensor2d([[1, 2], [3, 4]]); - * let [q, r] = tf.linalg.qr(a); - * console.log('Q'); - * q.print(); - * console.log('R'); - * r.print(); - * console.log('Orthogonalized'); - * q.dot(q.transpose()).print() // should be nearly the identity matrix. - * console.log('Reconstructed'); - * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; - * ``` - * - * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose - * it has the shape `[..., M, N]`. - * @param fullMatrices An optional boolean parameter. Defaults to `false`. - * If `true`, compute full-sized `Q`. If `false` (the default), - * compute only the leading N columns of `Q` and `R`. - * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix, - * i.e., its columns all have unit norm and are mutually orthogonal. - * If `M >= N`, - * If `fullMatrices` is `false` (default), - * - `Q` has a shape of `[..., M, N]`, - * - `R` has a shape of `[..., N, N]`. - * If `fullMatrices` is `true` (default), - * - `Q` has a shape of `[..., M, M]`, - * - `R` has a shape of `[..., M, N]`. - * If `M < N`, - * - `Q` has a shape of `[..., M, M]`, - * - `R` has a shape of `[..., M, N]`. - * @throws If the rank of `x` is less than 2. - */ -/** - * @doc {heading:'Operations', - * subheading:'Linear Algebra', - * namespace:'linalg'} - */ -function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { - if (x.rank < 2) { - throw new Error( - `qr() requires input tensor to have a rank >= 2, but got rank ${ - x.rank}`); - } else if (x.rank === 2) { - return qr2d(x as Tensor2D, fullMatrices); - } else { - // Rank > 2. - // TODO(cais): Below we split the input into individual 2D tensors, - // perform QR decomposition on them and then stack the results back - // together. We should explore whether this can be parallelized. - const outerDimsProd = x.shape.slice(0, x.shape.length - 2) - .reduce((value, prev) => value * prev); - const x2ds = unstack( - x.reshape([ - outerDimsProd, x.shape[x.shape.length - 2], - x.shape[x.shape.length - 1] - ]), - 0); - const q2ds: Tensor2D[] = []; - const r2ds: Tensor2D[] = []; - x2ds.forEach(x2d => { - const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices); - q2ds.push(q2d); - r2ds.push(r2d); - }); - const q = stack(q2ds, 0).reshape(x.shape); - const r = stack(r2ds, 0).reshape(x.shape); - return [q, r]; - } -} - -function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { - return ENGINE.tidy(() => { - if (x.shape.length !== 2) { - throw new Error( - `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`); - } - - const m = x.shape[0]; - const n = x.shape[1]; - - let q = eye(m); // Orthogonal transform so far. - let r = x.clone(); // Transformed matrix so far. - - const one2D = tensor2d([[1]], [1, 1]); - let w: Tensor2D = one2D.clone(); - - const iters = m >= n ? n : m; - for (let j = 0; j < iters; ++j) { - // This tidy within the for-loop ensures we clean up temporary - // tensors as soon as they are no longer needed. - const rTemp = r; - const wTemp = w; - const qTemp = q; - [w, r, q] = ENGINE.tidy((): [Tensor2D, Tensor2D, Tensor2D] => { - // Find H = I - tau * w * w', to put zeros below R(j, j). - const rjEnd1 = r.slice([j, j], [m - j, 1]); - const normX = rjEnd1.norm(); - const rjj = r.slice([j, j], [1, 1]); - - // The sign() function returns 0 on 0, which causes division by zero. - const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]])); - - const u1 = rjj.sub(s.mul(normX)); - const wPre = rjEnd1.div(u1); - if (wPre.shape[0] === 1) { - w = one2D.clone(); - } else { - w = one2D.concat( - wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as - Tensor2D, - 0); - } - const tau = s.matMul(u1).div(normX).neg() as Tensor2D; - - // -- R := HR, Q := QH. - const rjEndAll = r.slice([j, 0], [m - j, n]); - const tauTimesW: Tensor2D = tau.mul(w); - const wT: Tensor2D = w.transpose(); - if (j === 0) { - r = rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll))); - } else { - const rTimesTau: Tensor2D = - rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll))); - r = r.slice([0, 0], [j, n]).concat(rTimesTau, 0); - } - const tawTimesWT: Tensor2D = tauTimesW.transpose(); - const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); - if (j === 0) { - q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT)); - } else { - const qTimesTau: Tensor2D = - qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT)); - q = q.slice([0, 0], [m, j]).concat(qTimesTau, 1); - } - return [w, r, q]; - }); - dispose([rTemp, wTemp, qTemp]); - } - - if (!fullMatrices && m > n) { - q = q.slice([0, 0], [m, n]); - r = r.slice([0, 0], [n, n]); - } - - return [q, r]; - }) as [Tensor2D, Tensor2D]; -} - -export const bandPart = op({bandPart_}); -export const gramSchmidt = op({gramSchmidt_}); -export const qr = op({qr_}); diff --git a/tfjs-core/src/ops/linalg_ops_test.ts b/tfjs-core/src/ops/linalg_ops_test.ts deleted file mode 100644 index bb5d76a413f..00000000000 --- a/tfjs-core/src/ops/linalg_ops_test.ts +++ /dev/null @@ -1,369 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as tf from '../index'; -import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; -import {Tensor1D, Tensor2D, Tensor3D} from '../tensor'; -import {expectArraysClose} from '../test_util'; - -import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; - -describeWithFlags('bandPart', ALL_ENVS, () => { - it('keeps tensor unchanged', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, -1, -1).array(), - [[1, 1, 1], [1, 1, 1], [1, 1, 1]]); - }); - - it('upper triangular matrix', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, 0, -1).array(), - [[1, 1, 1], [0, 1, 1], [0, 0, 1]]); - }); - - it('lower triangular matrix', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, -1, 0).array(), - [[1, 0, 0], [1, 1, 0], [1, 1, 1]]); - }); - - it('diagonal elements', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, 0, 0).array(), - [[1, 0, 0], [0, 1, 0], [0, 0, 1]]); - }); - - it('lower triangular elements', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, 1, 0).array(), - [[1, 0, 0], [1, 1, 0], [0, 1, 1]]); - }); - - it('upper triangular elements', async () => { - const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); - expectArraysClose( - await tf.linalg.bandPart(x, 0, 1).array(), - [[1, 1, 0], [0, 1, 1], [0, 0, 1]]); - }); - - it('4X4 matrix - tensorflow python examples', async () => { - const x: Tensor2D = tensor2d( - [[0, 1, 2, 3], [-1, 0, 1, 2], [-2, -1, 0, 1], [-3, -2, -1, 0]]); - expectArraysClose( - await tf.linalg.bandPart(x, 1, -1).array(), - [[0, 1, 2, 3], [-1, 0, 1, 2], [0, -1, 0, 1], [0, 0, -1, 0]]); - expectArraysClose( - await tf.linalg.bandPart(x, 2, 1).array(), - [[0, 1, 0, 0], [-1, 0, 1, 0], [-2, -1, 0, 1], [0, -2, -1, 0]]); - }); - - it('3 dimensional matrix', async () => { - const x: Tensor3D = tensor3d([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]); - expectArraysClose( - await tf.linalg.bandPart(x, 0, 0).array(), - [[[1, 0], [0, 1]], [[1, 0], [0, 1]]]); - }); - - it('2X3X3 tensor', async () => { - const x: Tensor3D = tensor3d( - [[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]); - expectArraysClose( - await tf.linalg.bandPart(x, 1, 2).array(), - [[[1, 1, 1], [1, 1, 1], [0, 1, 1]], [[1, 1, 1], [1, 1, 1], [0, 1, 1]]]); - }); - - const la = tf.linalg; - - it('fails for scalar', async () => { - const x = scalar(1); - expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); - }); - - it('fails for 1D tensor', async () => { - const x = tensor1d([1, 2, 3, 4, 5]); - expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); - }); - - it('fails if numLower or numUpper too large', async () => { - const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - - for (const numLower of [3, 5, 8, 13]) { - for (const numUpper of [-1, 0, 1, 2]) { - expect(() => tf.linalg.bandPart(a, numLower, numUpper)) - .toThrowError(/bandPart.*numLower/i); - } - } - - for (const numLower of [-1, 0, 1]) { - for (const numUpper of [4, 5, 9]) { - expect(() => tf.linalg.bandPart(a, numLower, numUpper)) - .toThrowError(/bandPart.*numUpper/i); - } - } - - for (const numLower of [3, 5, 8, 13]) { - for (const numUpper of [4, 5, 9]) { - expect(() => tf.linalg.bandPart(a, numLower, numUpper)) - .toThrowError(/bandPart.*(numLower|numUpper)/i); - } - } - }); - - it('works for 3x4 example', async () => { - const a = tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); - - expectArraysClose( - await la.bandPart(a, 0, 0).array(), - [[1, 0, 0, 0], [0, 6, 0, 0], [0, 0, 11, 0]]); - expectArraysClose( - await la.bandPart(a, 0, 1).array(), - [[1, 2, 0, 0], [0, 6, 7, 0], [0, 0, 11, 12]]); - expectArraysClose( - await la.bandPart(a, 0, 2).array(), - [[1, 2, 3, 0], [0, 6, 7, 8], [0, 0, 11, 12]]); - for (const numUpper of [3, 4, -1, -2]) { - expectArraysClose( - await la.bandPart(a, 0, numUpper).array(), - [[1, 2, 3, 4], [0, 6, 7, 8], [0, 0, 11, 12]]); - } - - expectArraysClose( - await la.bandPart(a, 1, 0).array(), - [[1, 0, 0, 0], [5, 6, 0, 0], [0, 10, 11, 0]]); - expectArraysClose( - await la.bandPart(a, 1, 1).array(), - [[1, 2, 0, 0], [5, 6, 7, 0], [0, 10, 11, 12]]); - expectArraysClose( - await la.bandPart(a, 1, 2).array(), - [[1, 2, 3, 0], [5, 6, 7, 8], [0, 10, 11, 12]]); - for (const numUpper of [3, 4, -1, -2]) { - expectArraysClose( - await la.bandPart(a, 1, numUpper).array(), - [[1, 2, 3, 4], [5, 6, 7, 8], [0, 10, 11, 12]]); - } - - for (const numLower of [2, 3, -1, -2]) { - expectArraysClose( - await la.bandPart(a, numLower, 0).array(), - [[1, 0, 0, 0], [5, 6, 0, 0], [9, 10, 11, 0]]); - expectArraysClose( - await la.bandPart(a, numLower, 1).array(), - [[1, 2, 0, 0], [5, 6, 7, 0], [9, 10, 11, 12]]); - expectArraysClose( - await la.bandPart(a, numLower, 2).array(), - [[1, 2, 3, 0], [5, 6, 7, 8], [9, 10, 11, 12]]); - for (const numUpper of [3, 4, -1, -2]) { - expectArraysClose( - await la.bandPart(a, numLower, numUpper).array(), - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); - } - } - }); -}); // end bandPart - -describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { - it('2x2, Array of Tensor1D', async () => { - const xs: Tensor1D[] = [ - tf.randomNormal([2], 0, 1, 'float32', 1), - tf.randomNormal([2], 0, 1, 'float32', 2) - ]; - const ys = tf.linalg.gramSchmidt(xs) as Tensor1D[]; - const y = tf.stack(ys) as Tensor2D; - // Test that the results are orthogonalized and normalized. - expectArraysClose( - await y.transpose().matMul(y).array(), await tf.eye(2).array()); - // Test angle between xs[0] and ys[0] is zero, i.e., the orientation of the - // first vector is kept. - expectArraysClose( - await tf.sum(xs[0].mul(ys[0])).array(), - await tf.norm(xs[0]).mul(tf.norm(ys[0])).array()); - }); - - it('3x3, Array of Tensor1D', async () => { - const xs: Tensor1D[] = [ - tf.randomNormal([3], 0, 1, 'float32', 1), - tf.randomNormal([3], 0, 1, 'float32', 2), - tf.randomNormal([3], 0, 1, 'float32', 3) - ]; - const ys = tf.linalg.gramSchmidt(xs) as Tensor1D[]; - const y = tf.stack(ys) as Tensor2D; - expectArraysClose( - await y.transpose().matMul(y).array(), await tf.eye(3).array()); - expectArraysClose( - await tf.sum(xs[0].mul(ys[0])).array(), - await tf.norm(xs[0]).mul(tf.norm(ys[0])).array()); - }); - - it('3x3, Matrix', async () => { - const xs: Tensor2D = tf.randomNormal([3, 3], 0, 1, 'float32', 1); - const y = tf.linalg.gramSchmidt(xs) as Tensor2D; - expectArraysClose( - await y.transpose().matMul(y).array(), await tf.eye(3).array()); - }); - - it('2x3, Matrix', async () => { - const xs: Tensor2D = tf.randomNormal([2, 3], 0, 1, 'float32', 1); - const y = tf.linalg.gramSchmidt(xs) as Tensor2D; - const yT: Tensor2D = y.transpose(); - expectArraysClose(await y.matMul(yT).array(), await tf.eye(2).array()); - }); - - it('3x2 Matrix throws Error', () => { - const xs = tf.tensor2d([[1, 2], [3, -1], [5, 1]]); - expect(() => tf.linalg.gramSchmidt(xs)) - .toThrowError( - /Number of vectors \(3\) exceeds number of dimensions \(2\)/); - }); - - it('Mismatching dimensions input throws Error', () => { - const xs: Tensor1D[] = - [tf.tensor1d([1, 2, 3]), tf.tensor1d([-1, 5, 1]), tf.tensor1d([0, 0])]; - - expect(() => tf.linalg.gramSchmidt(xs)).toThrowError(/Non-unique/); - }); - - it('Empty input throws Error', () => { - expect(() => tf.linalg.gramSchmidt([])).toThrowError(/empty/); - }); -}); - -describeWithFlags('qr', ALL_ENVS, () => { - it('1x1', async () => { - const x = tensor2d([[10]], [1, 1]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.array(), [[-1]]); - expectArraysClose(await r.array(), [[-10]]); - }); - - it('2x2', async () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.array(), [[-0.4472, -0.8944], [0.8944, -0.4472]]); - expectArraysClose(await r.array(), [[-2.2361, -4.9193], [0, -0.8944]]); - }); - - it('2x2x2', async () => { - const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.array(), [ - [[-0.4472, -0.8944], [0.8944, -0.4472]], - [[-0.4472, -0.8944], [0.8944, -0.4472]] - ]); - expectArraysClose( - await r.array(), - [[[2.2361, 4.9193], [0, 0.8944]], [[-2.2361, -4.9193], [0, -0.8944]]]); - }); - - it('2x1x2x2', async () => { - const x = - tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.array(), [ - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - ]); - expectArraysClose(await r.array(), [ - [[[2.2361, 4.9193], [0, 0.8944]]], [[[-2.2361, -4.9193], [0, -0.8944]]] - ]); - }); - - it('3x3', async () => { - const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.array(), [ - [-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669], - [-0.9631, -0.2182, 0.1576] - ]); - expectArraysClose( - await r.array(), - [[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]]); - }); - - it('3x3, zero on diagonal', async () => { - const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(await q.data(), [ - [0., -0.89442719, 0.4472136], [1., 0., 0.], [0., -0.4472136, -0.89442719] - ]); - expectArraysClose( - await r.data(), - [[1., 1., 1.], [0., -2.23606798, -2.68328157], [0., 0., -0.89442719]]); - }); - - it('3x2, fullMatrices = default false', async () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - await q.array(), - [[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]]); - expectArraysClose(await r.array(), [[-3.7417, 2.4054], [0, 2.8661]]); - }); - - it('3x2, fullMatrices = true', async () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose(await q.array(), [ - [-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663], - [0.5345, -0.0997, 0.8393] - ]); - expectArraysClose( - await r.array(), [[-3.7417, 2.4054], [0, 2.8661], [0, 0]]); - }); - - it('2x3, fullMatrices = default false', async () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - await q.array(), [[-0.3162278, -0.9486833], [0.9486833, -0.31622773]]); - expectArraysClose( - await r.array(), - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]]); - }); - - it('2x3, fullMatrices = true', async () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - await q.array(), [[-0.3162278, -0.9486833], [0.9486833, -0.31622773]]); - expectArraysClose( - await r.array(), - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]]); - }); - - it('Does not leak memory', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); - // The first call to qr creates and keeps internal singleton tensors. - // Subsequent calls should always create exactly two tensors. - tf.linalg.qr(x); - // Count before real call. - const numTensors = tf.memory().numTensors; - tf.linalg.qr(x); - expect(tf.memory().numTensors).toEqual(numTensors + 2); - }); - - it('Insuffient input tensor rank leads to error', () => { - const x1 = scalar(12); - expect(() => tf.linalg.qr(x1)).toThrowError(/rank >= 2.*got rank 0/); - const x2 = tensor1d([12]); - expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/); - }); -}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 8d88529c810..5e6e31d579d 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -125,7 +125,6 @@ export * from './in_top_k'; export {op} from './operation'; import * as losses from './loss_ops'; -import * as linalg from './linalg_ops'; import * as spectral from './spectral_ops'; import * as fused from './fused_ops'; import * as signal from './signal_ops'; @@ -148,5 +147,15 @@ const image = { nonMaxSuppressionWithScoreAsync }; +// linalg namespace +import {bandPart} from './band_part'; +import {gramSchmidt} from './gram_schmidt'; +import {qr} from './qr'; +const linalg = { + bandPart, + gramSchmidt, + qr +}; + // Second level exports. export {image, linalg, losses, spectral, fused, signal}; diff --git a/tfjs-core/src/ops/qr.ts b/tfjs-core/src/ops/qr.ts new file mode 100644 index 00000000000..fbbddb10264 --- /dev/null +++ b/tfjs-core/src/ops/qr.ts @@ -0,0 +1,200 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {dispose} from '../globals'; +import {Tensor, Tensor2D} from '../tensor'; +import {assert} from '../util'; + +import {reshape, stack, unstack} from './array_ops'; +import {clone} from './clone'; +import {concat} from './concat'; +import {div} from './div'; +import {eye} from './eye'; +import {greater} from './greater'; +import {where} from './logical_ops'; +import {matMul} from './mat_mul'; +import {mul} from './mul'; +import {norm} from './norm'; +import {op} from './operation'; +import {slice} from './slice'; +import {sub} from './sub'; +import {tensor2d} from './tensor_ops'; +import {transpose} from './transpose'; +import {neg} from './unary_ops'; + +/** + * Compute QR decomposition of m-by-n matrix using Householder transformation. + * + * Implementation based on + * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] + * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + * + * ```js + * const a = tf.tensor2d([[1, 2], [3, 4]]); + * let [q, r] = tf.linalg.qr(a); + * console.log('Q'); + * q.print(); + * console.log('R'); + * r.print(); + * console.log('Orthogonalized'); + * q.dot(q.transpose()).print() // should be nearly the identity matrix. + * console.log('Reconstructed'); + * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; + * ``` + * + * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose + * it has the shape `[..., M, N]`. + * @param fullMatrices An optional boolean parameter. Defaults to `false`. + * If `true`, compute full-sized `Q`. If `false` (the default), + * compute only the leading N columns of `Q` and `R`. + * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix, + * i.e., its columns all have unit norm and are mutually orthogonal. + * If `M >= N`, + * If `fullMatrices` is `false` (default), + * - `Q` has a shape of `[..., M, N]`, + * - `R` has a shape of `[..., N, N]`. + * If `fullMatrices` is `true` (default), + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * If `M < N`, + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * @throws If the rank of `x` is less than 2. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { + assert( + x.rank >= 2, + () => `qr() requires input tensor to have a rank >= 2, but got rank ${ + x.rank}`); + + if (x.rank === 2) { + return qr2d(x as Tensor2D, fullMatrices); + } else { + // Rank > 2. + // TODO(cais): Below we split the input into individual 2D tensors, + // perform QR decomposition on them and then stack the results back + // together. We should explore whether this can be parallelized. + const outerDimsProd = x.shape.slice(0, x.shape.length - 2) + .reduce((value, prev) => value * prev); + const x2ds = unstack( + reshape( + x, + [ + outerDimsProd, x.shape[x.shape.length - 2], + x.shape[x.shape.length - 1] + ]), + 0); + const q2ds: Tensor2D[] = []; + const r2ds: Tensor2D[] = []; + x2ds.forEach(x2d => { + const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices); + q2ds.push(q2d); + r2ds.push(r2d); + }); + const q = reshape(stack(q2ds, 0), x.shape); + const r = reshape(stack(r2ds, 0), x.shape); + return [q, r]; + } +} + +function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { + return ENGINE.tidy(() => { + assert( + x.shape.length === 2, + () => `qr2d() requires a 2D Tensor, but got a ${ + x.shape.length}D Tensor.`); + + const m = x.shape[0]; + const n = x.shape[1]; + + let q = eye(m); // Orthogonal transform so far. + let r = clone(x); // Transformed matrix so far. + + const one2D = tensor2d([[1]], [1, 1]); + let w: Tensor2D = clone(one2D); + + const iters = m >= n ? n : m; + for (let j = 0; j < iters; ++j) { + // This tidy within the for-loop ensures we clean up temporary + // tensors as soon as they are no longer needed. + const rTemp = r; + const wTemp = w; + const qTemp = q; + [w, r, q] = ENGINE.tidy((): [Tensor2D, Tensor2D, Tensor2D] => { + // Find H = I - tau * w * w', to put zeros below R(j, j). + const rjEnd1 = slice(r, [j, j], [m - j, 1]); + const normX = norm(rjEnd1); + const rjj = slice(r, [j, j], [1, 1]); + + // The sign() function returns 0 on 0, which causes division by zero. + const s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]])); + + const u1 = sub(rjj, mul(s, normX)); + const wPre = div(rjEnd1, u1); + if (wPre.shape[0] === 1) { + w = clone(one2D); + } else { + w = concat( + [ + one2D, + slice(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as + Tensor2D + ], + 0); + } + const tau = neg(div(matMul(s, u1), normX)) as Tensor2D; + + // -- R := HR, Q := QH. + const rjEndAll = slice(r, [j, 0], [m - j, n]); + const tauTimesW: Tensor2D = mul(tau, w); + const wT: Tensor2D = transpose(w); + if (j === 0) { + r = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll))); + } else { + const rTimesTau: Tensor2D = + sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll))); + r = concat([slice(r, [0, 0], [j, n]), rTimesTau], 0); + } + const tawTimesWT: Tensor2D = transpose(tauTimesW); + const qAllJEnd = slice(q, [0, j], [m, q.shape[1] - j]); + if (j === 0) { + q = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT)); + } else { + const qTimesTau: Tensor2D = + sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT)); + q = concat([slice(q, [0, 0], [m, j]), qTimesTau], 1); + } + return [w, r, q]; + }); + dispose([rTemp, wTemp, qTemp]); + } + + if (!fullMatrices && m > n) { + q = slice(q, [0, 0], [m, n]); + r = slice(r, [0, 0], [n, n]); + } + + return [q, r]; + }) as [Tensor2D, Tensor2D]; +} + +export const qr = op({qr_}); diff --git a/tfjs-core/src/ops/qr_test.ts b/tfjs-core/src/ops/qr_test.ts new file mode 100644 index 00000000000..f5b0b424f4a --- /dev/null +++ b/tfjs-core/src/ops/qr_test.ts @@ -0,0 +1,144 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; + +describeWithFlags('qr', ALL_ENVS, () => { + it('1x1', async () => { + const x = tensor2d([[10]], [1, 1]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.array(), [[-1]]); + expectArraysClose(await r.array(), [[-10]]); + }); + + it('2x2', async () => { + const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.array(), [[-0.4472, -0.8944], [0.8944, -0.4472]]); + expectArraysClose(await r.array(), [[-2.2361, -4.9193], [0, -0.8944]]); + }); + + it('2x2x2', async () => { + const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.array(), [ + [[-0.4472, -0.8944], [0.8944, -0.4472]], + [[-0.4472, -0.8944], [0.8944, -0.4472]] + ]); + expectArraysClose( + await r.array(), + [[[2.2361, 4.9193], [0, 0.8944]], [[-2.2361, -4.9193], [0, -0.8944]]]); + }); + + it('2x1x2x2', async () => { + const x = + tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.array(), [ + [[[-0.4472, -0.8944], [0.8944, -0.4472]]], + [[[-0.4472, -0.8944], [0.8944, -0.4472]]], + ]); + expectArraysClose(await r.array(), [ + [[[2.2361, 4.9193], [0, 0.8944]]], [[[-2.2361, -4.9193], [0, -0.8944]]] + ]); + }); + + it('3x3', async () => { + const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.array(), [ + [-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669], + [-0.9631, -0.2182, 0.1576] + ]); + expectArraysClose( + await r.array(), + [[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]]); + }); + + it('3x3, zero on diagonal', async () => { + const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose(await q.data(), [ + [0., -0.89442719, 0.4472136], [1., 0., 0.], [0., -0.4472136, -0.89442719] + ]); + expectArraysClose( + await r.data(), + [[1., 1., 1.], [0., -2.23606798, -2.68328157], [0., 0., -0.89442719]]); + }); + + it('3x2, fullMatrices = default false', async () => { + const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose( + await q.array(), + [[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]]); + expectArraysClose(await r.array(), [[-3.7417, 2.4054], [0, 2.8661]]); + }); + + it('3x2, fullMatrices = true', async () => { + const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); + const [q, r] = tf.linalg.qr(x, true); + expectArraysClose(await q.array(), [ + [-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663], + [0.5345, -0.0997, 0.8393] + ]); + expectArraysClose( + await r.array(), [[-3.7417, 2.4054], [0, 2.8661], [0, 0]]); + }); + + it('2x3, fullMatrices = default false', async () => { + const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose( + await q.array(), [[-0.3162278, -0.9486833], [0.9486833, -0.31622773]]); + expectArraysClose( + await r.array(), + [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]]); + }); + + it('2x3, fullMatrices = true', async () => { + const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); + const [q, r] = tf.linalg.qr(x, true); + expectArraysClose( + await q.array(), [[-0.3162278, -0.9486833], [0.9486833, -0.31622773]]); + expectArraysClose( + await r.array(), + [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]]); + }); + + it('Does not leak memory', () => { + const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); + // The first call to qr creates and keeps internal singleton tensors. + // Subsequent calls should always create exactly two tensors. + tf.linalg.qr(x); + // Count before real call. + const numTensors = tf.memory().numTensors; + tf.linalg.qr(x); + expect(tf.memory().numTensors).toEqual(numTensors + 2); + }); + + it('Insuffient input tensor rank leads to error', () => { + const x1 = scalar(12); + expect(() => tf.linalg.qr(x1)).toThrowError(/rank >= 2.*got rank 0/); + const x2 = tensor1d([12]); + expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/); + }); +}); diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 47510b52c88..04b0e1c054d 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -45,6 +45,7 @@ import './ops/array_ops_test'; import './ops/avg_pool_3d_test'; import './ops/avg_pool_test'; import './ops/axis_util_test'; +import './ops/band_part_test'; import './ops/batch_to_space_nd_test'; import './ops/batchnorm_test'; import './ops/binary_ops_test'; @@ -75,6 +76,7 @@ import './ops/equal_test'; import './ops/eye_test'; import './ops/fused_test'; import './ops/gather_nd_test'; +import './ops/gram_schmidt_test'; import './ops/greater_equal_test'; import './ops/greater_test'; import './ops/image_ops_test'; @@ -82,7 +84,6 @@ import './ops/in_top_k_test'; import './ops/leaky_relu_test'; import './ops/less_equal_test'; import './ops/less_test'; -import './ops/linalg_ops_test'; import './ops/local_response_normalization_test'; import './ops/logical_ops_test'; import './ops/loss_ops_test'; @@ -98,6 +99,7 @@ import './ops/one_hot_test'; import './ops/operation_test'; import './ops/pad_test'; import './ops/pool_test'; +import './ops/qr_test'; import './ops/rand_test'; import './ops/random_gamma_test'; import './ops/random_normal_test';