diff --git a/tfjs-core/src/ops/squared_difference.ts b/tfjs-core/src/ops/squared_difference.ts index 0653191326c..23abbe633a8 100644 --- a/tfjs-core/src/ops/squared_difference.ts +++ b/tfjs-core/src/ops/squared_difference.ts @@ -25,7 +25,6 @@ import {TensorLike} from '../types'; import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; -import {scalar} from './tensor_ops'; /** * Returns (a - b) * (a - b) element-wise. @@ -61,13 +60,7 @@ function squaredDifference_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - const der = (dy: Tensor, saved: Tensor[]) => { - const [$a, $b] = saved; - const two = scalar(2); - const derA = () => dy.mul($a.sub($b).mul(two)); - const derB = () => dy.mul($b.sub($a).mul(two)); - return {a: derA, b: derB}; - }; + const forward: ForwardFunc = (backend, save) => { const res = backend.squaredDifference($a, $b); save([$a, $b]); @@ -77,11 +70,9 @@ function squaredDifference_( const inputs: SquaredDifferenceInputs = {a: $a, b: $b}; const attrs = {}; - const inputsToSave = [$a, $b]; - const outputToSave: boolean[] = []; return ENGINE.runKernelFunc( - forward, inputs as unknown as NamedTensorMap, der, - SquaredDifference, attrs, inputsToSave, outputToSave) as T; + forward, inputs as unknown as NamedTensorMap, null /* grad */, + SquaredDifference, attrs) as T; } export const squaredDifference = op({squaredDifference_});