Skip to content

Commit

Permalink
Rewrite example test (#1539)
Browse files Browse the repository at this point in the history
* run deepfm example test success

* add test_deepfm_eval and test_mnist_train

* add resnet50

* follow comments
  • Loading branch information
mhaoli committed Dec 2, 2019
1 parent a6120c2 commit 13654d6
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 47 deletions.
163 changes: 163 additions & 0 deletions elasticdl/python/tests/example_test.py
@@ -0,0 +1,163 @@
import os
import unittest

from elasticdl.python.tests.test_utils import (
DatasetName,
create_pserver,
distributed_train_and_evaluate,
)

_model_zoo_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../../model_zoo"
)


class ExampleTest(unittest.TestCase):
def _test_train(
self,
feature_shape,
model_def,
model_params="",
dataset_name=DatasetName.IMAGE_DEFAULT,
):
num_ps_pods = 2
use_asyncs = [False, True]
model_versions = []
for use_async in use_asyncs:
grads_to_wait = 1 if use_async else 2
_, ps_channels, pservers = create_pserver(
_model_zoo_path,
model_def,
grads_to_wait,
use_async,
num_ps_pods,
)
try:
model_version = distributed_train_and_evaluate(
feature_shape,
_model_zoo_path,
model_def,
model_params=model_params,
training=True,
dataset_name=dataset_name,
use_async=use_async,
ps_channels=ps_channels,
pservers=pservers,
)
finally:
for pserver in pservers:
pserver.server.stop(0)
model_versions.append(model_version)
return model_versions

def _test_evaluate(
self,
feature_shape,
model_def,
model_params="",
dataset_name=DatasetName.IMAGE_DEFAULT,
):
num_ps_pods = 2
grads_to_wait = 1
_, ps_channels, pservers = create_pserver(
_model_zoo_path, model_def, grads_to_wait, False, num_ps_pods
)
try:
model_version = distributed_train_and_evaluate(
feature_shape,
_model_zoo_path,
model_def,
model_params=model_params,
training=False,
dataset_name=dataset_name,
ps_channels=ps_channels,
pservers=pservers,
)
finally:
for pserver in pservers:
pserver.server.stop(0)
return model_version

def test_deepfm_functional_train(self):
self._test_train(
10,
"deepfm_functional_api.deepfm_functional_api.custom_model",
"input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4",
dataset_name=DatasetName.FRAPPE,
)

def test_deepfm_functional_evaluate(self):
self._test_evaluate(
10,
"deepfm_functional_api.deepfm_functional_api.custom_model",
"input_dim=5383;embedding_dim=4;input_length=10;fc_unit=4",
dataset_name=DatasetName.FRAPPE,
)

def test_mnist_train(self):
model_defs = [
"mnist_functional_api.mnist_functional_api.custom_model",
"mnist_subclass.mnist_subclass.CustomModel",
]

model_versions = []
for model_def in model_defs:
versions = self._test_train(
feature_shape=[28, 28], model_def=model_def,
)

model_versions.extend(versions)
# async model version = sync model version * 2
self.assertEqual(model_versions[0] * 2, model_versions[1])
self.assertEqual(model_versions[2] * 2, model_versions[3])

def test_mnist_evaluate(self):
model_defs = [
"mnist_functional_api.mnist_functional_api.custom_model",
"mnist_subclass.mnist_subclass.CustomModel",
]
for model_def in model_defs:
self._test_evaluate([28, 28], model_def)

def test_cifar10_train(self):
model_defs = [
"cifar10_functional_api.cifar10_functional_api.custom_model",
"cifar10_subclass.cifar10_subclass.CustomModel",
]

model_versions = []
for model_def in model_defs:
versions = self._test_train([32, 32, 3], model_def,)
model_versions.extend(versions)
# async model version = sync model version * 2
self.assertEqual(model_versions[0] * 2, model_versions[1])
self.assertEqual(model_versions[2] * 2, model_versions[3])

def test_cifar10_evaluate(self):
model_defs = [
"cifar10_functional_api.cifar10_functional_api.custom_model",
"cifar10_subclass.cifar10_subclass.CustomModel",
]
for model_def in model_defs:
self._test_evaluate(
[32, 32, 3], model_def,
)

def test_resnet50_subclass_train(self):
self._test_train(
[224, 224, 3],
"resnet50_subclass.resnet50_subclass.CustomModel",
dataset_name=DatasetName.IMAGENET,
)

def test_resnet50_subclass_evaluate(self):
self._test_evaluate(
[224, 224, 3],
"resnet50_subclass.resnet50_subclass.CustomModel",
model_params='num_classes=10;dtype="float32"',
dataset_name=DatasetName.IMAGENET,
)


if __name__ == "__main__":
unittest.main()
23 changes: 11 additions & 12 deletions elasticdl/python/tests/in_process_master.py
Expand Up @@ -7,28 +7,27 @@ def __init__(self, master, callbacks=[]):
self._m = master
self._callbacks = callbacks

def GetTask(self, req):
return self._m.GetTask(req, None)

def GetModel(self, req):
return self._m.GetModel(req, None)

def ReportVariable(self, req):
return self._m.ReportVariable(req, None)
def get_task(self, req):
return self._m.get_task(req, None)

"""
def ReportGradient(self, req):
for callback in self._callbacks:
if test_call_back.ON_REPORT_GRADIENT_BEGIN in callback.call_times:
callback()
return self._m.ReportGradient(req, None)
"""

def ReportEvaluationMetrics(self, req):
def report_evaluation_metrics(self, req):
for callback in self._callbacks:
if test_call_back.ON_REPORT_EVALUATION_METRICS_BEGIN in (
callback.call_times
):
callback()
return self._m.ReportEvaluationMetrics(req, None)
return self._m.report_evaluation_metrics(req, None)

def report_task_result(self, req):
return self._m.report_task_result(req, None)

def ReportTaskResult(self, req):
return self._m.ReportTaskResult(req, None)
def report_version(self, req):
return self._m.report_version(req, None)
90 changes: 77 additions & 13 deletions elasticdl/python/tests/test_utils.py
Expand Up @@ -4,14 +4,20 @@
from contextlib import closing
from pathlib import Path

import grpc
import numpy as np
import recordio
import tensorflow as tf
from odps import ODPS

from elasticdl.proto import elasticdl_pb2
from elasticdl.python.common.args import parse_worker_args
from elasticdl.python.common.constants import JobType, ODPSConfig
from elasticdl.python.common.constants import (
DistributionStrategy,
JobType,
ODPSConfig,
)
from elasticdl.python.common.grpc_utils import build_channel
from elasticdl.python.common.model_utils import (
get_module_file_path,
load_module,
Expand All @@ -22,6 +28,7 @@
from elasticdl.python.master.evaluation_service import EvaluationService
from elasticdl.python.master.servicer import MasterServicer
from elasticdl.python.master.task_dispatcher import _TaskDispatcher
from elasticdl.python.ps.parameter_server import ParameterServer
from elasticdl.python.tests.in_process_master import InProcessMaster
from elasticdl.python.worker.worker import Worker

Expand Down Expand Up @@ -148,6 +155,31 @@ def create_recordio_file(size, dataset_name, shape, temp_dir=None):
return temp_file.name


def create_pserver(
model_zoo_path, model_def, grads_to_wait, use_async, num_ps_pods
):
ports = [i + 12345 for i in range(num_ps_pods)]
channels = []
for port in ports:
addr = "localhost:%d" % port
channel = build_channel(addr)
channels.append(channel)

pservers = []
for port in ports:
args = PserverArgs(
grads_to_wait=grads_to_wait,
use_async=True,
port=port,
model_zoo=model_zoo_path,
model_def=model_def,
)
pserver = ParameterServer(args)
pserver.prepare()
pservers.append(pserver)
return ports, channels, pservers


def distributed_train_and_evaluate(
feature_shape,
model_zoo_path,
Expand All @@ -158,7 +190,11 @@ def distributed_train_and_evaluate(
training=True,
dataset_name=DatasetName.IMAGE_DEFAULT,
callback_classes=[],
use_async=False,
get_model_steps=1,
ps_channels=None,
pservers=None,
distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
):
"""Runs distributed training and evaluation with a local master. Grpc
calls are mocked by local master call.
Expand All @@ -179,8 +215,14 @@ def distributed_train_and_evaluate(
dataset_name: A dataset name from `DatasetName`.
callback_classes: A List of callbacks that will be called at given
stages of the training procedure.
use_async: A bool. True if using asynchronous updates.
get_model_steps: Worker will perform `get_model` from the parameter
server every this many steps.
ps_channels: A channel list to all parameter server pods.
pservers: A list of parameter server pods.
distribution_strategy: The distribution startegy used by workers, e.g.
DistributionStrategy.PARAMETER_SERVER or
DistributionStrategy.AllreduceStrategy.
Returns:
An integer indicating the model version after the distributed training
Expand All @@ -191,8 +233,18 @@ def distributed_train_and_evaluate(
if training
else JobType.EVALUATION_ONLY
)
evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0
batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16
arguments = [
pservers = pservers or []
ps_channels = ps_channels or []

model_module = load_module(
get_module_file_path(model_zoo_path, model_def)
).__dict__

for channel in ps_channels:
grpc.channel_ready_future(channel).result()
worker_arguments = [
"--worker_id",
"1",
"--job_type",
Expand All @@ -209,9 +261,11 @@ def distributed_train_and_evaluate(
loss,
"--get_model_steps",
get_model_steps,
"--distribution_strategy",
distribution_strategy,
]
args = parse_worker_args(arguments)
worker = Worker(args)
args = parse_worker_args(worker_arguments)
worker = Worker(args, ps_channels=ps_channels)

if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]:
record_num = batch_size
Expand All @@ -237,16 +291,25 @@ def distributed_train_and_evaluate(
num_epochs=1,
)

model_module = load_module(
get_module_file_path(model_zoo_path, model_def)
).__dict__
if training:
evaluation_service = EvaluationService(
None, task_d, 0, 0, 1, False, model_module[eval_metrics_fn],
None,
task_d,
0,
0,
evaluation_steps,
False,
model_module[eval_metrics_fn],
)
else:
evaluation_service = EvaluationService(
None, task_d, 0, 0, 0, True, model_module[eval_metrics_fn],
None,
task_d,
0,
0,
evaluation_steps,
True,
model_module[eval_metrics_fn],
)
task_d.set_evaluation_service(evaluation_service)

Expand All @@ -256,16 +319,17 @@ def distributed_train_and_evaluate(
callbacks = [
callback_class(master, worker) for callback_class in callback_classes
]
worker._stub = InProcessMaster(master, callbacks)

for var in worker._model.trainable_variables:
master.set_model_var(var.name, var.numpy())
in_process_master = InProcessMaster(master, callbacks)
worker._stub = in_process_master
for pservicer in pservers:
pservicer._master_stub = in_process_master

worker.run()

req = elasticdl_pb2.GetTaskRequest()
req.worker_id = 1
task = master.GetTask(req, None)
task = master.get_task(req, None)
# No more task.
if task.shard_name:
raise RuntimeError(
Expand Down

0 comments on commit 13654d6

Please sign in to comment.