Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tfx/extensions/experimental/kubernetes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
242 changes: 242 additions & 0 deletions tfx/extensions/experimental/kubernetes/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper class to start TFX multi-worker training jobs on Kubernetes."""

import json
from typing import Any, Dict, List, Text

from absl import logging

from tfx import types
from tfx import version
from tfx.types import artifact_utils
from tfx.utils import kube_utils
from kubernetes.client.rest import ApiException
import kubernetes.client as client

# Default TFX container image to use in Kubernetes Training. For GPU
# training, specify a custom image in executor.TRAINING_ARGS_KEY.
_TFX_IMAGE = 'tensorflow/tfx:%s' % (version.__version__)

_COMMAND = ["python", "-m", "tfx.scripts.run_executor"]


def _build_pod_names(num_workers: int, unique_id: Text) -> List[Text]:
return ['training-worker-{}-{}'.format(unique_id,
i) for i in range(num_workers)]


def _build_service_names(num_workers: int, unique_id: Text) -> List[Text]:
return ['training-service-{}-{}'.format(unique_id,
i) for i in range(num_workers)]


def _pod_is_done(resp: client.V1Pod):
return kube_utils.PodPhase(resp.status.phase).is_done


def create_worker_pods(job_args: List[Text],
training_inputs: Dict[Text, Any],
unique_id: Text):
"""Create worker pods for multi-worker training."""
tfx_image = training_inputs.get('tfx_image', _TFX_IMAGE)
num_workers = training_inputs.get('num_workers', 1)
num_gpus_per_worker = training_inputs.get('num_gpus_per_worker', 0)

api_instance = kube_utils.make_core_v1_api()
service_names = _build_service_names(num_workers=num_workers,
unique_id=unique_id)
pod_names = _build_pod_names(num_workers=num_workers, unique_id=unique_id)
worker_hosts = ['{}:5000'.format(svc_name) for svc_name in service_names]

# TODO(ericlege): consider using a jinja2 template instead
for i in range(num_workers):
tf_config = json.dumps({
'cluster': {
'worker': worker_hosts
},
'task': {'type': 'worker', 'index': i}
})
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
name=pod_names[i],
labels={
'name': 'training',
'id': unique_id,
'task': str(i),
},
),
spec=client.V1PodSpec(
containers=[
client.V1Container(
name='worker-pod',
image=tfx_image,
command=_COMMAND,
args=job_args,
security_context=client.V1SecurityContext(
privileged=True,
),
env=[
client.V1EnvVar(
name='TF_CONFIG',
value=tf_config,
),
],
ports=[
client.V1ContainerPort(
container_port=5000,
),
],
resources=client.V1ResourceRequirements(
limits={
'nvidia.com/gpu': num_gpus_per_worker,
},
) if num_gpus_per_worker > 0 else None,
),
],
restart_policy=kube_utils.RestartPolicy.NEVER.value,
),
)
try:
api_instance.create_namespaced_pod(namespace='default', body=pod)
except ApiException as e:
raise RuntimeError('Worker pod creation failed.') from e
logging.info('created {} worker pods'.format(num_workers))


def create_worker_services(training_inputs: Dict[Text, Any],
unique_id: Text):
"""Create worker services for multi-worker training."""
num_workers = training_inputs.get('num_workers', 1)
service_names = _build_service_names(num_workers=num_workers,
unique_id=unique_id)
api_instance = kube_utils.make_core_v1_api()

# TODO(ericlege): consider using a jinja2 template instead
for i in range(num_workers):
service = client.V1Service(
metadata=client.V1ObjectMeta(
name=service_names[i],
),
spec=client.V1ServiceSpec(
selector={
'name': 'training',
'id': unique_id,
'task': str(i),
},
ports=[
client.V1ServicePort(
port=5000,
),
],
),
)
try:
api_instance.create_namespaced_service(namespace='default', body=service)
except ApiException as e:
raise RuntimeError('Worker service creation failed.') from e
logging.info('created {} worker services'.format(num_workers))


def delete_worker_services(training_inputs: Dict[Text, Any],
unique_id: Text):
"""Clean up worker services deployed to the kubernetes cluster."""
num_workers = training_inputs.get('num_workers', 1)
service_names = _build_service_names(num_workers=num_workers,
unique_id=unique_id)
api_instance = kube_utils.make_core_v1_api()
for service_name in service_names:
try:
api_instance.delete_namespaced_service(namespace='default',
name=service_name)
except ApiException as e:
logging.error(
'Exception when calling CoreV1Api.delete_namespaced_service: %s' % e)
logging.info('Deleted {} worker services'.format(num_workers))


def start_kubernetes_training(input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any],
executor_class_path: Text,
training_inputs: Dict[Text,Any],
unique_id: Text):
"""Start a trainer job on Kubernetes.

This is done by forwarding the inputs/outputs/exec_properties to the
tfx.scripts.run_executor module on a kubernetes pod.

Args:
input_dict: Passthrough input dict for tfx.components.Trainer.executor.
output_dict: Passthrough input dict for tfx.components.Trainer.executor.
exec_properties: Passthrough input dict for tfx.components.Trainer.executor.
executor_class_path: class path for TFX core default trainer.
training_inputs: Training input argument for Kubernetes.
'num_workers', 'num_gpus_per_worker' and 'tfx_image' will be consumed.

Returns:
None

Raises:
RuntimeError: if the Google Kubernetes Engine training job failed/cancelled.
"""
training_inputs = training_inputs.copy()

json_inputs = artifact_utils.jsonify_artifact_dict(input_dict)
logging.info('json_inputs=\'%s\'.', json_inputs)
json_outputs = artifact_utils.jsonify_artifact_dict(output_dict)
logging.info('json_outputs=\'%s\'.', json_outputs)
json_exec_properties = json.dumps(exec_properties, sort_keys=True)
logging.info('json_exec_properties=\'%s\'.', json_exec_properties)


# We use custom containers to launch training on Kubernetes, which invokes
# the specified image using the container's entrypoint. The entrypoint used
# for the worker conatiner is to call scripts/run_executor.py. The arguments
# below are passed to this run_executor entry to run the executor specified
# in `executor_class_path`.
job_args = [
'--executor_class_path', executor_class_path, '--inputs', json_inputs,
'--outputs', json_outputs, '--exec-properties', json_exec_properties
]

# Launch the ClusterIP services.
create_worker_services(training_inputs=training_inputs, unique_id=unique_id)

# Launch the worker pods.
create_worker_pods(job_args=job_args,
training_inputs=training_inputs,
unique_id=unique_id)

# Wait indefinitely until training finishes.
num_workers = training_inputs.get('num_workers', 1)
pod_names = _build_pod_names(unique_id=unique_id,
num_workers=num_workers)
resp = kube_utils.wait_pod(core_api=kube_utils.make_core_v1_api(),
pod_name=pod_names[0], # chief
namespace='default',
exit_condition_lambda=_pod_is_done,
condition_description='Chief finished',
exponential_backoff=True)

# Clean up the ClusterIP services.
delete_worker_services(training_inputs=training_inputs, unique_id=unique_id)

if resp.status.phase == kube_utils.PodPhase.FAILED.value:
raise RuntimeError('Pod "%s:%s" failed with status "%s".' %
('default', pod_names[0], resp.status))


# Kubernetes training complete.
logging.info('Job successful.')
107 changes: 107 additions & 0 deletions tfx/extensions/experimental/kubernetes/runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.extensions.experimental.kubernetes.runner."""

import copy
import os
from typing import Any, Dict, Text, List

import mock
import tensorflow as tf

from tfx.extensions.experimental.kubernetes import runner
from tfx.extensions.experimental.kubernetes.trainer import executor
from tfx.utils import json_utils


def mock_build_service_names(num_workers: int, unique_id: Text) -> List[Text]:
return ['TEST-SERVICE-{}-{}'.format(unique_id, i) for i in range(num_workers)]


class RunnerTest(tf.test.TestCase):

def setUp(self):
super(RunnerTest, self).setUp()
self._output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)
self._mock_api_client = mock.Mock()
self._mock_pod = mock.Mock()
self._mock_service = mock.Mock()
self._inputs = {}
self._outputs = {}
self._unique_id = "UNIQUE_ID"
self._num_workers = 5
self._num_gpus_per_worker = 2
self._training_inputs = {
'num_workers': self._num_workers,
'num_gpus_per_worker': self._num_gpus_per_worker
}
# Dict format of exec_properties. custom_config needs to be serialized
# before being passed into start_aip_training function.
self._exec_properties = {
'custom_config': {
executor.TRAINING_ARGS_KEY: self._training_inputs,
},
}
self._model_name = 'model_name'
self._executor_class_path = 'my.executor.Executor'

def _set_up_training_mocks(self):
self._mock_create_pod = mock.Mock()
self._mock_api_client.create_namespaced_pod = self._mock_create_pod
self._mock_create_service = mock.Mock()
self._mock_api_client.create_namespaced_service = self._mock_create_service
self._mock_delete_service = mock.Mock()
self._mock_api_client.create_delete_service = self._mock_delete_service

def _serialize_custom_config_under_test(self) -> Dict[Text, Any]:
"""Converts self._exec_properties['custom_config'] to string."""
result = copy.deepcopy(self._exec_properties)
result['custom_config'] = json_utils.dumps(result['custom_config'])
return result

@mock.patch.object(runner, '_build_service_names', mock_build_service_names)
@mock.patch('tfx.extensions.experimental.kubernetes.runner.client')
@mock.patch('tfx.extensions.experimental.kubernetes.runner.kube_utils')
def testStartKubernetesTraining(self, mock_kube_utils, mock_client):
mock_client.V1Pod.return_value = self._mock_pod
mock_client.V1Service.return_value = self._mock_service
mock_kube_utils.make_core_v1_api.return_value = self._mock_api_client
mock_kube_utils.wait_pod.return_value = mock.Mock()
self._set_up_training_mocks()

runner.start_kubernetes_training(self._inputs, self._outputs,
self._serialize_custom_config_under_test(),
self._executor_class_path,
self._training_inputs, self._unique_id)

self._mock_api_client.create_namespaced_service.assert_called_with(
namespace='default',
body=self._mock_service)

self._mock_api_client.create_namespaced_pod.assert_called_with(
namespace='default',
body=self._mock_pod)

expected_service_names = mock_build_service_names(self._num_workers,
self._unique_id)
expected_calls = [mock.call(namespace='default', name=expected_service_name)
for expected_service_name in expected_service_names]
self.assertEqual(expected_calls,
self._mock_api_client.delete_namespaced_service.mock_calls)


if __name__ == '__main__':
tf.test.main()
13 changes: 13 additions & 0 deletions tfx/extensions/experimental/kubernetes/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading