From cccc99769bef6cc74a28756e831c1a89f992db76 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 7 Apr 2020 17:05:35 -0700 Subject: [PATCH 1/2] Clean up. Remove backward func from SquaredDifference. --- tfjs-core/src/ops/squared_difference.ts | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tfjs-core/src/ops/squared_difference.ts b/tfjs-core/src/ops/squared_difference.ts index 0653191326c..9015cf6ea6b 100644 --- a/tfjs-core/src/ops/squared_difference.ts +++ b/tfjs-core/src/ops/squared_difference.ts @@ -61,13 +61,7 @@ function squaredDifference_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - const der = (dy: Tensor, saved: Tensor[]) => { - const [$a, $b] = saved; - const two = scalar(2); - const derA = () => dy.mul($a.sub($b).mul(two)); - const derB = () => dy.mul($b.sub($a).mul(two)); - return {a: derA, b: derB}; - }; + const forward: ForwardFunc = (backend, save) => { const res = backend.squaredDifference($a, $b); save([$a, $b]); @@ -77,11 +71,9 @@ function squaredDifference_( const inputs: SquaredDifferenceInputs = {a: $a, b: $b}; const attrs = {}; - const inputsToSave = [$a, $b]; - const outputToSave: boolean[] = []; return ENGINE.runKernelFunc( - forward, inputs as unknown as NamedTensorMap, der, - SquaredDifference, attrs, inputsToSave, outputToSave) as T; + forward, inputs as unknown as NamedTensorMap, null /* grad */, + SquaredDifference, attrs) as T; } export const squaredDifference = op({squaredDifference_}); From 0e6ebd4b20941c0951027f05dc2500a39f46af04 Mon Sep 17 00:00:00 2001 From: Na Li Date: Wed, 8 Apr 2020 20:38:21 -0700 Subject: [PATCH 2/2] Fix lint. --- tfjs-core/src/ops/squared_difference.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/ops/squared_difference.ts b/tfjs-core/src/ops/squared_difference.ts index 9015cf6ea6b..23abbe633a8 100644 --- a/tfjs-core/src/ops/squared_difference.ts +++ b/tfjs-core/src/ops/squared_difference.ts @@ -25,7 +25,6 @@ import {TensorLike} from '../types'; import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; -import {scalar} from './tensor_ops'; /** * Returns (a - b) * (a - b) element-wise.