Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics Phase 1 #180

Merged
merged 67 commits into from
Feb 1, 2021
Merged

Metrics Phase 1 #180

merged 67 commits into from
Feb 1, 2021

Conversation

JimClarke5
Copy link
Contributor

This is the first phase of the metrics classes focused on those metrics that leverage Losses.
The second phase will focus on more complex metrics, like Area Under Curve (AUC).

A couple of notes.

  1. The metric CosineSimilarity in Keras, does not call its Loss counterpart, losses.cosineSimilarity, but instead it calls Metrics.cosineProximity. This metric is calculating the Euclidean distance using L2 norms, while the loss class losses.CosineSImilarity is using the dot product proportional to the product of their magnitudes. While the 2 concepts are similar, they are different. Should we rename this metric to CosineProximity?
  2. In the Metrics class, there are two methods for l2Normalize(), which is defined as tf.math.l2_normalize in TF Python implemented in Python. Where should these methods reside, or just leave them inside Metrics?

@JimClarke5 JimClarke5 mentioned this pull request Jan 3, 2021
@karllessard
Copy link
Collaborator

Should we rename this metric to CosineProximity?

Is "cosine proximity" actually a thing outside of Keras? Searching the net, it looks like this is simply called euclidean distance, while cosine similarity stands on its own.

Since we already agreed in the past that we don't have to copy exactly what is found in Keras, would it make sense to simply call it EuclideanDistance or EuclideanDistanceL2?

@karllessard
Copy link
Collaborator

karllessard commented Jan 5, 2021

Instead of wrapping all these loss classes inside metrics, would it make sense to simply add a Metric interface implemented by the same classes to avoid duplicating them?

The actual way is better for metrics discovery, as they all reside in the same package, but it might be better to avoid having classes with the same name and the same instance of Loss could be used for computing the loss and collecting it as a metric, wdyt?

(though I'm not sure what would be the use case for doing it, as the loss should always be available as a metric, right?)

public abstract class Metric<U extends TNumber, T extends TNumber> {

/** variables are stored by ExecutionEnvironment, and then by an identifier name */
protected static Map<ExecutionEnvironment, Map<String, MetricVariable<? extends TNumber>>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds a bit dangerous to keep static references to ExecutionEnvironment instances since both cases (eager and graph mode) they are native resources that should be explicitly released. Any way we can avoid that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requirements and constraints are:

  1. A metric instance with the same name should reuse the same variables.
  2. Variables are only valid under the same execution environment.

This issue first showed up in the unit test cases where each test case has a different execution environment.

I tried to solve this issue by using a WeakHashMap. I am open to a better way. Is there some other way to find a variable by name from the graph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, variableMap is only used by Metric and its subclass Reduce. Also, a Metric has an Ops, so is tied to a particular ExecutionEnvironment. Can all access to variableMap be encapsulated in protected methods and then variableMap (without the ExecutionEnvironment key) become an instance variable of Metric?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had it set up the way you describe here. The problem is if the Exec Env changes, the variables are no longer valid but are still referenced in the map. The way Python TF works is individual Metric instances with the same name, use the same Variables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to solve this issue by using a WeakHashMap. I am open to a better way. Is there some other way to find a variable by name from the graph?

I think this should work to lookup for existing variables with the same name in the same graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back to the original python code, and basically the metric only lives as long as its object does, so there is no need to keep the Variable instance longer than the actual Metric instance. So I have removed the WeakHashMap. Now, the variables are stored as class attributes on the Metric subclass.

@karllessard: You can find the Operation by iterating Graph.operations on op.name(), but once you have the Operation, how do you turn that into a org.tensorflow.op.core.Variable instance, The constructor to do it is private in org.tensorflow.op.core.Variable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I don't think there is a way for doing it right now... But do you really need it to be a Variable? Can't it just keep a reference of Operation instead? Looking quickly at the code, I see that we need the shape and the type of the variables but both are also accessible via operation.output(0).shape() and operation.output(0).type() respectively. Would that be enough?

I would also prefer keeping a reference to Variable as it enforced which type of operation is allowed in this context but we would need to add a mapping in the core to revert an Operation back to its operator based on its name (that would be a great feature to have though).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 @JimClarke5 addressed my concern.

LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights);
lValues = tuple.getTarget();
lSampleWeights = tuple.getSampleWeights();
// lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this commented line still present by purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LossTuple takes care of the broadcast, so I'll remove the commented out code.

@JimClarke5
Copy link
Contributor Author

Instead of wrapping all these loss classes inside metrics, would it make sense to simply add a Metric interface implemented by the same classes to avoid duplicating them?

The actual way is better for metrics discovery, as they all reside in the same package, but it might be better to avoid having classes with the same name and the same instance of Loss could be used for computing the loss and collecting it as a metric, wdyt?

(though I'm not sure what would be the use case for doing it, as the loss should always be available as a metric, right?)

The main difference between the specific “loss” metrics and Loss classes is one of state. Loss classes are stateless in that the the loss is calculated, returned, then forgotten. Metrics, however are stateful. After each call, the result is added to the previous result along with a running counter. As a result, the reductions differ. Loss reduction focuses on reducing a multi dimensional result into a single value, e.g. reduce sum. While a Metric reduction looks to reduce the metric results across state, e.g a mean.
To maintain the state, Metric classes keep two variables, total and count. Loss classes do not have any variables. Also reduction logic is different.

I think if you tried to combine these disparate concepts into one class, things would get messy. Also, the real core of the actual loss calculation is already broken out into the Losses methods used by both classes.

Lastly, there is a whole other set of metrics that don’t rely on Losses at all, including the various accuracy metrics. These will be included in metrics phase 2.

Fix JavaDoc,
modify one line code block to include braces.
@@ -78,8 +79,8 @@ public MetricVariable(
} else {
throw new IllegalArgumentException(
String.format(
"An initializer for variable %s of type %s is required",
variable.toString(), type.getSimpleName()));
"Type %s is not a supported for metric variables",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Type %s is not a supported for metric variables" -> "Type %s is not supported for metric variables"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@karllessard
Copy link
Collaborator

@JimClarke5 , any comment on this suggestion?

@JimClarke5
Copy link
Contributor Author

Should we rename this metric to CosineProximity?

Is "cosine proximity" actually a thing outside of Keras? Searching the net, it looks like this is simply called euclidean distance, while cosine similarity stands on its own.

Since we already agreed in the past that we don't have to copy exactly what is found in Keras, would it make sense to simply call it EuclideanDistance or EuclideanDistanceL2?

Going back to review this, the only difference between losses.CosineSimilarity and metrics.CosineSimilarity is the sign of the result. losses.CosineSimilarity applies a tf.math.neg() on the final result, while metrics.CosineSimilarity does not.

From researching Keras and notes on GitHub for Keras, it seems this was done to facilitate the disparate goals of a loss versus the goal of a metric. Losses are meant to be minimized, so it seems they apply negative to the result to facilitate with that minimization. For normal cosine similarity, -1 indicates the vectors are diametrically opposed, while 1 means they are similar. As the loss focuses on minimization, they flip the sign, so -1 becomes similar, while +1 becomes dissimilar. That makes it easier to to minimize down toward similarity. In a sense, CosineSimilarity is the right paradigm for both, it is just that the loss version flips the sign to facilitate minimization.

Maybe the correct answer is to rename losses.CosineSimilarity to losses.CosineSimilarityLoss, to indicate that they are both calculating Cosine Similarity, but that the Loss class finagles the result for help with the loss calculation.

* @param labels the labels
* @param predictions the predictions
* @param sampleWeights sample weights to be applied to values, may be null.
* @param <V> the data type for the sample weights

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like GitHub's "hide resolved" functionality. I'll just use a 👍 .

👍

* @param labels the labels
* @param predictions the predictions
* @param sampleWeights sample weights to be applied to values, may be null.
* @param <V> the data type for the sample weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> for the labels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* @param tf the TensorFlow Ops used to create the result
* @return the result, possibly with control dependencies
*/
public abstract Operand<T> result(Ops tf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is tf required to be on the same ExecutionEnvironment as this.tf? If so, should we document and check that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had it to form a control dependency on the result for the callOnce() method. I reworked it, so I replaced the abstract result(Ops tf) withabstract result() and reworked the control dependency logic in callOnce().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

private final String totalName;
private final String countName;

private final Class<T> type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> resultType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* will always produce the same random tensor for a given shape and data type.
* @param type the type for the variables and result
*/
protected Reduce(Ops tf, String name, long seed, Class<T> type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> resultType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* @throws IllegalArgumentException if values is null
*/
@Override
public List<Op> updateStateList(Operand<U> values, Operand<T> sampleWeights) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it usually help us to type sampleWeights as an Operand<T>, instead of letting it have its own tensor type? It feels arbitrary to me, and I see it doesn't help us in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the question. Did you mean SampleWeights be typed to something like Operand where so that sample weights is not tied to the Metric class type ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Is it helpful that the tensor type of sampleWeights is tied to the Metric result type T, instead of either being method-local or being tied to the tensor type U of values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, all the types get cast to the Metric type T for the result and the variables, but there is no reason sampleWeights could not be another independent type. I will change the class signature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

.withControlDependencies(Collections.singletonList(broadcastWeightsCheck))
.math
.mul(lValues, lSampleWeights);
} catch (IllegalArgumentException ex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels unsafe to me to assume the only possible reason for an IllegalArgumentException in this try/catch is the one handled here. Instead, I would think we'd want to explicitly check for this condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point. Should we create something like BroadcastException?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A special-purpose exception would certainly make this safe and more explicit, and I'd be fine with it. To me, encapsulating this additional information in the return type would feel more idiomatic, but for this intendedly-module-private method, I don't see that as a strong argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already TFInvalidArgumentException which extends IllegalArgumentException that is thrown for the dynamic analysis part of this function thrown from within the graph. Should we throw TFInvalidArgumentException instead for static shapes? Though, that is still somewhat vague, so maybe BroadcastException makes more sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd be a fan of the more specific exception name. Perhaps even NotBroadcastableException?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added NotBroadcastableException which is thrown if the static shapes are not broadcastable. TFInvalidArgumentException is still thrown when evaluating dynamic shapes as that is done inside the graph. Both exceptions inherit from IllegalArgumentException.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we now need to change the catch to NotBroadcastableException.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

* values</code>, an optional additional input for the <code>sampleWeights</code>
*
* @param values the inputs to be passed to update state, this may not be null
* @param sampleWeights sample weights to be applied to values, may be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The supported shape relationships between values and sampleWeights are complex. Do we want to document and test them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting seems to be inherent throughout Numpy and TF Python. The most concise definition is from tf.broadcast_to() in TF Python.

Broadcasting is the process of making arrays to have compatible shapes for arithmetic operations. Two shapes are compatible if for each dimension pair they are either equal or one of them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing dimensions, and works its way forward.

If shapes aren't compatible, the low level TF Ops will throw TF Exceptions. These methods that test the shapes are trying to catch this before the TF Ops does as this makes it harder to debug the actual problem.

Here is a link to Numpy broadcast rules. Broadcasting.

It might be helpful to document this, but maybe in separate document like Numpy does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it were simply that this method supports broadcasting, then I'd see no reason to say much about it at this stage of the project. But I think the permitted shape relationships are far more complex than that:

  • We use squeezeOrExpandDimensions to "squeeze or expand the last dim of sampleWeight if its rank differs by 1 from [the rank of values]."

  • Then we assert either at method runtime or at graph execution time (depending on our inputs) that "the sampleWeights can be broadcast to values." (I think we intentionally disallow the case where values would be broadcast to sampleWeights?)

  • But if the inputs contain enough static shape information (how much is enough?) to allow us to check broadcasting at method runtime, and if we discover that the shapes don't support the unidirectional broadcasting that we allow, then instead we apply either reduceSum or math.mean (depending on this.reduction) to reduce values down to the same rank as sampleWeights.

  • After all of the above logic, we make a broadcasting call to math.mul. So, for example, the previous step may use reduceSum to bring values to the same rank as sampleWeights, after which we may still broadcast one or more dimensions (having size 1) of values or sampleWeights, or both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the test cases for broadcasting in TF Python so I am creating Java test cases. I already found an anomaly with dynamic shapes and scalar, so I will investigate that, and when I'm done with the test cases I will check them in.

Also, the method name was wrong and I have corrected it to assertBroadcastable to match TF Python.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce.java line 131, LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights);
null is the labels parameter, which is y_true in the python version, I don't understand your point here.

I always map y_true to labels, and y_pred to predictions.

From LossesHelper.java

public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
      Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights) 

from metrics.py

values, _, sample_weight = tf_losses_utils.squeeze_or_expand_dimensions(
          values, sample_weight=sample_weight)

Here values is y_pred.

from python/ops/losses/util.py
def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry, I wasn't reading the code I thought I was! Your response is dead-on for the code I thought I was reading: Reduce.update_state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern that opened this thread remains open:

The supported shape relationships between values and sampleWeights are complex. Do we want to document and test them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added methods assertBroadcastable and broadcastWeights to MetricsHelper along with corresponding test classes. This should address the testing part, at least for these two topics. However, there are shape operations sprinkled all around the code, not just Metrics. I am not sure how feasible it is to capture all these. Much of the Python wrapper classes around low level operations involves manipulating shapes into shapes that are acceptable to the low level Op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the Python code doesn't document the permitted shape relationships either. No reason this PR has to exceed the Python state of the art in this area!

👍 nothing to do

.mul(lValues, lSampleWeights);
} catch (IllegalArgumentException ex) {
// reduce the values down to the rank of the samples
int nDim = lValues.shape().numDimensions();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably we need to either prohibit shapes with unknown dimensions or handle them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, broadcastWeights will only throw the IllegalArgumentException, if the shapes are fully static and the shapes cannot be broadcast. Maybe a code comment would help explain this state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

// reduce the values down to the rank of the samples
int nDim = lValues.shape().numDimensions();
int wDim = lSampleWeights.shape().numDimensions();
int numAxes = nDim - wDim;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is nDim necessarily > wDim at this point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. This is an error, I'll fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* incorrect shape that prohibit broadcasting to to <code>values</code>
*/
@SuppressWarnings("unchecked")
public static <U extends TNumber> Op broadcastWeights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a clearer name be assertCanBroadcastWeights? Given that a caller wants to base control flow on a static check showing incorrect shape, would it be cleaner to return Optional<Op> instead of using the IllegalArgumentException to signal that situation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names was changed to assertBroadcastable to match TF Python.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

*/
private static <T extends TNumber> Operand<TBool> hasValidDims(
Ops tf, Operand<T> weightsShape, Operand<T> valuesShape) {
tf = tf.withSubScope("has_invalid_dims");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> "has_valid_dims"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*
* @param tf the TensorFlow Ops
* @param weightsShape the operand for the shape of the sample weights
* @param valuesShape the operand for the shape of the sample weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> "of the values"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

* @param <T> the data type for the operands
* @return a boolean operand to determine if the Shape is scalar or not.
*/
private static <T extends TNumber> Operand<TBool> hasValidNonscalarShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps canBroadcastNonscalarShapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is was has_valid_nonscalar_shape in Python. Since this is private, the name doesn't matter to me. So if you think canBroadcastNonscalarShapes is a better name, then let's change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

canBroadcastNonscalarShapes does seem better to me. I like the clue that this is part of the canBroadcast collection of methods, whereas the name hasValidNonscalarShape doesn't really mean anything from a standing start.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK changed to canBroadcastNonscalarShapes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* @param <T> the data type for the operands
* @return a boolean operand to determine if the shapes have valid dimensions or not.
*/
private static <T extends TNumber> Operand<TBool> hasValidDims(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps "canBroadcastDims"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as canBroadcastNonscalarShapes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you accidentally omit a comment for canBroadcastNonscalarShapes? None is visible to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to commit the text, you should see it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I do prefer canBroadcastDims for the same reason as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK changed to canBroadcastDims

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* @param axis Axes to compute the mean.
* @param <T> the type of the Operand.
* @param <U> the type of the axis.
* @return the mean of the operand, alongside the specified axis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"alongside" -> "along"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

*
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param axis Axes to compute the mean.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis or axes? (We currently say both.)

Copy link
Contributor Author

@JimClarke5 JimClarke5 Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be axes, but tf.math.mean says axis, but the description says dims, plural.
I'll change it to axes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

}

/**
* Calculate the mean of the operand, along all axis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> "all axes"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

*
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param axis Axes to compute the mean.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis versus "Axes"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to axes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* @return the mean of elements of `x`.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public static <T extends TType, U extends TNumber> Operand<T> mean(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value of this method won't necessarily have a runtime tensor type of T, because T could be TBool, which could get converted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I added Z extends TNumber and return Operand<Z>. Also changed U extends TNumber to U extends TIntegral as U is for shapes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new version still has similar problems. I'm going to thumbs-up this comment as done and add an updated comment to the new version of the code.

👍 moving discussion to new version of code

change hasValidNonscalarShape to canBroadcastNonscalarShapes
change hasValidNonscalarShape to canBroadcastNonscalarShapes
move the dynamic shapes and rank down to the dynamic section so they are created needlessly when static
Fix if statement to check for unknown size and unknown dimensions
renamed WeightBroadcastTest to AssertBroadcastableTest and added BroadcastWeightsTest
….sparse.sparseToDense with the output of tf.sparse.denseToDenseSetOperation
@google-cla
Copy link

google-cla bot commented Jan 30, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Jan 30, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

Copy link
Collaborator

@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks @JimClarke5 this looks good on my side, I'll merge it as soon as @deansher gives me a green light

@rnett
Copy link
Contributor

rnett commented Jan 31, 2021

Looking at this in the context of #193 and #176, it seems like there's a few cases where there's a type parameter on the class that isn't really necessary, since it's never used to enforce anything (as best I can tell). I.e. U in Metric. This also affects #201, since the Kotlin wrappers won't work as well if you have to specify extra types. Since things are cast anyways, it looks like U could be moved to the call method(s) and/or replaced by ? extends TNumber. There's some places in LossHelper or MetricsHelper that would need to be refactored, though, if you use ? extends TNumber, and moving it to the call methods is fine for my needs.

Thoughts? It would make the inline Kotlin wrappers much nicer if this is possible.

@karllessard
Copy link
Collaborator

@rnett , following this comment, Jim agreed to go through the whole framework for cleaning up the superfluous generics in a separate PR, probably right after this one.

@Craigacp
Copy link
Collaborator

@rnett I think the Googlebot CLA is asking you to consent as Jim merged in from master rather than rebasing.

Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I studied all non-test-code deltas plus SetsOpsTest and AssertBroadcastableTest. Thanks again, Jim, both for all of your hard work on this and for your patience as I reviewed it.

* and <code>axis=1</code> corresponds to data format 'Channels First'.
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format "Channels Last"
* and <code>axis=1</code> corresponds to data format "Channels First".
* {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend to add a "See" in front of the links?

* @param b The other operand representing set <code>b</code>
* @param <T>the data type for the sets
* @return An Operand with the same rank as <code>a</code> and <code>b</code>, and all but the
* last dimension the * same. Elements along the last dimension contain the results of the set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

@rnett
Copy link
Contributor

rnett commented Jan 31, 2021

@googlebot I consent.

Copy link
Collaborator

@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @JimClarke5 for all the great work you provide to this project, merging that PR now

@karllessard karllessard merged commit 3a0489e into tensorflow:master Feb 1, 2021
JimClarke5 added a commit to JimClarke5/java that referenced this pull request Feb 1, 2021
@JimClarke5 JimClarke5 deleted the metrics1 branch February 3, 2021 22:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants