New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Kendall's Tau computation. #1147
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great; thanks a lot for the contribution! I've gone ahead and made a first pass of comments, mostly focused on TFP-specific requirements and style. Happy to discuss further.
exchanges = 0 | ||
num = tf.size(y) | ||
k = tf.constant(1, tf.int32) | ||
while tf.less(k, num): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit annoying, but I'm going to ask that you write this loop, and all loops in this file, using TF graph ops (tf.while_loop
or tf.scan
) in place of Python while
or for
loop constructs. Since we don't know the contexts in which TFP code will be called, we need the whole library to work in both eager and graph modes (these days you can read 'graph mode' as equivalent to 'inside a @tf.function
tracing context'), and in particular when shape information is not statically available. That means we can't assume that conditions like tf.less(k, num)
will be concrete values available to the Python interpreter, so the control flow has to occur at the TF level.
It's usually pretty mechanical to do this conversion, though it does force you to think explicitly about the state carried through the loop. In general if you can frame a Python loop in the form
loop_state = initial_loop_state
while condition(loop_state):
loop_state = loop_body(loop_state)
for some structure of Tensors loop_state
, then the translation is just:
final_loop_state = tf.while_loop(
condition,
loop_body,
initial_loop_vars)
and you can use TensorArray
s to store any values accumulated during the loop (or tf.scan
, which is just a thin wrapper around while_loop
+ TensorArray
).
The test for whether everything is working is to trace the code with unknown inputs of unknown shape
traced_kendalls_tau = tf.function(
kendalls_tau, autograph=False).get_concrete_function(
y_true=tf.TensorSpec(shape=None, dtype=tf.float32),
y_pred=tf.TensorSpec(shape=None, dtype=tf.float32))
and verify that the traced_kendalls_tau
behaves identically to the original kendalls_tau
function.
The @test_util.test_all_tf_execution_regimes
decorator in the unit test file should be doing something like this (though with known shapes). It'd surprise me if those tests are all passing currently? In any case, the above check is the gold standard.
We have a lot of experience writing graph-mode control flow (it's a pain, but you get used to it), so feel free to ask for help if you get stuck.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My test file was missing the main so the tests were not executing.
while tf.less(i, rght) and tf.less(j, rend): | ||
permij = aperm.gather([i, j]) | ||
yij = tf.gather(y, permij) | ||
if tf.less_equal(yij[0], yij[1]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to the point above about loop control flow, we also need to replace
if condition_fn():
result = do_one_thing()
else:
result = do_another_thing()
with the equivalent graph op
result = tf.cond(
condition_fn,
do_one_thing,
do_another_thing)
to handle the case where the condition can't be statically evaluated.
(here and elsewhere)
v += ((n - first) * (n - first - 1)) // 2 | ||
|
||
tot = (n * (n - 1)) // 2 | ||
if tf.equal(tot, u) or tf.equal(tot, v): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above---prefer raising an error (assert_util.assert_none_equal(tot, y)
) to returning NaN
.
Adds conversion of arguments via convert_to_tensor, plus additional paramters for type and name.
Working on integrating via tensorflow probability. |
Is there another PR? |
Not yet, but the code has diverged a lot from this PR so thought I should close. Will ping this thread when it's made part of a future release. |
Migrated and updated https://github.com/tensorflow/addons/pull/2169/files