Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ def crf_log_likelihood(inputs,

# Get the transition matrix if not provided.
if transition_params is None:
transition_params = tf.get_variable("transitions",
[num_tags, num_tags])
initializer = tf.initializers.GlorotUniform()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any precedent for using GlorotUniform here? Thanks very much for the PR as usual

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the default behavior for get_variable.

https://www.tensorflow.org/api_docs/python/tf/get_variable?hl=en
If initializer is None (the default), the default initializer passed in the variable scope will be used. If that one is None too, a glorot_uniform_initializer will be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was why I used GlorotUniform, I could parameterize it with GlorotUniform being the default.

Copy link
Member

@seanpmorgan seanpmorgan Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think parametizer-ing it may be overkill. This looks good to me.

transition_params = tf.Variable(
initializer([num_tags, num_tags]), "transitions")

sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params)
Expand Down
6 changes: 6 additions & 0 deletions tensorflow_addons/text/crf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ def testCrfLogLikelihood(self):
tf_total_log_likelihood = self.evaluate(total_log_likelihood)
self.assertAllClose(tf_total_log_likelihood, 0.0)

# check if `transition_params = None` raises an error
text.crf_log_likelihood(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0))

def testViterbiDecode(self):
inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
dtype=np.float32)
Expand Down