diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index 84736ada6a5..007ee9d0d42 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -345,10 +345,10 @@ public final class Ops {
public final SignalOps signal;
- public final TrainOps train;
-
public final QuantizationOps quantization;
+ public final TrainOps train;
+
private final Scope scope;
private Ops(Scope scope) {
@@ -370,8 +370,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
- train = new TrainOps(this);
quantization = new QuantizationOps(this);
+ train = new TrainOps(this);
}
/**
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java
deleted file mode 100644
index aeab16c7c6c..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java
+++ /dev/null
@@ -1,81 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=======================================================================*/
-
-// This class has been generated, DO NOT EDIT!
-
-package org.tensorflow.op.core;
-
-import org.tensorflow.Operand;
-import org.tensorflow.Operation;
-import org.tensorflow.OperationBuilder;
-import org.tensorflow.Output;
-import org.tensorflow.op.RawOp;
-import org.tensorflow.op.Scope;
-import org.tensorflow.op.annotation.Endpoint;
-import org.tensorflow.op.annotation.Operator;
-import org.tensorflow.types.family.TType;
-
-/**
- * Gives a guarantee to the TF runtime that the input tensor is a constant.
- *
- * The runtime is then free to make optimizations based on this.
- *
- * Only accepts value typed tensors as inputs and rejects resource variable handles
- * as input.
- *
- * Returns the input tensor without modification.
- *
- * @param data type for {@code output()} output
- */
-@Operator
-public final class GuaranteeConst extends RawOp implements Operand {
-
- /**
- * Factory method to create a class wrapping a new GuaranteeConst operation.
- *
- * @param scope current scope
- * @param input
- * @return a new instance of GuaranteeConst
- */
- @Endpoint(describeByClass = true)
- public static GuaranteeConst create(Scope scope, Operand input) {
- OperationBuilder opBuilder = scope.env().opBuilder("GuaranteeConst", scope.makeOpName("GuaranteeConst"));
- opBuilder.addInput(input.asOutput());
- opBuilder = scope.apply(opBuilder);
- return new GuaranteeConst(opBuilder.build());
- }
-
- /**
- */
- public Output output() {
- return output;
- }
-
- @Override
- public Output asOutput() {
- return output;
- }
-
- /** The name of this op, as known by TensorFlow core engine */
- public static final String OP_NAME = "GuaranteeConst";
-
- private Output output;
-
- private GuaranteeConst(Operation operation) {
- super(operation);
- int outputIdx = 0;
- output = operation.output(outputIdx++);
- }
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java
deleted file mode 100644
index f3f6e374590..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=======================================================================*/
-
-// This class has been generated, DO NOT EDIT!
-
-package org.tensorflow.op.core;
-
-import org.tensorflow.Operand;
-import org.tensorflow.Operation;
-import org.tensorflow.OperationBuilder;
-import org.tensorflow.Output;
-import org.tensorflow.op.RawOp;
-import org.tensorflow.op.Scope;
-import org.tensorflow.op.annotation.Endpoint;
-import org.tensorflow.op.annotation.Operator;
-import org.tensorflow.types.family.TType;
-
-/**
- * Makes its input available to the next iteration.
- *
- * @param data type for {@code output()} output
- */
-@Operator
-public final class RefNextIteration extends RawOp implements Operand {
-
- /**
- * Factory method to create a class wrapping a new RefNextIteration operation.
- *
- * @param scope current scope
- * @param data The tensor to be made available to the next iteration.
- * @return a new instance of RefNextIteration
- */
- @Endpoint(describeByClass = true)
- public static RefNextIteration create(Scope scope, Operand data) {
- OperationBuilder opBuilder = scope.env().opBuilder("RefNextIteration", scope.makeOpName("RefNextIteration"));
- opBuilder.addInput(data.asOutput());
- opBuilder = scope.apply(opBuilder);
- return new RefNextIteration(opBuilder.build());
- }
-
- /**
- * The same tensor as `data`.
- */
- public Output output() {
- return output;
- }
-
- @Override
- public Output asOutput() {
- return output;
- }
-
- /** The name of this op, as known by TensorFlow core engine */
- public static final String OP_NAME = "RefNextIteration";
-
- private Output output;
-
- private RefNextIteration(Operation operation) {
- super(operation);
- int outputIdx = 0;
- output = operation.output(outputIdx++);
- }
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java
deleted file mode 100644
index 8c60fc6350f..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=======================================================================*/
-
-// This class has been generated, DO NOT EDIT!
-
-package org.tensorflow.op.random.experimental;
-
-import org.tensorflow.Operand;
-import org.tensorflow.Operation;
-import org.tensorflow.OperationBuilder;
-import org.tensorflow.Output;
-import org.tensorflow.op.RawOp;
-import org.tensorflow.op.Scope;
-import org.tensorflow.op.annotation.Endpoint;
-import org.tensorflow.op.annotation.Operator;
-import org.tensorflow.types.family.TType;
-
-/**
- */
-public final class DummySeedGenerator extends RawOp implements Operand {
-
- /**
- * Factory method to create a class wrapping a new DummySeedGenerator operation.
- *
- * @param scope current scope
- * @return a new instance of DummySeedGenerator
- */
- @Endpoint(describeByClass = true)
- public static DummySeedGenerator create(Scope scope) {
- OperationBuilder opBuilder = scope.env().opBuilder("DummySeedGenerator", scope.makeOpName("DummySeedGenerator"));
- opBuilder = scope.apply(opBuilder);
- return new DummySeedGenerator(opBuilder.build());
- }
-
- /**
- */
- public Output> handle() {
- return handle;
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public Output asOutput() {
- return (Output) handle;
- }
-
- /** The name of this op, as known by TensorFlow core engine */
- public static final String OP_NAME = "DummySeedGenerator";
-
- private Output> handle;
-
- private DummySeedGenerator(Operation operation) {
- super(operation);
- int outputIdx = 0;
- handle = operation.output(outputIdx++);
- }
-}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java
index c7edfcca24e..3417c07372a 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java
@@ -202,13 +202,12 @@ public BinaryCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param The data type of the predictions, sampleWeights and loss.
- * @param The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
index 77c6ab2bf87..5aac163c1e4 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
@@ -69,7 +69,7 @@
public class CategoricalCrossentropy extends Loss {
public static final boolean FROM_LOGITS_DEFAULT = false;
public static final float LABEL_SMOOTHING_DEFAULT = 0.0f;
- public static final int DEFAULT_AXIS = -1;
+ public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST;
private final boolean fromLogits;
private final float labelSmoothing;
@@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) {
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
- * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the
- * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a
- * value of 0.1 for label 0 and 0.9 for label 1
+ * @param labelSmoothing Float in [0, 1]. When > 0, label values are
+ * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1 for label 0 and
+ * 0.9 for label 1
*/
public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) {
this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
}
/**
- * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT},
- * and a channel axis of {@link #DEFAULT_AXIS}
+ * Creates a categorical cross entropy Loss using a Loss Reduction of {@link
+ * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS}
*
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
- * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the
- * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a
- * value of 0.1 for label 0 and 0.9 for label 1
+ * @param labelSmoothing Float in [0, 1]. When > 0, label values are
+ * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1 for label 0 and
+ * 0.9 for label 1
*/
public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) {
this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
@@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
- * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the
- * confidence on label values are relaxed. e.g. x=0.2 means that we will use a
- * value of 0.1 for label 0 and 0.9 for label 1
+ * @param labelSmoothing Float in [0, 1]. When > 0, label values are
+ * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means
+ * that we will use a value of 0.1 for label 0 and 0.9
+ * for label 1
* @param reduction Type of Reduction to apply to loss.
*/
public CategoricalCrossentropy(
@@ -199,12 +202,14 @@ public CategoricalCrossentropy(
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
- * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the
- * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a
- * value of 0.1 for label 0 and 0.9 for label 1
+ * @param labelSmoothing Float in [0, 1]. When > 0, label values are
+ * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1 for label 0 and
+ * 0.9 for label 1
* @param reduction Type of Reduction to apply to loss.
- * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last'
- * and axis=1 corresponds to data format 'Channels First'.
+ * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last"
+ * and axis=1 corresponds to data format "Channels First". {@link
+ * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
* @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1.
*/
public CategoricalCrossentropy(
@@ -241,13 +246,12 @@ public CategoricalCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param The data type of the predictions, sampleWeights and loss.
- * @param The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java
index f592c19f8bb..73837ed1756 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java
@@ -25,7 +25,7 @@
* loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions)
* and pos=sum(labels*predictions)
*
- *
labels values are expected to be 0 or 1.
+ * labels values are expected to be 0 or 1.
*
*
Standalone usage:
*
@@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) {
/** {@inheritDoc} */
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand losses = Losses.categoricalHinge(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
index 137c7025c04..0a18d93caf3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
@@ -22,12 +22,13 @@
/**
* Computes the cosine similarity between labels and predictions.
*
- * Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0
- * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to
- * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
- * try to maximize the proximity between predictions and targets. If either labels or predictions is
- * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and
- * targets.
+ *
Note that it is a number between -1 and 1. When it is a negative
+ * number between -1 and 0, 0 indicates orthogonality and
+ * values closer to -1indicate greater similarity. The values closer to 1
+ * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
+ * try to maximize the proximity between predictions and targets. If either labels or
+ * predictions is a zero vector, cosine similarity will be 0 regardless of
+ * the proximity between predictions and targets.
*
*
loss = -sum(l2Norm(labels) * l2Norm(predictions))
*
@@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss {
public static final int DEFAULT_AXIS = -1;
public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO;
- private final int axis;
+ private final int[] axis;
/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis
@@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) {
this(tf, null, axis, DEFAULT_REDUCTION);
}
+ /**
+ * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a
+ * Loss Reduction of {@link #DEFAULT_REDUCTION}
+ *
+ * @param tf the TensorFlow Ops
+ * @param axis The dimension along which the cosine similarity is computed.
+ */
+ public CosineSimilarity(Ops tf, int[] axis) {
+
+ this(tf, null, axis, DEFAULT_REDUCTION);
+ }
/**
* Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
@@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) {
this(tf, name, axis, DEFAULT_REDUCTION);
}
+ /**
+ * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of the loss
+ * @param axis The dimension along which the cosine similarity is computed.
+ */
+ public CosineSimilarity(Ops tf, String name, int[] axis) {
+
+ this(tf, name, axis, DEFAULT_REDUCTION);
+ }
+
/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an
* axis of {@link #DEFAULT_AXIS}
@@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) {
*/
public CosineSimilarity(Ops tf, int axis, Reduction reduction) {
+ this(tf, null, new int[] {axis}, reduction);
+ }
+
+ /**
+ * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name
+ *
+ * @param tf the TensorFlow Ops
+ * @param axis The dimension along which the cosine similarity is computed.
+ * @param reduction Type of Reduction to apply to the loss.
+ */
+ public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) {
+
this(tf, null, axis, reduction);
}
@@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) {
* @param reduction Type of Reduction to apply to the loss.
*/
public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) {
+ this(tf, name, new int[] {axis}, reduction);
+ }
+
+ /**
+ * Creates a Cosine Similarity Loss
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of the loss
+ * @param axis The dimension along which the cosine similarity is computed.
+ * @param reduction Type of Reduction to apply to the loss.
+ */
+ public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) {
super(tf, name, reduction);
this.axis = axis;
}
/** {@inheritDoc} */
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis);
+ losses = tf.math.neg(losses);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
index 88b4a7aa056..db3569441ef 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
@@ -18,15 +18,16 @@
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;
+
import static org.tensorflow.framework.utils.CastHelper.cast;
/**
* Computes the hinge loss between labels and predictions.
*
- * loss = maximum(1 - labels * predictions, 0)
.
+ * loss = maximum(1 - labels * predictions, 0).
*
- *
labels values are expected to be -1 or 1.
- * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.
+ * labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided,
+ * they will be converted to -1 or 1.
*
*
Standalone usage:
*
@@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) {
* label values are not in the set [-1., 0., 1.].
*
* @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be
- * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
+ * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
* @param predictions the predictions, values must be in the range [0. to 1.] inclusive.
* @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is
* provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor
@@ -121,16 +122,19 @@ public Hinge(Ops tf, String name, Reduction reduction) {
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
@SuppressWarnings("unchecked")
- Operand tLabels = predictions.type() == labels.type() ?
- (Operand)labels : cast(tf, labels, predictions.type());
- tLabels = LossesHelper.valueCheck(
+ Operand tLabels =
+ predictions.type() == labels.type()
+ ? (Operand) labels
+ : cast(tf, labels, predictions.type());
+ tLabels =
+ LossesHelper.valueCheck(
getTF(),
"labels value check [-1, 0, 1]",
tLabels,
- cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type()));
+ cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type()));
Operand losses = Losses.hinge(getTF(), tLabels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
index 6d3e3f0c2ac..665a9ac157d 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
@@ -130,8 +130,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) {
/** {@inheritDoc} */
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand losses = Losses.huber(getTF(), labels, predictions, delta);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
index 8cf3db8d518..2aa1f72092b 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
@@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) {
/** {@inheritDoc} */
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
index 1669669a768..78325713e3e 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
@@ -105,8 +105,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) {
/** {@inheritDoc} */
@Override
- public Operand call(
- Operand labels, Operand predictions, Operand sampleWeights) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
Operand losses = Losses.logCosh(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
index ae33d5dfa37..cdd35d28aba 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
@@ -25,7 +25,7 @@ public abstract class Loss {
protected final Reduction reduction;
/**
- * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link
+ * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link
* Loss#REDUCTION_DEFAULT}
*
* @param tf the TensorFlow Ops
@@ -62,10 +62,10 @@ protected Loss(Ops tf, String name, Reduction reduction) {
* @param labels the truth values or labels
* @param predictions the predictions
* @param The data type of the predictions and loss.
- * @param The data type of the labels.
* @return the loss
*/
- public Operand call(Operand labels, Operand predictions) {
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions) {
return call(labels, predictions, null);
}
@@ -82,11 +82,10 @@ public Operand call(Operand labels,
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param The data type of the predictions, sampleWeights and loss.
- * @param The data type of the labels.
* @return the loss
*/
- public abstract Operand call(
- Operand labels, Operand predictions, Operand sampleWeights);
+ public abstract Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights);
/**
* Gets the TensorFlow Ops
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
index 81d9e13c8a9..2222ebb41f8 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
@@ -36,6 +36,9 @@ public class Losses {
/** Default Fuzz factor. */
public static final float EPSILON = 1e-7f;
+ public static final int CHANNELS_LAST = -1;
+ public static final int CHANNELS_FIRST = 1;
+
/**
* Calculates the mean absolute error between labels and predictions.
*
@@ -45,11 +48,10 @@ public class Losses {
* @param labels the labels
* @param predictions the predictions
* @param the data type of the predictions and result
- * @param the data type of the labels
* @return the mean absolute error
*/
- public static Operand meanAbsoluteError(
- Ops tf, Operand labels, Operand predictions) {
+ public static Operand meanAbsoluteError(
+ Ops tf, Operand extends TNumber> labels, Operand predictions) {
Operand tLabels = cast(tf, labels, predictions.type());
LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
predictions = ops.getTarget();
@@ -67,11 +69,10 @@ public static Operand meanAbsoluteErro
* @param labels the labels
* @param predictions the predictions
* @param the data type of the predictions and result
- * @param the data type of the labels
* @return the mean squared error
*/
- public static Operand meanSquaredError(
- Ops tf, Operand labels, Operand predictions) {
+ public static Operand meanSquaredError(
+ Ops tf, Operand extends TNumber> labels, Operand predictions) {
Operand tLabels = cast(tf, labels, predictions.type());
LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
predictions = ops.getTarget();
@@ -88,11 +89,10 @@ public static Operand meanSquaredError
* @param labels the labels
* @param predictions the predictions
* @param the data type of the predictions and result
- * @param the data type of the labels
* @return the mean absolute percentage error
*/
- public static Operand meanAbsolutePercentageError(
- Ops tf, Operand labels, Operand predictions) {
+ public static Operand meanAbsolutePercentageError(
+ Ops tf, Operand extends TNumber> labels, Operand predictions) {
Class predictionType = predictions.type();
Operand tLabels = cast(tf, labels, predictionType);
LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
@@ -102,8 +102,10 @@ public static Operand meanAbsolutePerc
tf.math.abs(
tf.math.div(
tf.math.sub(tLabels, predictions),
- tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType))));
- return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1)));
+ tf.math.maximum(
+ tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType))));
+ return tf.math.mul(
+ cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1)));
}
/**
@@ -115,11 +117,10 @@ public static Operand meanAbsolutePerc
* @param labels the labels
* @param predictions the predictions
* @param the data type of the predictions and result
- * @param the data type of the labels
* @return the mean squared logarithmic percentage error
*/
- public static Operand meanSquaredLogarithmicError(
- Ops tf, Operand labels, Operand predictions) {
+ public static Operand meanSquaredLogarithmicError(
+ Ops tf, Operand extends TNumber> labels, Operand predictions) {
Class predictionType = predictions.type();
Operand tLabels = cast(tf, labels, predictionType);
LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
@@ -149,8 +150,12 @@ public static Operand meanSquaredLogar
* @param the data type of the predictions and labels
* @return the binary crossentropy loss.
*/
- public static Operand binaryCrossentropy(
- Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) {
+ public static Operand binaryCrossentropy(
+ Ops tf,
+ Operand extends TNumber> labels,
+ Operand predictions,
+ boolean fromLogits,
+ float labelSmoothing) {
Operand tLabels = cast(tf, labels, predictions.type());
LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
predictions = ops.getTarget();
@@ -178,7 +183,7 @@ private static Operand binaryCrossentropyHelper(
Ops tf, Operand target, Operand output, boolean fromLogits) {
if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output);
- /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible
+ /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible
if (!(output instanceof Variable) && (!tf.scope().env().isEager())) {
// TODO - this does not work
// TODO output = backtrackIdentity(output);
@@ -215,16 +220,17 @@ private static Operand binaryCrossentropyHelper(
* @param labels true targets
* @param predictions the predictions
* @param fromLogits Whether to interpret predictions as a tensor of logit values
- * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the
- * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a
- * value of 0.1 for label 0 and 0.9 for label 1
+ * @param labelSmoothing Float in [0, 1]. When > 0, label values are
+ * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1 for label 0 and
+ * 0.9 for label 1
* @param axis the
* @param the data type of the predictions and labels
* @return the categorical crossentropy loss.
*/
- public static Operand categoricalCrossentropy(
+ public static Operand categoricalCrossentropy(
Ops tf,
- Operand labels,
+ Operand extends TNumber> labels,
Operand predictions,
boolean fromLogits,
float labelSmoothing,
@@ -239,7 +245,7 @@ public static Operand categoricalCross
tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing);
}
if (fromLogits) {
- return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1);
+ return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis);
}
/* TODO
if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) {
@@ -280,8 +286,8 @@ public static Operand categoricalCross
* @param the data type of the predictions and labels
* @return the categorical hinge loss
*/
- public static Operand categoricalHinge(
- Ops tf, Operand labels, Operand predictions) {
+ public static Operand categoricalHinge(
+ Ops tf, Operand extends TNumber> labels, Operand predictions) {
Class predictionType = predictions.type();
Operand tLabels = cast(tf, labels, predictionType);
LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
@@ -326,8 +332,8 @@ public static Operand categoricalHinge
* @param the data type of the predictions and labels
* @return the cosine similarity loss
*/
- public static Operand cosineSimilarity(
- Ops tf, Operand labels, Operand predictions, int axis) {
+ public static Operand cosineSimilarity(
+ Ops tf, Operand extends TNumber> labels, Operand predictions, int[] axis) {
Operand tLabels = cast(tf, labels, predictions.type());
LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
predictions = lossTuple.getTarget();
@@ -336,8 +342,7 @@ public static Operand cosineSimilarity
tLabels = l2Normalize(tf, tLabels, axis);
predictions = l2Normalize(tf, predictions, axis);
Operand mathMul = tf.math.mul(tLabels, predictions);
- Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE));
- return tf.math.neg(sum);
+ return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE));
}
/**
@@ -352,8 +357,8 @@ public static Operand cosineSimilarity
* @param the data type of the predictions and labels
* @return the hinge loss
*/
- public static Operand hinge(
- Ops tf, Operand