From 2236b530606cb4da9412dd65e186815ca0baa355 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 11 May 2020 10:19:55 -0700 Subject: [PATCH] Resolve merge conflict. --- tfjs-core/src/gradients/BatchMatMul_grad.ts | 53 +++ tfjs-core/src/kernel_names.ts | 7 + tfjs-core/src/ops/compare_ops_test.ts | 336 ------------------ tfjs-core/src/ops/dot.ts | 81 +++++ tfjs-core/src/ops/fused_ops.ts | 2 +- tfjs-core/src/ops/mat_mul.ts | 110 ++++++ .../ops/{matmul_test.ts => mat_mul_test.ts} | 0 tfjs-core/src/ops/matmul.ts | 193 ---------- tfjs-core/src/ops/ops.ts | 4 +- tfjs-core/src/ops/outer_product.ts | 54 +++ tfjs-core/src/public/chained_ops/dot.ts | 30 ++ tfjs-core/src/public/chained_ops/mat_mul.ts | 32 ++ .../chained_ops/register_all_chained_ops.ts | 2 + .../register_all_chained_ops_test.ts | 2 + tfjs-core/src/register_all_gradients.ts | 32 +- tfjs-core/src/tensor.ts | 13 - tfjs-core/src/tests.ts | 2 +- 17 files changed, 398 insertions(+), 555 deletions(-) create mode 100644 tfjs-core/src/gradients/BatchMatMul_grad.ts create mode 100644 tfjs-core/src/ops/dot.ts create mode 100644 tfjs-core/src/ops/mat_mul.ts rename tfjs-core/src/ops/{matmul_test.ts => mat_mul_test.ts} (100%) delete mode 100644 tfjs-core/src/ops/matmul.ts create mode 100644 tfjs-core/src/ops/outer_product.ts create mode 100644 tfjs-core/src/public/chained_ops/dot.ts create mode 100644 tfjs-core/src/public/chained_ops/mat_mul.ts diff --git a/tfjs-core/src/gradients/BatchMatMul_grad.ts b/tfjs-core/src/gradients/BatchMatMul_grad.ts new file mode 100644 index 00000000000..231530b73b1 --- /dev/null +++ b/tfjs-core/src/gradients/BatchMatMul_grad.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {BatchMatMul, BatchMatMulAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {matMul} from '../ops/mat_mul'; +import {Tensor, Tensor3D} from '../tensor'; + +export const batchMatMulGradConfig: GradConfig = { + kernelName: BatchMatMul, + inputsToSave: ['a', 'b'], + gradFunc: (dy: Tensor3D, saved: Tensor[], attrs: NamedAttrMap) => { + const [a, b] = saved as Tensor3D[]; + + const {transposeA, transposeB} = attrs as {} as BatchMatMulAttrs; + + if (!transposeA && !transposeB) { + return { + a: () => matMul(dy, b, false, true), + b: () => matMul(a, dy, true, false) + }; + } else if (!transposeA && transposeB) { + return { + a: () => matMul(dy, b, false, false), + b: () => matMul(dy, a, true, false) + }; + } else if (transposeA && !transposeB) { + return { + a: () => matMul(b, dy, false, true), + b: () => matMul(a, dy, false, false) + }; + } else { + return { + a: () => matMul(b, dy, true, true), + b: () => matMul(dy, a, true, true) + }; + } + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 027c565ccb1..90a08befa5a 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -27,6 +27,13 @@ export type AddInputs = BinaryInputs; export const AddN = 'AddN'; export type AddNInputs = TensorInfo[]; +export const BatchMatMul = 'BatchMatMul'; +export type BatchMatMulInputs = Pick; +export interface BatchMatMulAttrs { + transposeA: boolean; + transposeB: boolean; +} + export type BinaryInputs = Pick; export const BroadcastTo = 'BroadcastTo'; diff --git a/tfjs-core/src/ops/compare_ops_test.ts b/tfjs-core/src/ops/compare_ops_test.ts index d56bc652ed9..edbe1b6a61a 100644 --- a/tfjs-core/src/ops/compare_ops_test.ts +++ b/tfjs-core/src/ops/compare_ops_test.ts @@ -694,342 +694,6 @@ describeWithFlags('greaterStrict', ALL_ENVS, () => { }); }); -describeWithFlags('greaterEqual', ALL_ENVS, () => { - // Tensor1D: - it('Tensor1D - int32', async () => { - let a = tf.tensor1d([1, 4, 5], 'int32'); - let b = tf.tensor1d([2, 3, 5], 'int32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 1]); - - a = tf.tensor1d([2, 2, 2], 'int32'); - b = tf.tensor1d([2, 2, 2], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1]); - - a = tf.tensor1d([0, 0], 'int32'); - b = tf.tensor1d([3, 3], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 0]); - }); - it('Tensor1D - float32', async () => { - let a = tf.tensor1d([1.1, 4.1, 5.1], 'float32'); - let b = tf.tensor1d([2.2, 3.2, 5.1], 'float32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 1]); - - a = tf.tensor1d([2.31, 2.31, 2.31], 'float32'); - b = tf.tensor1d([2.31, 2.31, 2.31], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1]); - - a = tf.tensor1d([0.45, 0.123], 'float32'); - b = tf.tensor1d([3.123, 3.321], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 0]); - }); - - it('upcasts when dtypes dont match', async () => { - const a = [1.1, 4.1, 5]; - const b = [2.2, 3.2, 5]; - - let res = tf.greaterEqual( - tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); - expect(res.dtype).toBe('bool'); - expect(res.shape).toEqual([3]); - expectArraysClose(await res.data(), [0, 1, 1]); - - res = - tf.greaterEqual(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); - expect(res.dtype).toBe('bool'); - expect(res.shape).toEqual([3]); - expectArraysClose(await res.data(), [1, 1, 1]); - }); - - it('mismatched Tensor1D shapes - int32', () => { - const a = tf.tensor1d([1, 2], 'int32'); - const b = tf.tensor1d([1, 2, 3], 'int32'); - const f = () => { - tf.greaterEqual(a, b); - }; - expect(f).toThrowError(); - }); - it('mismatched Tensor1D shapes - float32', () => { - const a = tf.tensor1d([1.1, 2.1], 'float32'); - const b = tf.tensor1d([1.1, 2.1, 3.1], 'float32'); - const f = () => { - tf.greaterEqual(a, b); - }; - expect(f).toThrowError(); - }); - it('NaNs in Tensor1D - float32', async () => { - const a = tf.tensor1d([1.1, NaN, 2.1], 'float32'); - const b = tf.tensor1d([2.1, 3.1, NaN], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 0, 0]); - }); - - // Tensor2D: - it('Tensor2D - int32', async () => { - let a = tf.tensor2d([[1, 4, 5], [8, 9, 12]], [2, 3], 'int32'); - let b = tf.tensor2d([[2, 3, 6], [7, 10, 11]], [2, 3], 'int32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1, 0, 1]); - - a = tf.tensor2d([[0, 0], [1, 1]], [2, 2], 'int32'); - b = tf.tensor2d([[0, 0], [1, 1]], [2, 2], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1]); - }); - it('Tensor2D - float32', async () => { - let a = tf.tensor2d([[1.1, 4.1, 5.1], [8.1, 9.1, 12.1]], [2, 3], 'float32'); - let b = - tf.tensor2d([[2.1, 3.1, 6.1], [7.1, 10.1, 11.1]], [2, 3], 'float32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1, 0, 1]); - - a = tf.tensor2d([[0.2, 0.2], [1.2, 1.2]], [2, 2], 'float32'); - b = tf.tensor2d([[0.2, 0.2], [1.2, 1.2]], [2, 2], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1]); - }); - it('broadcasting Tensor2D shapes - int32', async () => { - const a = tf.tensor2d([[3], [7]], [2, 1], 'int32'); - const b = tf.tensor2d([[2, 3, 4], [7, 8, 9]], [2, 3], 'int32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 0, 1, 0, 0]); - }); - it('broadcasting Tensor2D shapes - float32', async () => { - const a = tf.tensor2d([[1.1], [7.1]], [2, 1], 'float32'); - const b = - tf.tensor2d([[0.1, 1.1, 2.1], [7.1, 8.1, 9.1]], [2, 3], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 0, 1, 0, 0]); - }); - it('NaNs in Tensor2D - float32', async () => { - const a = tf.tensor2d([[1.1, NaN], [0.1, NaN]], [2, 2], 'float32'); - const b = tf.tensor2d([[0.1, NaN], [1.1, NaN]], [2, 2], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 0, 0]); - }); - - // Tensor3D: - it('Tensor3D - int32', async () => { - let a = - tf.tensor3d([[[1], [4], [5]], [[8], [9], [12]]], [2, 3, 1], 'int32'); - let b = - tf.tensor3d([[[2], [3], [6]], [[7], [10], [11]]], [2, 3, 1], 'int32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1, 0, 1]); - - a = tf.tensor3d([[[0], [0], [0]], [[1], [1], [1]]], [2, 3, 1], 'int32'); - b = tf.tensor3d([[[0], [0], [0]], [[1], [1], [1]]], [2, 3, 1], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1, 1, 1]); - }); - it('Tensor3D - float32', async () => { - let a = tf.tensor3d( - [[[1.1], [4.1], [5.1]], [[8.1], [9.1], [12.1]]], [2, 3, 1], 'float32'); - let b = tf.tensor3d( - [[[2.1], [3.1], [6.1]], [[7.1], [10.1], [11.1]]], [2, 3, 1], 'float32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1, 0, 1]); - - a = tf.tensor3d( - [[[0.1], [0.1], [0.1]], [[1.1], [1.1], [1.2]]], [2, 3, 1], 'float32'); - b = tf.tensor3d( - [[[0.1], [0.1], [0.1]], [[1.1], [1.1], [1.1]]], [2, 3, 1], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1, 1, 1]); - }); - it('broadcasting Tensor3D shapes - int32', async () => { - const a = tf.tensor3d( - [[[1, 0], [2, 3], [4, 5]], [[6, 7], [9, 8], [10, 11]]], [2, 3, 2], - 'int32'); - const b = - tf.tensor3d([[[1], [2], [3]], [[7], [10], [9]]], [2, 3, 1], 'int32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]); - }); - it('broadcasting Tensor3D shapes - float32', async () => { - const a = tf.tensor3d( - [ - [[1.1, 0.1], [2.1, 3.1], [4.1, 5.1]], - [[6.1, 7.1], [9.1, 8.1], [10.1, 11.1]] - ], - [2, 3, 2], 'float32'); - const b = tf.tensor3d( - [[[1.1], [2.1], [3.1]], [[7.1], [10.1], [9.1]]], [2, 3, 1], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]); - }); - it('NaNs in Tensor3D - float32', async () => { - const a = tf.tensor3d( - [[[1.1], [NaN], [1.1]], [[0.1], [0.1], [0.1]]], [2, 3, 1], 'float32'); - const b = tf.tensor3d( - [[[0.1], [0.1], [1.1]], [[1.1], [0.1], [NaN]]], [2, 3, 1], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 1, 0, 1, 0]); - }); - - // Tensor4D: - it('Tensor4D - int32', async () => { - let a = tf.tensor4d([1, 4, 5, 8], [2, 2, 1, 1], 'int32'); - let b = tf.tensor4d([2, 3, 6, 7], [2, 2, 1, 1], 'int32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1]); - - a = tf.tensor4d([0, 1, 2, 3], [2, 2, 1, 1], 'int32'); - b = tf.tensor4d([0, 1, 2, 3], [2, 2, 1, 1], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1]); - - a = tf.tensor4d([1, 1, 1, 1], [2, 2, 1, 1], 'int32'); - b = tf.tensor4d([2, 2, 2, 2], [2, 2, 1, 1], 'int32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 0, 0, 0]); - }); - it('Tensor4D - float32', async () => { - let a = tf.tensor4d([1.1, 4.1, 5.1, 8.1], [2, 2, 1, 1], 'float32'); - let b = tf.tensor4d([2.1, 3.1, 6.1, 7.1], [2, 2, 1, 1], 'float32'); - let res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 0, 1]); - - a = tf.tensor4d([0.1, 1.1, 2.2, 3.3], [2, 2, 1, 1], 'float32'); - b = tf.tensor4d([0.1, 1.1, 2.2, 3.3], [2, 2, 1, 1], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 1, 1, 1]); - - a = tf.tensor4d([0.1, 0.1, 0.1, 0.1], [2, 2, 1, 1], 'float32'); - b = tf.tensor4d([1.1, 1.1, 1.1, 1.1], [2, 2, 1, 1], 'float32'); - res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 0, 0, 0]); - }); - it('broadcasting Tensor4D shapes - int32', async () => { - const a = tf.tensor4d([1, 2, 5, 9], [2, 2, 1, 1], 'int32'); - const b = tf.tensor4d( - [[[[1, 2]], [[3, 4]]], [[[5, 6]], [[7, 8]]]], [2, 2, 1, 2], 'int32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 0, 0, 1, 0, 1, 1]); - }); - it('broadcasting Tensor4D shapes - float32', async () => { - const a = tf.tensor4d([1.1, 2.1, 5.1, 9.1], [2, 2, 1, 1], 'float32'); - const b = tf.tensor4d( - [[[[1.1, 2.1]], [[3.1, 4.1]]], [[[5.1, 6.1]], [[7.1, 8.1]]]], - [2, 2, 1, 2], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 0, 0, 1, 0, 1, 1]); - }); - it('NaNs in Tensor4D - float32', async () => { - const a = tf.tensor4d([1.1, NaN, 0.1, 0.1], [2, 2, 1, 1], 'float32'); - const b = tf.tensor4d([0.1, 1.1, 1.1, NaN], [2, 2, 1, 1], 'float32'); - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [1, 0, 0, 0]); - }); - - it('throws when passed a as a non-tensor', () => { - expect(() => tf.greaterEqual({} as tf.Tensor, tf.scalar(1))) - .toThrowError(/Argument 'a' passed to 'greaterEqual' must be a Tensor/); - }); - it('throws when passed b as a non-tensor', () => { - expect(() => tf.greaterEqual(tf.scalar(1), {} as tf.Tensor)) - .toThrowError(/Argument 'b' passed to 'greaterEqual' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const a = [1, 4, 5]; - const b = [2, 3, 5]; - const res = tf.greaterEqual(a, b); - - expect(res.dtype).toBe('bool'); - expectArraysClose(await res.data(), [0, 1, 1]); - }); - - it('has gradient', async () => { - const a = tf.tensor1d([3, 2, 5]); - const b = tf.tensor1d([4, 1, 5]); - const dy = tf.ones([3], 'float32'); - const da = tf.grad((a: tf.Tensor1D) => tf.greaterEqual(a, b))(a, dy); - - expect(da.dtype).toBe('float32'); - expect(da.shape).toEqual([3]); - expectArraysClose(await da.data(), [0, 0, 0]); - }); - - it('gradient with clones', async () => { - const a = tf.tensor1d([3, 2, 5]); - const b = tf.tensor1d([4, 1, 5]); - const dy = tf.ones([3], 'float32'); - const da = tf.grad( - (a: tf.Tensor1D) => tf.greaterEqual(a.clone(), b.clone()).clone())( - a, dy); - - expect(da.dtype).toBe('float32'); - expect(da.shape).toEqual([3]); - expectArraysClose(await da.data(), [0, 0, 0]); - }); -}); - describeWithFlags('greaterEqualStrict', ALL_ENVS, () => { it('Tensor1D - strict version throws when a and b are different shape', () => { diff --git a/tfjs-core/src/ops/dot.ts b/tfjs-core/src/ops/dot.ts new file mode 100644 index 00000000000..162b98d2aab --- /dev/null +++ b/tfjs-core/src/ops/dot.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor,} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {reshape} from './array_ops'; +import {matMul} from './mat_mul'; +import {op} from './operation'; + +/** + * Computes the dot product of two matrices and/or vectors, `t1` and `t2`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor2d([[1, 2], [3, 4]]); + * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + * + * a.dot(b).print(); // or tf.dot(a, b) + * b.dot(a).print(); + * b.dot(c).print(); + * ``` + * @param t1 The first tensor in the dot operation. + * @param t2 The second tensor in the dot operation. + */ +/** @doc {heading: 'Operations', subheading: 'Matrices'} */ +function dot_(t1: Tensor|TensorLike, t2: Tensor|TensorLike): Tensor { + const $t1 = convertToTensor(t1, 't1', 'dot'); + const $t2 = convertToTensor(t2, 't2', 'dot'); + + util.assert( + ($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), + () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` + + `${$t1.rank} and ${$t2.rank}.`); + + const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]); + const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]); + + util.assert( + t1Inner === t2Inner, + () => `Error in dot: inner dimensions of inputs must match, but got ` + + `${t1Inner} and ${t2Inner}.`); + + if ($t1.rank === 1 && $t2.rank === 1) { + const t12D = reshape($t1, [1, -1]); + const t22D = reshape($t2, [-1, 1]); + const t1t2 = matMul(t12D, t22D); + return reshape(t1t2, []); + } else if ($t1.rank === 1 && $t2.rank === 2) { + const t12D = reshape($t1, [1, -1]); + const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]); + const t1t2 = matMul(t12D, t22D); + return reshape(t1t2, [t1t2.size]); + } else if ($t1.rank === 2 && $t2.rank === 1) { + const t22D = reshape($t2, [-1, 1]); + const t1t2 = matMul($t1, t22D); + return reshape(t1t2, [t1t2.size]); + } else { + const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]); + const t1t2 = matMul($t1, t22D); + return t1t2; + } +} + +export const dot = op({dot_}); diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index 14931598441..54d82bc350f 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -33,7 +33,7 @@ import {depthwiseConv2d as unfusedDepthwiseConv2d} from './depthwise_conv2d'; import {depthwiseConv2dNativeBackpropFilter} from './depthwise_conv2d_native_backprop_filter'; import {depthwiseConv2dNativeBackpropInput} from './depthwise_conv2d_native_backprop_input'; import {Activation, shouldFuse} from './fused_util'; -import {matMul as unfusedMatMul} from './matmul'; +import {matMul as unfusedMatMul} from './mat_mul'; import {elu, prelu, relu, relu6} from './relu_ops'; // Returns gradient for fused activation. diff --git a/tfjs-core/src/ops/mat_mul.ts b/tfjs-core/src/ops/mat_mul.ts new file mode 100644 index 00000000000..b0d834345db --- /dev/null +++ b/tfjs-core/src/ops/mat_mul.ts @@ -0,0 +1,110 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {BatchMatMul, BatchMatMulAttrs, BatchMatMulInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor3D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {reshape} from './array_ops'; +import {op} from './operation'; + +/** + * Computes the dot product of two matrices, A * B. These must be matrices. + * + * ```js + * const a = tf.tensor2d([1, 2], [1, 2]); + * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.matMul(b).print(); // or tf.matMul(a, b) + * ``` + * @param a First matrix in dot product operation. + * @param b Second matrix in dot product operation. + * @param transposeA If true, `a` is transposed before multiplication. + * @param transposeB If true, `b` is transposed before multiplication. + */ +/** @doc {heading: 'Operations', subheading: 'Matrices'} */ +function matMul_( + a: T|TensorLike, b: T|TensorLike, transposeA = false, + transposeB = false): T { + let $a = convertToTensor(a, 'a', 'matMul'); + let $b = convertToTensor(b, 'b', 'matMul'); + [$a, $b] = makeTypesMatch($a, $b); + + util.assert( + $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, + () => `Error in matMul: inputs must have the same rank of at least 2, ` + + `got ranks ${$a.rank} and ${$b.rank}.`); + + const innerShapeA = + transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + const innerShapeB = + transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + + const outerShapeA = + transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + const outerShapeB = + transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + + const outerDimsA = $a.shape.slice(0, -2); + const outerDimsB = $b.shape.slice(0, -2); + const batchDimA = util.sizeFromShape(outerDimsA); + const batchDimB = util.sizeFromShape(outerDimsB); + + util.assert( + util.arraysEqual(outerDimsA, outerDimsB), + () => `Error in matMul: outer dimensions (${outerDimsA}) and (` + + `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` + + `${$b.shape} must match.`); + + util.assert( + innerShapeA === innerShapeB, + () => `Error in matMul: inner shapes (${innerShapeA}) and (` + + `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + + `${$b.shape} and transposeA=${transposeA}` + + ` and transposeB=${transposeB} must match.`); + + const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); + + const a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : + reshape($a, [batchDimA, outerShapeA, innerShapeA]); + const b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : + reshape($b, [batchDimB, innerShapeB, outerShapeB]); + + const forward: ForwardFunc = (backend, save) => { + save([a3D, b3D]); + + return backend.batchMatMul( + a3D as Tensor3D, b3D as Tensor3D, transposeA, transposeB); + }; + + const inputs: BatchMatMulInputs = {a: a3D, b: b3D}; + + const attrs: BatchMatMulAttrs = {transposeA, transposeB}; + + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, BatchMatMul, + attrs as {} as NamedAttrMap); + + return reshape(res, outShape) as T; +} + +export const matMul = op({matMul_}); diff --git a/tfjs-core/src/ops/matmul_test.ts b/tfjs-core/src/ops/mat_mul_test.ts similarity index 100% rename from tfjs-core/src/ops/matmul_test.ts rename to tfjs-core/src/ops/mat_mul_test.ts diff --git a/tfjs-core/src/ops/matmul.ts b/tfjs-core/src/ops/matmul.ts deleted file mode 100644 index a09f8c98801..00000000000 --- a/tfjs-core/src/ops/matmul.ts +++ /dev/null @@ -1,193 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE} from '../engine'; -import {Tensor, Tensor1D, Tensor2D, Tensor3D} from '../tensor'; -import {makeTypesMatch} from '../tensor_util'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import * as util from '../util'; -import {op} from './operation'; - -/** - * Computes the dot product of two matrices, A * B. These must be matrices. - * - * ```js - * const a = tf.tensor2d([1, 2], [1, 2]); - * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * a.matMul(b).print(); // or tf.matMul(a, b) - * ``` - * @param a First matrix in dot product operation. - * @param b Second matrix in dot product operation. - * @param transposeA If true, `a` is transposed before multiplication. - * @param transposeB If true, `b` is transposed before multiplication. - */ -/** @doc {heading: 'Operations', subheading: 'Matrices'} */ -function matMul_( - a: T|TensorLike, b: T|TensorLike, transposeA = false, - transposeB = false): T { - let $a = convertToTensor(a, 'a', 'matMul'); - let $b = convertToTensor(b, 'b', 'matMul'); - [$a, $b] = makeTypesMatch($a, $b); - - const innerShapeA = - transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; - const innerShapeB = - transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; - - const outerShapeA = - transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; - const outerShapeB = - transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; - - const outerDimsA = $a.shape.slice(0, -2); - const outerDimsB = $b.shape.slice(0, -2); - const batchDimA = util.sizeFromShape(outerDimsA); - const batchDimB = util.sizeFromShape(outerDimsB); - - util.assert( - $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, - () => `Error in matMul: inputs must have the same rank of at least 2, ` + - `got ranks ${$a.rank} and ${$b.rank}.`); - - util.assert( - util.arraysEqual(outerDimsA, outerDimsB), - () => `Error in matMul: outer dimensions (${outerDimsA}) and (` + - `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} must match.`); - - util.assert( - innerShapeA === innerShapeB, - () => `Error in matMul: inner shapes (${innerShapeA}) and (` + - `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} and transposeA=${transposeA}` + - ` and transposeB=${transposeB} must match.`); - - const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); - - const a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) : - $a.as3D(batchDimA, outerShapeA, innerShapeA); - const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) : - $b.as3D(batchDimB, innerShapeB, outerShapeB); - - const grad = (dy: Tensor3D, saved: Tensor[]) => { - const [a3D, b3D] = saved as Tensor3D[]; - if (!transposeA && !transposeB) { - return { - a: () => dy.matMul(b3D, false, true), - b: () => a3D.matMul(dy, true, false) - }; - } else if (!transposeA && transposeB) { - return { - a: () => dy.matMul(b3D, false, false), - b: () => dy.matMul(a3D, true, false) - }; - } else if (transposeA && !transposeB) { - return { - a: () => b3D.matMul(dy, false, true), - b: () => a3D.matMul(dy, false, false) - }; - } else { - return { - a: () => b3D.matMul(dy, true, true), - b: () => dy.matMul(a3D, true, true) - }; - } - }; - - const attrs = {transposeA, transposeB}; - const res = ENGINE.runKernelFunc((backend, save) => { - const res = backend.batchMatMul(a3D, b3D, transposeA, transposeB); - save([a3D, b3D]); - return res; - }, {a: a3D, b: b3D}, grad, 'BatchMatMul', attrs); - return res.reshape(outShape) as T; -} - -/** - * Computes the outer product of two vectors, `v1` and `v2`. - * - * ```js - * const a = tf.tensor1d([1, 2, 3]); - * const b = tf.tensor1d([3, 4, 5]); - * - * tf.outerProduct(a, b).print(); - * ``` - * @param v1 The first vector in the outer product operation. - * @param v2 The second vector in the outer product operation. - */ -/** @doc {heading: 'Operations', subheading: 'Matrices'} */ -function outerProduct_( - v1: Tensor1D|TensorLike, v2: Tensor1D|TensorLike): Tensor2D { - const $v1 = convertToTensor(v1, 'v1', 'outerProduct'); - const $v2 = convertToTensor(v2, 'v2', 'outerProduct'); - - util.assert( - $v1.rank === 1 && $v2.rank === 1, - () => `Error in outerProduct: inputs must be rank 1, but got ranks ` + - `${$v1.rank} and ${$v2.rank}.`); - - return $v1.as2D(-1, 1).matMul($v2.as2D(1, -1)); -} - -/** - * Computes the dot product of two matrices and/or vectors, `t1` and `t2`. - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor2d([[1, 2], [3, 4]]); - * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - * - * a.dot(b).print(); // or tf.dot(a, b) - * b.dot(a).print(); - * b.dot(c).print(); - * ``` - * @param t1 The first tensor in the dot operation. - * @param t2 The second tensor in the dot operation. - */ -/** @doc {heading: 'Operations', subheading: 'Matrices'} */ -function dot_(t1: Tensor|TensorLike, t2: Tensor|TensorLike): Tensor { - const $t1 = convertToTensor(t1, 't1', 'dot'); - const $t2 = convertToTensor(t2, 't2', 'dot'); - util.assert( - ($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), - () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` + - `${$t1.rank} and ${$t2.rank}.`); - - const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]); - const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]); - - util.assert( - t1Inner === t2Inner, - () => `Error in dot: inner dimensions of inputs must match, but got ` + - `${t1Inner} and ${t2Inner}.`); - - if ($t1.rank === 1 && $t2.rank === 1) { - return $t1.as2D(1, -1).matMul($t2.as2D(-1, 1)).asScalar(); - } else if ($t1.rank === 1 && $t2.rank === 2) { - return $t1.as2D(1, -1).matMul($t2.as2D($t2.shape[0], $t2.shape[1])).as1D(); - } else if ($t1.rank === 2 && $t2.rank === 1) { - return $t1.matMul($t2.as2D(-1, 1)).as1D(); - } else { - return $t1.matMul($t2.as2D($t2.shape[0], $t2.shape[1])); - } -} - -export const matMul = op({matMul_}); -export const dot = op({dot_}); -export const outerProduct = op({outerProduct_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 2ff0ba3cde3..adc64067146 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -38,16 +38,19 @@ export {depthwiseConv2d} from './depthwise_conv2d'; export {diag} from './diag'; export {div} from './div'; export {divNoNan} from './div_no_nan'; +export {dot} from './dot'; export {equal} from './equal'; export {eye} from './eye'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {less} from './less'; export {lessEqual} from './less_equal'; +export {matMul} from './mat_mul'; export {max} from './max'; export {multinomial} from './multinomial'; export {notEqual} from './not_equal'; export {oneHot} from './one_hot'; +export {outerProduct} from './outer_product'; export {pad} from './pad'; export {pad1d} from './pad1d'; export {pad2d} from './pad2d'; @@ -67,7 +70,6 @@ export {truncatedNormal} from './truncated_normal'; export * from './boolean_mask'; export * from './complex_ops'; -export * from './matmul'; export * from './reverse'; export * from './pool'; export * from './slice'; diff --git a/tfjs-core/src/ops/outer_product.ts b/tfjs-core/src/ops/outer_product.ts new file mode 100644 index 00000000000..b350d2820e5 --- /dev/null +++ b/tfjs-core/src/ops/outer_product.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D, Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {reshape} from './array_ops'; +import {matMul} from './mat_mul'; +import {op} from './operation'; + +/** + * Computes the outer product of two vectors, `v1` and `v2`. + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([3, 4, 5]); + * + * tf.outerProduct(a, b).print(); + * ``` + * @param v1 The first vector in the outer product operation. + * @param v2 The second vector in the outer product operation. + */ +/** @doc {heading: 'Operations', subheading: 'Matrices'} */ +function outerProduct_( + v1: Tensor1D|TensorLike, v2: Tensor1D|TensorLike): Tensor2D { + const $v1 = convertToTensor(v1, 'v1', 'outerProduct'); + const $v2 = convertToTensor(v2, 'v2', 'outerProduct'); + + util.assert( + $v1.rank === 1 && $v2.rank === 1, + () => `Error in outerProduct: inputs must be rank 1, but got ranks ` + + `${$v1.rank} and ${$v2.rank}.`); + + const v12D = reshape($v1, [-1, 1]); + const v22D = reshape($v2, [1, -1]); + return matMul(v12D, v22D) as Tensor2D; +} + +export const outerProduct = op({outerProduct_}); diff --git a/tfjs-core/src/public/chained_ops/dot.ts b/tfjs-core/src/public/chained_ops/dot.ts new file mode 100644 index 00000000000..14639bd41e4 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/dot.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {dot} from '../../ops/dot'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + dot(b: Tensor|TensorLike): Tensor; + } +} + +Tensor.prototype.dot = function(b: T|TensorLike): Tensor { + this.throwIfDisposed(); + return dot(this, b); +}; diff --git a/tfjs-core/src/public/chained_ops/mat_mul.ts b/tfjs-core/src/public/chained_ops/mat_mul.ts new file mode 100644 index 00000000000..d7c592d6878 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/mat_mul.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {matMul} from '../../ops/mat_mul'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + matMul( + b: T|TensorLike, transposeA?: boolean, transposeB?: boolean): T; + } +} + +Tensor.prototype.matMul = function( + this: T, b: T|TensorLike, transposeA?: boolean, transposeB?: boolean): T { + this.throwIfDisposed(); + return matMul(this, b, transposeA, transposeB); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index dca44625289..2dc27b0b8b0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -26,11 +26,13 @@ import './depthwise_conv2d'; import './depthwise_conv2D_deprecated'; import './div'; import './div_no_nan'; +import './dot'; import './equal'; import './greater'; import './greater_equal'; import './less'; import './less_equal'; +import './mat_mul'; import './one_hot'; import './not_equal'; import './pad'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 7188841fea0..e3c8981c477 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -35,11 +35,13 @@ const CHAINED_OPS = [ 'depthwiseConv2D', 'div', 'divNoNan', + 'dot', 'equal', 'greater', 'greaterEqual', 'less', 'lessEqual', + 'matMul', 'notEqual', 'oneHot', 'pad', diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index a425f50f391..21fb7f4c1ef 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -16,6 +16,7 @@ */ import {addGradConfig} from './gradients/Add_grad'; import {addNGradConfig} from './gradients/AddN_grad'; +import {batchMatMulGradConfig} from './gradients/BatchMatMul_grad'; import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {concatGradConfig} from './gradients/Concat_grad'; import {conv2DGradConfig} from './gradients/Conv2D_grad'; @@ -40,16 +41,27 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - addGradConfig, addNGradConfig, - broadcastToGradConfig, concatGradConfig, - conv2DGradConfig, conv2DBackpropInputGradConfig, - conv3DGradConfig, depthwiseConv2dNativeGradConfig, - divGradConfig, fusedBatchNormGradConfig, - greaterEqualGradConfig, identityGradConfig, - oneHotGradConfig, padV2GradConfig, - splitVGradConfig, maxGradConfig, - squareGradConfig, squaredDifferenceGradConfig, - tileGradConfig, transposeGradConfig, + addGradConfig, + addNGradConfig, + batchMatMulGradConfig, + broadcastToGradConfig, + concatGradConfig, + conv2DGradConfig, + conv2DBackpropInputGradConfig, + conv3DGradConfig, + depthwiseConv2dNativeGradConfig, + divGradConfig, + fusedBatchNormGradConfig, + greaterEqualGradConfig, + identityGradConfig, + oneHotGradConfig, + padV2GradConfig, + splitVGradConfig, + maxGradConfig, + squareGradConfig, + squaredDifferenceGradConfig, + tileGradConfig, + transposeGradConfig, subGradConfig ]; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 1df1994bf40..1af9132e3b8 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -181,9 +181,6 @@ export interface OpHandler { squeeze(x: Tensor, axis?: number[]): T; clone(x: T): T; gather(x: T, indices: Tensor|TensorLike, axis: number): T; - matMul( - a: T, b: T|TensorLike, transposeA: boolean, transposeB: boolean): T; - dot(t1: Tensor, t2: Tensor|TensorLike): Tensor; norm( x: Tensor, ord: number|'euclidean'|'fro', axis: number|number[], keepDims: boolean): Tensor; @@ -732,16 +729,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.gather(this, indices, axis); } - - matMul( - this: T, b: T|TensorLike, transposeA = false, transposeB = false): T { - this.throwIfDisposed(); - return opHandler.matMul(this, b, transposeA, transposeB); - } - dot(b: Tensor|TensorLike): Tensor { - this.throwIfDisposed(); - return opHandler.dot(this, b); - } norm( ord: number|'euclidean'|'fro' = 'euclidean', axis: number|number[] = null, keepDims = false): Tensor { diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index c0f71240526..a4ccf06c019 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -80,7 +80,7 @@ import './ops/logical_ops_test'; import './ops/loss_ops_test'; import './ops/lrn_test'; import './ops/lstm_test'; -import './ops/matmul_test'; +import './ops/mat_mul_test'; import './ops/moving_average_test'; import './ops/multinomial_test'; import './ops/not_equal_test';