New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a test for deepfm #1442
add a test for deepfm #1442
Changes from 5 commits
eb60db9
cb3eb5d
9d077f5
0b9ddde
a69722d
d88de17
08f2c5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ def create_server_and_stub( | |
port=self._port, | ||
model_zoo=_test_model_zoo_path, | ||
model_def="test_module.custom_model", | ||
**kwargs, | ||
**kwargs | ||
) | ||
pserver = ParameterServer(args) | ||
pserver.prepare() | ||
|
@@ -259,18 +259,18 @@ def push_gradient_test_setup(self): | |
self.embedding_table = ( | ||
np.random.rand(4 * dim).reshape((4, dim)).astype(np.float32) | ||
) | ||
self.embedding_grads = tf.IndexedSlices( | ||
self.embedding_grads0 = tf.IndexedSlices( | ||
values=np.random.rand(3 * dim) | ||
.reshape((3, dim)) | ||
.astype(np.float32), | ||
indices=(3, 1, 3), | ||
) | ||
self.expected_embed_table = np.copy(self.embedding_table) | ||
for gv, gi in zip( | ||
self.embedding_grads.values, self.embedding_grads.indices | ||
): | ||
self.expected_embed_table[gi] -= self._lr * gv | ||
|
||
self.embedding_grads1 = tf.IndexedSlices( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to address the issue here, refine the unit test. |
||
values=np.random.rand(3 * dim) | ||
.reshape((3, dim)) | ||
.astype(np.float32), | ||
indices=(2, 2, 3), | ||
) | ||
push_model_req = elasticdl_pb2.Model() | ||
push_model_req.version = self._parameters.version | ||
for name, value in zip(self.var_names, self.var_values): | ||
|
@@ -297,8 +297,8 @@ def test_push_gradient_async_update(self): | |
emplace_tensor_pb_from_ndarray(req.gradients, g, name=name) | ||
emplace_tensor_pb_from_ndarray( | ||
req.gradients, | ||
values=self.embedding_grads.values, | ||
indices=self.embedding_grads.indices, | ||
values=self.embedding_grads0.values, | ||
indices=self.embedding_grads0.indices, | ||
name=self._embedding_info.name, | ||
) | ||
res = self._stub.push_gradient(req) | ||
|
@@ -316,12 +316,16 @@ def test_push_gradient_async_update(self): | |
) | ||
) | ||
|
||
expected_embed_table = np.copy(self.embedding_table) | ||
for gv, gi in zip( | ||
self.embedding_grads0.values, self.embedding_grads0.indices | ||
): | ||
expected_embed_table[gi] -= self._lr * gv | ||
|
||
actual_embed_table = self._parameters.get_embedding_param( | ||
self._embedding_info.name, range(len(self.expected_embed_table)) | ||
) | ||
self.assertTrue( | ||
np.allclose(self.expected_embed_table, actual_embed_table) | ||
self._embedding_info.name, range(len(expected_embed_table)) | ||
) | ||
self.assertTrue(np.allclose(expected_embed_table, actual_embed_table)) | ||
|
||
# Test applying gradients with same name | ||
for name, var in zip(self.var_names, self.var_values): | ||
|
@@ -358,6 +362,12 @@ def test_push_gradient_sync_update(self): | |
req.model_version = 0 | ||
for g, name in zip(self.grad_values0, self.var_names): | ||
emplace_tensor_pb_from_ndarray(req.gradients, g, name=name) | ||
emplace_tensor_pb_from_ndarray( | ||
req.gradients, | ||
values=self.embedding_grads0.values, | ||
indices=self.embedding_grads0.indices, | ||
name=self._embedding_info.name, | ||
) | ||
res = self._stub.push_gradient(req) | ||
self.assertEqual(res.accepted, True) | ||
self.assertEqual(res.model_version, 0) | ||
|
@@ -368,8 +378,8 @@ def test_push_gradient_sync_update(self): | |
emplace_tensor_pb_from_ndarray(req.gradients, g, name=name) | ||
emplace_tensor_pb_from_ndarray( | ||
req.gradients, | ||
values=self.embedding_grads.values, | ||
indices=self.embedding_grads.indices, | ||
values=self.embedding_grads1.values, | ||
indices=self.embedding_grads1.indices, | ||
name=self._embedding_info.name, | ||
) | ||
res = self._stub.push_gradient(req) | ||
|
@@ -398,9 +408,21 @@ def test_push_gradient_sync_update(self): | |
) | ||
) | ||
|
||
expected_embed_table = np.copy(self.embedding_table) | ||
for gv, gi in zip( | ||
self.embedding_grads0.values, self.embedding_grads0.indices | ||
): | ||
expected_embed_table[gi] -= self._lr * gv | ||
for gv, gi in zip( | ||
self.embedding_grads1.values, self.embedding_grads1.indices | ||
): | ||
expected_embed_table[gi] -= self._lr * gv | ||
|
||
actual_embed_table = self._parameters.get_embedding_param( | ||
self._embedding_info.name, range(len(self.expected_embed_table)) | ||
) | ||
self.assertTrue( | ||
np.allclose(self.expected_embed_table, actual_embed_table) | ||
self._embedding_info.name, range(len(expected_embed_table)) | ||
) | ||
self.assertTrue(np.allclose(expected_embed_table, actual_embed_table)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import os | ||
import unittest | ||
from pathlib import Path | ||
|
||
import grpc | ||
import numpy as np | ||
|
@@ -9,6 +10,9 @@ | |
from elasticdl.python.common.constants import GRPC | ||
from elasticdl.python.common.hash_utils import int_to_id, string_to_id | ||
from elasticdl.python.common.model_utils import get_model_spec | ||
from elasticdl.python.data.recordio_gen.frappe_recordio_gen import ( | ||
load_raw_data, | ||
) | ||
from elasticdl.python.ps.embedding_table import EmbeddingTable | ||
from elasticdl.python.ps.parameter_server import ParameterServer | ||
from elasticdl.python.tests.test_utils import PserverArgs | ||
|
@@ -266,6 +270,67 @@ def test_worker_pull_embedding(self): | |
expected_result = np.concatenate(expected_result) | ||
self.assertTrue(np.allclose(expected_result, result_dict[layer])) | ||
|
||
def test_deepfm_train(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extract the common training logic of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
worker = Worker( | ||
worker_id=0, | ||
job_type=elasticdl_pb2.TRAINING, | ||
minibatch_size=self._batch_size, | ||
model_zoo=self._model_zoo_path, | ||
model_def=( | ||
"deepfm_functional_api.deepfm_functional_api" ".custom_model" | ||
), | ||
ps_channels=self._channel, | ||
) | ||
|
||
home = str(Path.home()) | ||
|
||
class TmpArgs(object): | ||
def __init__(self, data): | ||
self.data = data | ||
|
||
args = TmpArgs(data=home + "/.keras/datasets/") | ||
|
||
x_train, y_train, x_val, y_val, x_test, y_test = load_raw_data(args) | ||
x_train = tf.convert_to_tensor(x_train, dtype=tf.int64) | ||
x_test = tf.convert_to_tensor(x_test, dtype=tf.int64) | ||
y_train = tf.convert_to_tensor(y_train, dtype=tf.int64) | ||
y_test = tf.convert_to_tensor(y_test, dtype=tf.int64) | ||
|
||
db = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | ||
db = db.batch(self._batch_size).repeat(10) | ||
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)) | ||
test_db = test_db.batch(self._batch_size) | ||
|
||
acc_meter = tf.keras.metrics.Accuracy() | ||
|
||
for step, (x, y) in enumerate(db): | ||
if step == 0: | ||
worker._run_model_call_before_training(x) | ||
worker.report_variable() | ||
|
||
worker.get_model(step, elasticdl_pb2.MINIMUM) | ||
w_loss, w_grads = worker.training_process_eagerly(x, y) | ||
worker.report_gradient(w_grads) | ||
|
||
if step % 20 == 0: | ||
worker.get_model(step, elasticdl_pb2.MINIMUM) | ||
for (x, y) in test_db: | ||
out = worker.forward_process(x) | ||
out["probs"] = tf.reshape(out["probs"], [-1]) | ||
acc_meter.update_state( | ||
tf.where( | ||
out["probs"] < 0.5, | ||
x=tf.zeros_like(y), | ||
y=tf.ones_like(y), | ||
), | ||
y, | ||
) | ||
acc = acc_meter.result().numpy() | ||
print("loss: ", w_loss.numpy(), " acc: ", acc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove print? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if acc > 0.7: | ||
return | ||
acc_meter.reset_states() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will generate too many warnings. So I remove it.