-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics Phase 1 #180
Metrics Phase 1 #180
Conversation
Sync with master tensorflow on upstream
Merge main branch to local branch
Update after losses merge
Fix Javadoc errors (tensorflow#152)
pull type def
Is "cosine proximity" actually a thing outside of Keras? Searching the net, it looks like this is simply called euclidean distance, while cosine similarity stands on its own. Since we already agreed in the past that we don't have to copy exactly what is found in Keras, would it make sense to simply call it |
Instead of wrapping all these loss classes inside metrics, would it make sense to simply add a The actual way is better for metrics discovery, as they all reside in the same package, but it might be better to avoid having classes with the same name and the same instance of (though I'm not sure what would be the use case for doing it, as the loss should always be available as a metric, right?) |
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java
Outdated
Show resolved
Hide resolved
...orflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java
Outdated
Show resolved
Hide resolved
...orflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java
Show resolved
Hide resolved
public abstract class Metric<U extends TNumber, T extends TNumber> { | ||
|
||
/** variables are stored by ExecutionEnvironment, and then by an identifier name */ | ||
protected static Map<ExecutionEnvironment, Map<String, MetricVariable<? extends TNumber>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds a bit dangerous to keep static references to ExecutionEnvironment
instances since both cases (eager and graph mode) they are native resources that should be explicitly released. Any way we can avoid that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requirements and constraints are:
- A metric instance with the same name should reuse the same variables.
- Variables are only valid under the same execution environment.
This issue first showed up in the unit test cases where each test case has a different execution environment.
I tried to solve this issue by using a WeakHashMap. I am open to a better way. Is there some other way to find a variable by name from the graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, variableMap
is only used by Metric
and its subclass Reduce
. Also, a Metric
has an Ops
, so is tied to a particular ExecutionEnvironment
. Can all access to variableMap
be encapsulated in protected
methods and then variableMap
(without the ExecutionEnvironment
key) become an instance variable of Metric
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had it set up the way you describe here. The problem is if the Exec Env changes, the variables are no longer valid but are still referenced in the map. The way Python TF works is individual Metric instances with the same name, use the same Variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to solve this issue by using a WeakHashMap. I am open to a better way. Is there some other way to find a variable by name from the graph?
I think this should work to lookup for existing variables with the same name in the same graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went back to the original python code, and basically the metric only lives as long as its object does, so there is no need to keep the Variable
instance longer than the actual Metric
instance. So I have removed the WeakHashMap
. Now, the variables are stored as class attributes on the Metric
subclass.
@karllessard: You can find the Operation
by iterating Graph.operations on op.name()
, but once you have the Operation
, how do you turn that into a org.tensorflow.op.core.Variable
instance, The constructor to do it is private in org.tensorflow.op.core.Variable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I don't think there is a way for doing it right now... But do you really need it to be a Variable
? Can't it just keep a reference of Operation
instead? Looking quickly at the code, I see that we need the shape and the type of the variables but both are also accessible via operation.output(0).shape()
and operation.output(0).type()
respectively. Would that be enough?
I would also prefer keeping a reference to Variable
as it enforced which type of operation is allowed in this context but we would need to add a mapping in the core to revert an Operation
back to its operator based on its name (that would be a great feature to have though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 @JimClarke5 addressed my concern.
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java
Outdated
Show resolved
Hide resolved
LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); | ||
lValues = tuple.getTarget(); | ||
lSampleWeights = tuple.getSampleWeights(); | ||
// lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this commented line still present by purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LossTuple
takes care of the broadcast, so I'll remove the commented out code.
The main difference between the specific “loss” metrics and I think if you tried to combine these disparate concepts into one class, things would get messy. Also, the real core of the actual loss calculation is already broken out into the Lastly, there is a whole other set of metrics that don’t rely on |
Fix JavaDoc, modify one line code block to include braces.
@@ -78,8 +79,8 @@ public MetricVariable( | |||
} else { | |||
throw new IllegalArgumentException( | |||
String.format( | |||
"An initializer for variable %s of type %s is required", | |||
variable.toString(), type.getSimpleName())); | |||
"Type %s is not a supported for metric variables", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Type %s is not a supported for metric variables" -> "Type %s is not supported for metric variables"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
@JimClarke5 , any comment on this suggestion? |
Going back to review this, the only difference between From researching Keras and notes on GitHub for Keras, it seems this was done to facilitate the disparate goals of a loss versus the goal of a metric. Losses are meant to be minimized, so it seems they apply negative to the result to facilitate with that minimization. For normal cosine similarity, -1 indicates the vectors are diametrically opposed, while 1 means they are similar. As the loss focuses on minimization, they flip the sign, so -1 becomes similar, while +1 becomes dissimilar. That makes it easier to to minimize down toward similarity. In a sense, CosineSimilarity is the right paradigm for both, it is just that the loss version flips the sign to facilitate minimization. Maybe the correct answer is to rename |
…nly live within a single instance of a Metric.
* @param labels the labels | ||
* @param predictions the predictions | ||
* @param sampleWeights sample weights to be applied to values, may be null. | ||
* @param <V> the data type for the sample weights |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I like GitHub's "hide resolved" functionality. I'll just use a 👍 .
👍
* @param labels the labels | ||
* @param predictions the predictions | ||
* @param sampleWeights sample weights to be applied to values, may be null. | ||
* @param <V> the data type for the sample weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> for the labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* @param tf the TensorFlow Ops used to create the result | ||
* @return the result, possibly with control dependencies | ||
*/ | ||
public abstract Operand<T> result(Ops tf); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is tf
required to be on the same ExecutionEnvironment
as this.tf
? If so, should we document and check that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had it to form a control dependency on the result for the callOnce()
method. I reworked it, so I replaced the abstract result(Ops tf)
withabstract result()
and reworked the control dependency logic in callOnce()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
private final String totalName; | ||
private final String countName; | ||
|
||
private final Class<T> type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> resultType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the type for the variables and result | ||
*/ | ||
protected Reduce(Ops tf, String name, long seed, Class<T> type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> resultType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* @throws IllegalArgumentException if values is null | ||
*/ | ||
@Override | ||
public List<Op> updateStateList(Operand<U> values, Operand<T> sampleWeights) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it usually help us to type sampleWeights
as an Operand<T>
, instead of letting it have its own tensor type? It feels arbitrary to me, and I see it doesn't help us in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand the question. Did you mean SampleWeights be typed to something like Operand where so that sample weights is not tied to the Metric class type ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Is it helpful that the tensor type of sampleWeights
is tied to the Metric
result type T
, instead of either being method-local or being tied to the tensor type U
of values
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, all the types get cast to the Metric type T
for the result and the variables, but there is no reason sampleWeights could not be another independent type. I will change the class signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
.withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) | ||
.math | ||
.mul(lValues, lSampleWeights); | ||
} catch (IllegalArgumentException ex) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels unsafe to me to assume the only possible reason for an IllegalArgumentException
in this try/catch is the one handled here. Instead, I would think we'd want to explicitly check for this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your point. Should we create something like BroadcastException
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A special-purpose exception would certainly make this safe and more explicit, and I'd be fine with it. To me, encapsulating this additional information in the return type would feel more idiomatic, but for this intendedly-module-private method, I don't see that as a strong argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already TFInvalidArgumentException
which extends IllegalArgumentException
that is thrown for the dynamic analysis part of this function thrown from within the graph. Should we throw TFInvalidArgumentException
instead for static shapes? Though, that is still somewhat vague, so maybe BroadcastException
makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'd be a fan of the more specific exception name. Perhaps even NotBroadcastableException
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added NotBroadcastableException
which is thrown if the static shapes are not broadcastable. TFInvalidArgumentException
is still thrown when evaluating dynamic shapes as that is done inside the graph. Both exceptions inherit from IllegalArgumentException
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we now need to change the catch to NotBroadcastableException
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
* values</code>, an optional additional input for the <code>sampleWeights</code> | ||
* | ||
* @param values the inputs to be passed to update state, this may not be null | ||
* @param sampleWeights sample weights to be applied to values, may be null. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The supported shape relationships between values
and sampleWeights
are complex. Do we want to document and test them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcasting seems to be inherent throughout Numpy and TF Python. The most concise definition is from tf.broadcast_to()
in TF Python.
Broadcasting is the process of making arrays to have compatible shapes for arithmetic operations. Two shapes are compatible if for each dimension pair they are either equal or one of them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing dimensions, and works its way forward.
If shapes aren't compatible, the low level TF Ops will throw TF Exceptions. These methods that test the shapes are trying to catch this before the TF Ops does as this makes it harder to debug the actual problem.
Here is a link to Numpy
broadcast rules. Broadcasting.
It might be helpful to document this, but maybe in separate document like Numpy
does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it were simply that this method supports broadcasting, then I'd see no reason to say much about it at this stage of the project. But I think the permitted shape relationships are far more complex than that:
-
We use
squeezeOrExpandDimensions
to "squeeze or expand the last dim ofsampleWeight
if its rank differs by 1 from [the rank ofvalues
]." -
Then we assert either at method runtime or at graph execution time (depending on our inputs) that "the
sampleWeights
can be broadcast tovalues
." (I think we intentionally disallow the case wherevalues
would be broadcast tosampleWeights
?) -
But if the inputs contain enough static shape information (how much is enough?) to allow us to check broadcasting at method runtime, and if we discover that the shapes don't support the unidirectional broadcasting that we allow, then instead we apply either
reduceSum
ormath.mean
(depending onthis.reduction
) to reducevalues
down to the same rank assampleWeights
. -
After all of the above logic, we make a broadcasting call to
math.mul
. So, for example, the previous step may usereduceSum
to bringvalues
to the same rank assampleWeights
, after which we may still broadcast one or more dimensions (having size 1) ofvalues
orsampleWeights
, or both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the test cases for broadcasting in TF Python so I am creating Java test cases. I already found an anomaly with dynamic shapes and scalar, so I will investigate that, and when I'm done with the test cases I will check them in.
Also, the method name was wrong and I have corrected it to assertBroadcastable
to match TF Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce.java
line 131, LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights);
null
is the labels
parameter, which is y_true
in the python version, I don't understand your point here.
I always map y_true
to labels
, and y_pred
to predictions
.
From LossesHelper.java
public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
from metrics.py
values, _, sample_weight = tf_losses_utils.squeeze_or_expand_dimensions(
values, sample_weight=sample_weight)
Here values
is y_pred
.
from python/ops/losses/util.py
def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry, I wasn't reading the code I thought I was! Your response is dead-on for the code I thought I was reading: Reduce.update_state
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern that opened this thread remains open:
The supported shape relationships between values and sampleWeights are complex. Do we want to document and test them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added methods assertBroadcastable
and broadcastWeights
to MetricsHelper
along with corresponding test classes. This should address the testing part, at least for these two topics. However, there are shape operations sprinkled all around the code, not just Metrics. I am not sure how feasible it is to capture all these. Much of the Python wrapper classes around low level operations involves manipulating shapes into shapes that are acceptable to the low level Op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the Python code doesn't document the permitted shape relationships either. No reason this PR has to exceed the Python state of the art in this area!
👍 nothing to do
.mul(lValues, lSampleWeights); | ||
} catch (IllegalArgumentException ex) { | ||
// reduce the values down to the rank of the samples | ||
int nDim = lValues.shape().numDimensions(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably we need to either prohibit shapes with unknown dimensions or handle them here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially, broadcastWeights
will only throw the IllegalArgumentException
, if the shapes are fully static and the shapes cannot be broadcast. Maybe a code comment would help explain this state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
// reduce the values down to the rank of the samples | ||
int nDim = lValues.shape().numDimensions(); | ||
int wDim = lSampleWeights.shape().numDimensions(); | ||
int numAxes = nDim - wDim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is nDim
necessarily > wDim
at this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. This is an error, I'll fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* incorrect shape that prohibit broadcasting to to <code>values</code> | ||
*/ | ||
@SuppressWarnings("unchecked") | ||
public static <U extends TNumber> Op broadcastWeights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a clearer name be assertCanBroadcastWeights
? Given that a caller wants to base control flow on a static check showing incorrect shape, would it be cleaner to return Optional<Op>
instead of using the IllegalArgumentException
to signal that situation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Names was changed to assertBroadcastable
to match TF Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
*/ | ||
private static <T extends TNumber> Operand<TBool> hasValidDims( | ||
Ops tf, Operand<T> weightsShape, Operand<T> valuesShape) { | ||
tf = tf.withSubScope("has_invalid_dims"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> "has_valid_dims"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* | ||
* @param tf the TensorFlow Ops | ||
* @param weightsShape the operand for the shape of the sample weights | ||
* @param valuesShape the operand for the shape of the sample weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> "of the values"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* @param <T> the data type for the operands | ||
* @return a boolean operand to determine if the Shape is scalar or not. | ||
*/ | ||
private static <T extends TNumber> Operand<TBool> hasValidNonscalarShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps canBroadcastNonscalarShapes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is was has_valid_nonscalar_shape
in Python. Since this is private, the name doesn't matter to me. So if you think canBroadcastNonscalarShapes
is a better name, then let's change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
canBroadcastNonscalarShapes
does seem better to me. I like the clue that this is part of the canBroadcast
collection of methods, whereas the name hasValidNonscalarShape
doesn't really mean anything from a standing start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK changed to canBroadcastNonscalarShapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* @param <T> the data type for the operands | ||
* @return a boolean operand to determine if the shapes have valid dimensions or not. | ||
*/ | ||
private static <T extends TNumber> Operand<TBool> hasValidDims( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps "canBroadcastDims"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as canBroadcastNonscalarShapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you accidentally omit a comment for canBroadcastNonscalarShapes
? None is visible to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to commit the text, you should see it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I do prefer canBroadcastDims
for the same reason as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK changed to canBroadcastDims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* @param axis Axes to compute the mean. | ||
* @param <T> the type of the Operand. | ||
* @param <U> the type of the axis. | ||
* @return the mean of the operand, alongside the specified axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"alongside" -> "along"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
* | ||
* @param tf the TensorFlow Ops | ||
* @param x the Operand used to calculate the mean | ||
* @param axis Axes to compute the mean. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis or axes? (We currently say both.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be axes
, but tf.math.mean
says axis
, but the description says dims
, plural.
I'll change it to axes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
} | ||
|
||
/** | ||
* Calculate the mean of the operand, along all axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> "all axes"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* | ||
* @param tf the TensorFlow Ops | ||
* @param x the Operand used to calculate the mean | ||
* @param axis Axes to compute the mean. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis
versus "Axes"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to axes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* @return the mean of elements of `x`. | ||
*/ | ||
@SuppressWarnings({"unchecked", "rawtypes"}) | ||
public static <T extends TType, U extends TNumber> Operand<T> mean( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value of this method won't necessarily have a runtime tensor type of T
, because T
could be TBool
, which could get converted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I added Z extends TNumber
and return Operand<Z>
. Also changed U extends TNumber
to U extends TIntegral
as U is for shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new version still has similar problems. I'm going to thumbs-up this comment as done and add an updated comment to the new version of the code.
👍 moving discussion to new version of code
…e same type as the return or internal variables,
…on when static shapes cannot boradcast.
change hasValidNonscalarShape to canBroadcastNonscalarShapes change hasValidNonscalarShape to canBroadcastNonscalarShapes
move the dynamic shapes and rank down to the dynamic section so they are created needlessly when static Fix if statement to check for unknown size and unknown dimensions
renamed WeightBroadcastTest to AssertBroadcastableTest and added BroadcastWeightsTest
….sparse.sparseToDense with the output of tf.sparse.denseToDenseSetOperation
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks @JimClarke5 this looks good on my side, I'll merge it as soon as @deansher gives me a green light
Looking at this in the context of #193 and #176, it seems like there's a few cases where there's a type parameter on the class that isn't really necessary, since it's never used to enforce anything (as best I can tell). I.e. Thoughts? It would make the inline Kotlin wrappers much nicer if this is possible. |
@rnett , following this comment, Jim agreed to go through the whole framework for cleaning up the superfluous generics in a separate PR, probably right after this one. |
@rnett I think the Googlebot CLA is asking you to consent as Jim merged in from master rather than rebasing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I studied all non-test-code deltas plus SetsOpsTest
and AssertBroadcastableTest
. Thanks again, Jim, both for all of your hard work on this and for your patience as I reviewed it.
* and <code>axis=1</code> corresponds to data format 'Channels First'. | ||
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format "Channels Last" | ||
* and <code>axis=1</code> corresponds to data format "Channels First". | ||
* {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intend to add a "See" in front of the links?
* @param b The other operand representing set <code>b</code> | ||
* @param <T>the data type for the sets | ||
* @return An Operand with the same rank as <code>a</code> and <code>b</code>, and all but the | ||
* last dimension the * same. Elements along the last dimension contain the results of the set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
@googlebot I consent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @JimClarke5 for all the great work you provide to this project, merging that PR now
Metrics Phase 1 (tensorflow#180)
This is the first phase of the metrics classes focused on those metrics that leverage
Losses
.The second phase will focus on more complex metrics, like Area Under Curve (AUC).
A couple of notes.
CosineSimilarity
in Keras, does not call its Loss counterpart,losses.cosineSimilarity
, but instead it callsMetrics.cosineProximity
. This metric is calculating the Euclidean distance using L2 norms, while the loss classlosses.CosineSImilarity
is using the dot product proportional to the product of their magnitudes. While the 2 concepts are similar, they are different. Should we rename this metric toCosineProximity
?Metrics
class, there are two methods forl2Normalize()
, which is defined astf.math.l2_normalize
in TF Python implemented in Python. Where should these methods reside, or just leave them inside Metrics?