Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the CSV hash table in Hash layer and fix some bugs. (#385)
* delete the Lambda sublayer in LocalActivationUnit Layer class * add vocabulary_path in the SparseFeat to support the csv HashTable functionality * update docs and add examples in doc * Remove trailing whitespace * add the pytest code for testing hash vocabulary_path parameter * add the import tensorflow.python.ops.numpy_ops.np_config as np_config * update the utils_test * use CustomObjectScope and layer_test * add comput_output_shape * modify the tf.constant to np.array * adapt to v1 session to may be solve the err in the ci: Failed precondition: Table not initialized. * revert the last commit * modify the PostionEncoding layer to fit the tf v2 mode successfully * modify the tf.version to compatible with tf1.4
- Loading branch information
Showing
11 changed files
with
119 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
import numpy as np | ||
import tensorflow as tf | ||
from deepctr.layers.utils import Hash | ||
from tests.utils import layer_test | ||
try: | ||
from tensorflow.python.keras.utils import CustomObjectScope | ||
except: | ||
from tensorflow.keras.utils import CustomObjectScope | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'num_buckets,mask_zero,vocabulary_path,input_data,expected_output', | ||
[ | ||
(3+1, False, None, ['lakemerson'], None), | ||
(3+1, True, None, ['lakemerson'], None), | ||
(3+1, False, "./tests/layers/vocabulary_example.csv", [['lake'], ['johnson'], ['lakemerson']], [[1], [3], [0]]) | ||
] | ||
) | ||
def test_Hash(num_buckets, mask_zero, vocabulary_path, input_data, expected_output): | ||
if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0': | ||
return | ||
|
||
with CustomObjectScope({'Hash': Hash}): | ||
layer_test(Hash, kwargs={'num_buckets': num_buckets, 'mask_zero': mask_zero, 'vocabulary_path': vocabulary_path}, | ||
input_dtype=tf.string, input_data=np.array(input_data, dtype='str'), | ||
expected_output_dtype=tf.int64, expected_output=expected_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
1,lake | ||
2,merson | ||
3,johnson |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters