diff --git a/tfjs-core/src/ops/norm.ts b/tfjs-core/src/ops/norm.ts index e65b7254c7d..5e8e9b557b5 100644 --- a/tfjs-core/src/ops/norm.ts +++ b/tfjs-core/src/ops/norm.ts @@ -20,9 +20,15 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; +import {reshape} from './array_ops'; import * as axis_util from './axis_util'; +import {max} from './max'; import {op} from './operation'; +import {pow} from './pow'; +import {min, sum} from './reduction_ops'; +import {square} from './square'; import {scalar} from './tensor_ops'; +import {abs, sqrt} from './unary_ops'; /** * Computes the norm of scalar, vectors, and matrices. @@ -78,29 +84,29 @@ function norm_( function normImpl( x: Tensor, p: number|string, axis: number|number[] = null): Tensor { if (x.rank === 0) { - return x.abs(); + return abs(x); } // consider vector when no axis is specified if (x.rank !== 1 && axis === null) { - return normImpl(x.reshape([-1]), p, axis); + return normImpl(reshape(x, [-1]), p, axis); } // vector if (x.rank === 1 || typeof axis === 'number' || Array.isArray(axis) && axis.length === 1) { if (p === 1) { - return x.abs().sum(axis); + return sum(abs(x), axis); } if (p === Infinity) { - return x.abs().max(axis); + return max(abs(x), axis); } if (p === -Infinity) { - return x.abs().min(axis); + return min(abs(x), axis); } if (p === 'euclidean' || p === 2) { // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2 - return x.abs().pow(scalar(2, 'int32')).sum(axis).sqrt(); + return sqrt(sum(pow(abs(x), scalar(2, 'int32')), axis)); } throw new Error(`Error in norm: invalid ord value: ${p}`); @@ -109,17 +115,17 @@ function normImpl( // matrix (assumption axis[0] < axis[1]) if (Array.isArray(axis) && axis.length === 2) { if (p === 1) { - return x.abs().sum(axis[0]).max(axis[1] - 1); + return max(sum(abs(x), axis[0]), axis[1] - 1); } if (p === Infinity) { - return x.abs().sum(axis[1]).max(axis[0]); + return max(sum(abs(x), axis[1]), axis[0]); } if (p === -Infinity) { - return x.abs().sum(axis[1]).min(axis[0]); + return min(sum(abs(x), axis[1]), axis[0]); } if (p === 'fro' || p === 'euclidean') { // norm(x) = sqrt(sum(pow(x, 2))) - return x.square().sum(axis).sqrt(); + return sqrt(sum(square(x), axis)); } throw new Error(`Error in norm: invalid ord value: ${p}`);