Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a test for deepfm #1442

Merged
merged 7 commits into from Nov 12, 2019
Merged

Conversation

QiJune
Copy link
Collaborator

@QiJune QiJune commented Nov 12, 2019

Fix #1441

"ElasticDL Tensor ignores dense_shape in "
"TensorFlow.IndexedSlices."
)
# TODO(yunjian.lmh): Support dense shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will generate too many warnings. So I remove it.

):
self.expected_embed_table[gi] -= self._lr * gv

self.embedding_grads1 = tf.IndexedSlices(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to address the issue here, refine the unit test.

y,
)
acc = acc_meter.result().numpy()
print("loss: ", w_loss.numpy(), " acc: ", acc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -266,6 +270,67 @@ def test_worker_pull_embedding(self):
expected_result = np.concatenate(expected_result)
self.assertTrue(np.allclose(expected_result, result_dict[layer]))

def test_deepfm_train(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract the common training logic of test_deepfm_train and test_mnist_train as a function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@mhaoli mhaoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!

@QiJune QiJune merged commit 21ebd82 into sql-machine-learning:develop Nov 12, 2019
@QiJune QiJune deleted the deepfm_test branch November 12, 2019 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeepFM unit test for new PS
2 participants