Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, Ten
export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types';

export * from './ops/ops';
export {Reduction} from './ops/loss_ops';
export {Reduction} from './ops/loss_ops_utils';

export * from './train';
export * from './globals';
Expand Down
60 changes: 60 additions & 0 deletions tfjs-core/src/ops/absolute_difference.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Tensor} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assertShapesMatch} from '../util';
import {computeWeightedLoss} from './compute_weighted_loss';
import {Reduction} from './loss_ops_utils';
import {op} from './operation';
import {sub} from './sub';
import {abs} from './unary_ops';

/**
* Computes the absolute difference loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*/
/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */
function absoluteDifference_<T extends Tensor, O extends Tensor>(
labels: T|TensorLike, predictions: T|TensorLike,
weights?: Tensor|TensorLike,
reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {
const $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
const $predictions =
convertToTensor(predictions, 'predictions', 'absoluteDifference');
let $weights: Tensor = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'absoluteDifference');
}
assertShapesMatch(
$labels.shape, $predictions.shape, 'Error in absoluteDifference: ');

const losses = abs(sub($labels, $predictions));
return computeWeightedLoss(losses, $weights, reduction);
}

export const absoluteDifference = op({absoluteDifference_});
222 changes: 222 additions & 0 deletions tfjs-core/src/ops/absolute_difference_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

describeWithFlags('absoluteDifference', ALL_ENVS, () => {
it('1D', async () => {
const predictions = tf.tensor1d([1, 2, 3]);
const label = tf.tensor1d([0.3, -0.6, -0.1]);

const y = tf.losses.absoluteDifference(label, predictions);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(1 - 0.3) + Math.abs(2 - (-0.6)) + Math.abs(3 - (-0.1))) / 3);
});

it('1D - weighted - Reduction.SUM_BY_NONZERO_WEIGHTS', async () => {
const predictions = tf.tensor1d([1, 2, 3]);
const label = tf.tensor1d([0.3, -0.6, -0.1]);
const weights = tf.tensor1d([0.1, 0.2, 0.3]);

const y = tf.losses.absoluteDifference(label, predictions, weights);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(1 - 0.3) * 0.1 + Math.abs(2 - (-0.6)) * 0.2 +
Math.abs(3 - (-0.1)) * 0.3) /
3);
});

it('1D - weighted - Reduction.NONE', async () => {
const predictions = tf.tensor1d([1, 2, 3]);
const label = tf.tensor1d([0.3, -0.6, -0.1]);
const weights = tf.tensor1d([0.1, 0.2, 0.3]);

const y = tf.losses.absoluteDifference(
label, predictions, weights, tf.Reduction.NONE);

expect(y.shape).toEqual([3]);
expectArraysClose(await y.data(), [
Math.abs(1 - 0.3) * 0.1, Math.abs(2 - (-0.6)) * 0.2,
Math.abs(3 - (-0.1)) * 0.3
]);
});

it('1D - Reduction.MEAN', async () => {
const predictions = tf.tensor1d([1, 2, 3]);
const label = tf.tensor1d([0.3, -0.6, -0.1]);

const y = tf.losses.absoluteDifference(
label, predictions, undefined, tf.Reduction.MEAN);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(1 - 0.3) + Math.abs(2 - (-0.6)) + Math.abs(3 - (-0.1))) / 3);
});

it('1D - weighted - Reduction.MEAN', async () => {
const predictions = tf.tensor1d([1, 2, 3]);
const label = tf.tensor1d([0.3, -0.6, -0.1]);
const weights = tf.tensor1d([0.1, 0.2, 0.3]);

const y = tf.losses.absoluteDifference(
label, predictions, weights, tf.Reduction.MEAN);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
((Math.abs(1 - 0.3) * 0.1) + (Math.abs(2 - (-0.6)) * 0.2) +
(Math.abs(3 - (-0.1)) * 0.3)) /
0.6);
});

it('2D', async () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);

const y = tf.losses.absoluteDifference(label, predictions);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(4 - 1) + Math.abs(8 - 9) + Math.abs(12 - 2) +
Math.abs(8 - (-5)) + Math.abs(1 - (-2)) + Math.abs(3 - 6)) /
6);
});

it('2D - weighted - Reduction.SUM_BY_NONZERO_WEIGHTS', async () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);
const weights = tf.tensor2d([3, 0, 5, 0, 4, 2], [2, 3]);

const y = tf.losses.absoluteDifference(label, predictions, weights);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(4 - 1) * 3 + Math.abs(8 - 9) * 0 + Math.abs(12 - 2) * 5 +
Math.abs(8 - (-5)) * 0 + Math.abs(1 - (-2)) * 4 +
Math.abs(3 - 6) * 2) /
4);
});

it('2D - weighted - Reduction.NONE', async () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);
const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]);

const y = tf.losses.absoluteDifference(
label, predictions, weights, tf.Reduction.NONE);

expect(y.shape).toEqual([2, 3]);
expectArraysClose(await y.data(), [
Math.abs(4 - 1) * 3, Math.abs(8 - 9) * 6, Math.abs(12 - 2) * 5,
Math.abs(8 - (-5)) * 0, Math.abs(1 - (-2)) * 4, Math.abs(3 - 6) * 2
]);
});

it('2D - Reduction.MEAN', async () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);

const y = tf.losses.absoluteDifference(
label, predictions, undefined, tf.Reduction.MEAN);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(4 - 1) + Math.abs(8 - 9) + Math.abs(12 - 2) +
Math.abs(8 - (-5)) + Math.abs(1 - (-2)) + Math.abs(3 - 6)) /
6);
});

it('2D - weighted - Reduction.MEAN', async () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);
const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]);

const y = tf.losses.absoluteDifference(
label, predictions, weights, tf.Reduction.MEAN);

expect(y.shape).toEqual([]);
expectArraysClose(
await y.data(),
(Math.abs(4 - 1) * 3 + Math.abs(8 - 9) * 6 + Math.abs(12 - 2) * 5 +
Math.abs(8 - (-5)) * 0 + Math.abs(1 - (-2)) * 4 +
Math.abs(3 - 6) * 2) /
20);
});

it('throws when passed label as a non-tensor', () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]);

const e =
/Argument 'labels' passed to 'absoluteDifference' must be a Tensor/;
expect(
() => tf.losses.absoluteDifference(
{} as tf.Tensor, predictions, weights, tf.Reduction.MEAN))
.toThrowError(e);
});

it('throws when passed label as a non-tensor', () => {
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);
const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]);

const e = new RegExp(
'Argument \'predictions\' passed to \'absoluteDifference\' ' +
'must be a Tensor');
expect(
() => tf.losses.absoluteDifference(
label, {} as tf.Tensor, weights, tf.Reduction.MEAN))
.toThrowError(e);
});

it('throws when passed weights as a non-tensor', () => {
const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]);
const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]);

const e =
/Argument 'weights' passed to 'absoluteDifference' must be a Tensor/;
expect(
() => tf.losses.absoluteDifference(
label, predictions, {} as tf.Tensor, tf.Reduction.MEAN))
.toThrowError(e);
});

it('accepts a tensor-like object', async () => {
const predictions = [1, 2, 3];
const label = [0.3, -0.6, -0.1];
const weights = [0.1, 0.2, 0.3];

const y = tf.losses.absoluteDifference(
label, predictions, weights, tf.Reduction.NONE);

expect(y.shape).toEqual([3]);
expectArraysClose(await y.data(), [
Math.abs(1 - 0.3) * 0.1, Math.abs(2 - (-0.6)) * 0.2,
Math.abs(3 - (-0.1)) * 0.3
]);
});
});
81 changes: 81 additions & 0 deletions tfjs-core/src/ops/compute_weighted_loss.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';

import {cast} from './array_ops';
import {div} from './div';
import {Reduction} from './loss_ops_utils';
import {mul} from './mul';
import {notEqual} from './not_equal';
import {op} from './operation';
import {mean, sum} from './reduction_ops';
import {ones, scalar} from './tensor_ops';

/**
* Computes the weighted loss between two tensors.
*
* @param losses Tensor of shape `[batch_size, d1, ... dN]`.
* @param weights Tensor whose rank is either 0, or the same rank as
* `losses`, and must be broadcastable to `losses` (i.e., all
* dimensions must be either `1`, or the same as the corresponding
* `losses` dimension).
*/
/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */
function computeWeightedLoss_<T extends Tensor, O extends Tensor>(
losses: T|TensorLike, weights?: Tensor|TensorLike,
reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {
const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
let $weights: Tensor = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
}

const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);

if (reduction === Reduction.NONE) {
return weightedLoss as O;
}
if (reduction === Reduction.SUM) {
return sum(weightedLoss);
}
if (reduction === Reduction.MEAN) {
if ($weights == null) {
return mean(weightedLoss);
} else {
const broadcastFactor = $losses.size / $weights.size;
const result = div(sum(weightedLoss), sum($weights));
return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) :
result as O;
}
}
if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) {
if ($weights == null) {
return div(sum(weightedLoss), scalar($losses.size));
} else {
const broadcastedWeights = mul($weights, ones($losses.shape));

const numNonZeros =
cast(sum(notEqual(broadcastedWeights, scalar(0))), 'float32');
return div(sum(weightedLoss), numNonZeros);
}
}

throw Error(`Unknown reduction: ${reduction}`);
}
export const computeWeightedLoss = op({computeWeightedLoss_});
Loading