Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit ecfefd1

Browse files
manrajgroverNikhil Thorat
authored andcommitted
Loss Operations: Adds Absolute Difference loss. (#775)
1 parent c27ed8f commit ecfefd1

File tree

4 files changed

+435
-1
lines changed

4 files changed

+435
-1
lines changed

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export {loadWeights} from './weights_loader';
4242

4343
export * from './ops/ops';
4444
export {LSTMCellFunc} from './ops/lstm';
45+
export {Reduction} from './ops/loss_ops';
4546

4647
export * from './train';
4748
export * from './globals';

src/ops/loss_ops.ts

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {doc} from '../doc';
19+
import {Tensor} from '../tensor';
20+
import * as util from '../util';
21+
22+
import {operation} from './operation';
23+
import * as ops from './ops';
24+
25+
export enum Reduction {
26+
NONE,
27+
MEAN,
28+
SUM,
29+
SUM_BY_NONZERO_WEIGHTS
30+
}
31+
32+
export class LossOps {
33+
/**
34+
* Computes the weighted loss between two tensors.
35+
*
36+
* @param losses Tensor of shape `[batch_size, d1, ... dN]`.
37+
* @param weights Tensor whose rank is either 0, or the same rank as
38+
* `losses`, and must be broadcastable to `losses` (i.e., all
39+
* dimensions must be either `1`, or the same as the corresponding
40+
* `losses` dimension).
41+
*/
42+
@doc({heading: 'Training', subheading: 'Losses', namespace: 'losses'})
43+
@operation
44+
static computeWeightedLoss<T extends Tensor, O extends Tensor>(
45+
losses: T, weights?: Tensor,
46+
reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {
47+
const weightedLoss = (weights == null) ? losses : losses.mul(weights);
48+
49+
if (reduction === Reduction.NONE) {
50+
return weightedLoss as O;
51+
}
52+
if (reduction === Reduction.SUM) {
53+
return weightedLoss.sum();
54+
}
55+
if (reduction === Reduction.MEAN) {
56+
return (weights == null) ? weightedLoss.mean() :
57+
weightedLoss.sum().div(weights.sum());
58+
}
59+
if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) {
60+
if (weights == null) {
61+
return weightedLoss.sum().div(ops.scalar(losses.size));
62+
} else {
63+
const numNonZeros = weights.notEqual(ops.scalar(0)).sum().toFloat();
64+
return weightedLoss.sum().div(numNonZeros);
65+
}
66+
}
67+
68+
throw Error(`Unknown reduction: ${reduction}`);
69+
}
70+
71+
/**
72+
* Computes the absolute difference loss between two tensors.
73+
*
74+
* @param labels The ground truth output tensor, same dimensions as
75+
* 'predictions'.
76+
* @param predictions The predicted outputs.
77+
* @param weights Tensor whose rank is either 0, or the same rank as
78+
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
79+
* must be either `1`, or the same as the corresponding `losses`
80+
* dimension).
81+
* @param reduction Type of reduction to apply to loss. Should be of type
82+
* `Reduction`
83+
*/
84+
@doc({heading: 'Training', subheading: 'Losses', namespace: 'losses'})
85+
@operation
86+
static absoluteDifference<T extends Tensor, O extends Tensor>(
87+
labels: T, predictions: T, weights?: Tensor,
88+
reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {
89+
util.assertShapesMatch(
90+
labels.shape, predictions.shape, 'Error in absoluteDifference: ');
91+
const losses = labels.sub(predictions).abs();
92+
return LossOps.computeWeightedLoss(losses, weights, reduction);
93+
}
94+
}

0 commit comments

Comments
 (0)