Skip to content
Permalink
Browse files

Functions to initialize embedding service (#1053)

* Functions to initialize embedding service

* redis-py-cluster update to v2.0.0

* Respond to comments
  • Loading branch information...
yhjust1 committed Aug 21, 2019
1 parent 055ed1b commit a0946f9981d22311f4431c20d61c6a92acfd9fe4
Showing with 280 additions and 8 deletions.
  1. +1 −1 .isort.cfg
  2. +228 −7 elasticdl/python/common/embedding_service.py
  3. +51 −0 elasticdl/python/tests/embedding_service_test.py
@@ -1,5 +1,5 @@
[settings]
multi_line_output=3
line_length=79
known_third_party = PIL,docker,google,grpc,kubernetes,mock,numpy,odps,pyspark,recordio,requests,setuptools,tensorflow
known_third_party = PIL,docker,google,grpc,kubernetes,mock,numpy,odps,pyspark,recordio,rediscluster,requests,setuptools,tensorflow
include_trailing_comma=True
@@ -1,17 +1,230 @@
import argparse
import subprocess
import time

from rediscluster import RedisCluster

from elasticdl.python.common import k8s_client as k8s
from elasticdl.python.common.args import pos_int
from elasticdl.python.common.log_util import default_logger as logger


class EmbeddingService(object):
"""Redis implementation of EmbeddingService"""

def __init__(self, redis_address_map=None, replicas=1):
def __init__(self, embedding_endpoint=None, replicas=1):
"""
# TODO
Arguments:
embedding_endpoint: The address(ip/url) and service's port map for
Redis cluster.
{
address0: [port list],
address1: [port list],
...
}
replicas: Number of slaves per redis master
Logic of starting embedding service :
master.main EmbeddingService k8s_client
| | |
1 --------------> 2 -----------> |
| | 3
5 <-------------- 4 <---------- |
1. master.main calls EmbeddingService.start_embedding_service
when the embedding service is required by the model.
2. EmbeddingService.start_embedding_service calls
EmbeddingService.start_embedding_pod_and_redis to ask
k8s_client create pods for Redis.
3. k8s_client creates pods, then pods call
EmbeddingService.start_redis_service() to start their local
redis instances.
4. After pods running, EmbeddingService.start_embedding_service
gets and saves addresses(ip/dns and port) of pods, and creates a
Redis Cluster base on these addresses.
5. EmbeddingService.start_embedding_service returns addresses to
master.main, master.main saves addresses for master/worker
accessing the Redis.
"""
pass
self._embedding_endpoint = embedding_endpoint
self._replicas = replicas

def start_embedding_service(self):
pass
def start_embedding_service(
self,
embedding_service_id=0,
resource_request="cpu=1,memory=4096Mi",
resource_limit="cpu=1,memory=4096Mi",
pod_priority=None,
volume=None,
image_pull_policy=None,
restart_policy="Never",
**kargs,
):
self.start_embedding_pod_and_redis(
command=["python"],
args=["-m", "elasticdl.python.common.embedding_service"],
embedding_service_id=embedding_service_id,
resource_request=resource_request,
resource_limit=resource_limit,
pod_priority=pod_priority,
volume=volume,
image_pull_policy=image_pull_policy,
restart_policy=restart_policy,
**kargs,
)
return self._create_redis_cluster()

def stop_embedding_service(self):
pass
def _create_redis_cluster(self):
redis_cluster_command = " ".join(
[
"%s:%d" % (ip, port)
for ip in self._embedding_endpoint
for port in self._embedding_endpoint[ip]
]
)
try:
command = (
"echo yes | redis-cli --cluster create %s "
"--cluster-replicas %d"
% (redis_cluster_command, self._replicas)
)
redis_process = subprocess.Popen(
[command], shell=True, stdout=subprocess.DEVNULL
)
redis_process.wait()
except Exception as e:
logger.error(e)
return None
else:
return self._embedding_endpoint

def stop_embedding_service(self, save="nosave"):
for redis_node in [
"-h %s -p %d" % (ip, port)
for ip in self._embedding_endpoint
for port in self._embedding_endpoint[ip]
]:
try:
command = "redis-cli %s shutdown %s" % (redis_node, save)
redis_process = subprocess.Popen(
[command], shell=True, stdout=subprocess.DEVNULL
)
redis_process.wait()
except Exception as e:
logger.error(e)
return False

return True

def _get_embedding_cluster(self):
startup_nodes = [
{"host": ip, "port": "%d" % (port)}
for ip in self._embedding_endpoint
for port in self._embedding_endpoint[ip]
]
try:
redis_cluster = RedisCluster(
startup_nodes=startup_nodes, decode_responses=False
)
except Exception as e:
logger.error(e)
return None
else:
return redis_cluster

def _parse_embedding_service_args(self):
parser = argparse.ArgumentParser(description="Embedding Service")
parser.add_argument(
"--first_port",
default=30001,
type=pos_int,
help="The first listening port of embedding service",
)
parser.add_argument(
"--num_of_redis_instances",
default=6,
type=pos_int,
help="The number of redis instances",
)
parser.add_argument(
"--cluster_node_timeout",
default=2000,
type=pos_int,
help="The maximum amount of time a Redis Cluster node "
"can be unavailable",
)

args = parser.parse_args()

return args

def start_redis_service(self):
args = self._parse_embedding_service_args()
logger.info(
"Starting redis server on ports: %d - %d, "
"--cluster_node_timeout %d"
% (
args.first_port,
args.first_port + args.num_of_redis_instances - 1,
args.cluster_node_timeout,
)
)
for i in range(args.num_of_redis_instances):
port = args.first_port + i
command = (
"redis-server --port %d --cluster-enabled yes "
"--cluster-config-file nodes-%d.conf --cluster-node-timeout"
" %d --appendonly yes --appendfilename appendonly-%d.aof "
"--dbfilename dump-%d.rdb --logfile %d.log --daemonize yes "
"--protected-mode no"
% (port, port, args.cluster_node_timeout, port, port, port)
)
redis_process = subprocess.Popen(
[command], shell=True, stdout=subprocess.DEVNULL
)
redis_process.wait()

# TODO: Now, we use single pod to start redis cluster service, we
# should support a redis cluster service running on multi-pods in
# the future.
def start_embedding_pod_and_redis(
self,
command,
args,
embedding_service_id=0,
resource_request="cpu=1,memory=4096Mi",
resource_limit="cpu=1,memory=4096Mi",
pod_priority=None,
volume=None,
image_pull_policy=None,
restart_policy="Never",
**kargs,
):
logger.info("Starting pod for embedding service ...")
self._k8s_client = k8s.Client(event_callback=None, **kargs)
pod = self._k8s_client.create_embedding_service(
worker_id=embedding_service_id,
resource_requests=resource_request,
resource_limits=resource_limit,
pod_priority=pod_priority,
volume=volume,
image_pull_policy=image_pull_policy,
command=command,
args=args,
restart_policy=restart_policy,
)

# TODO: assign address with pod's domain name instead of pod's ip.
# and should not fix ports
address_ip = pod.status.pod_ip
while not address_ip:
pod = self._k8s_client.get_embedding_service_pod(
embedding_service_id
)
address_ip = pod.status.pod_ip
self._embedding_endpoint = {address_ip: [30001 + i for i in range(6)]}

@staticmethod
def lookup_embedding(**kwargs):
@@ -20,3 +233,11 @@ def lookup_embedding(**kwargs):
@staticmethod
def update_embedding(**kwargs):
pass


if __name__ == "__main__":
EmbeddingService().start_redis_service()

# TODO: Keep the pod running with kubernetes config
while True:
time.sleep(1)
@@ -0,0 +1,51 @@
import subprocess
import time
import unittest

from elasticdl.python.common.embedding_service import EmbeddingService


def start_redis_instances():
for i in range(6):
port = 33001 + i
embedding_process = subprocess.Popen(
[
"redis-server --port %d --cluster-enabled yes "
"--cluster-config-file nodes-%d.conf "
"--cluster-node-timeout 200 --appendonly yes --appendfilename "
"appendonly-%d.aof --dbfilename dump-%d.rdb "
"--logfile %d.log --daemonize yes --protected-mode no"
% (port, port, port, port, port)
],
shell=True,
stdout=subprocess.DEVNULL,
)
embedding_process.wait()

embedding_endpoint = {"127.0.0.1": [33001 + i for i in range(6)]}
return embedding_endpoint


class EmbeddingServiceTest(unittest.TestCase):
def test_embedding_service(self):
embedding_endpoint = start_redis_instances()
# start
embedding_service = EmbeddingService(embedding_endpoint)
embedding_endpoint = embedding_service._create_redis_cluster()
# wait for cluster up-running
time.sleep(1)
self.assertFalse(embedding_endpoint is None)
# connection
redis_cluster = embedding_service._get_embedding_cluster()
self.assertFalse(redis_cluster is None)
# set value to a key
self.assertTrue(redis_cluster.set("test_key", "OK", nx=True))
# set value to a key existed
self.assertTrue(redis_cluster.set("test_key", "OK", nx=True) is None)
self.assertEqual(b"OK", redis_cluster.get("test_key"))
# close
self.assertTrue(embedding_service.stop_embedding_service())


if __name__ == "__main__":
unittest.main()

0 comments on commit a0946f9

Please sign in to comment.
You can’t perform that action at this time.