Skip to content

Commit

Permalink
move gRPC initialization to where it's used
Browse files Browse the repository at this point in the history
Closes #56
  • Loading branch information
mortendahl committed Oct 26, 2020
1 parent c04b476 commit 43a597e
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 16 deletions.
4 changes: 0 additions & 4 deletions examples/keras/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import asyncio
import logging

from grpc.experimental import aio

from moose.cluster.cluster_spec import load_cluster_spec
from moose.logger import get_logger
from moose.worker import Worker
Expand All @@ -20,8 +18,6 @@
get_logger().setLevel(level=logging.DEBUG)

if __name__ == "__main__":
aio.init_grpc_aio()

get_logger().info(f"Starting on {args.host}:{args.port}")
cluster_spec = load_cluster_spec(args.cluster_spec)
worker = Worker(args.name, args.host, args.port, cluster_spec)
Expand Down
3 changes: 0 additions & 3 deletions examples/mp-spdz/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import logging

from grpc.experimental import aio

from moose.compiler.edsl import HostPlacement
from moose.compiler.edsl import add
from moose.compiler.edsl import computation
Expand Down Expand Up @@ -77,7 +75,6 @@ def my_comp():
concrete_comp = my_comp.trace_func()

if __name__ == "__main__":
aio.init_grpc_aio()
runtime = RemoteRuntime(args.cluster_spec)
runtime.evaluate_computation(
computation=concrete_comp,
Expand Down
4 changes: 0 additions & 4 deletions examples/mp-spdz/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import asyncio
import logging

from grpc.experimental import aio

from moose.cluster.cluster_spec import load_cluster_spec
from moose.logger import get_logger
from moose.worker import Worker
Expand All @@ -20,8 +18,6 @@
get_logger().setLevel(level=logging.DEBUG)

if __name__ == "__main__":
aio.init_grpc_aio()

get_logger().info(f"Starting on {args.host}:{args.port}")
cluster_spec = load_cluster_spec(args.cluster_spec)
worker = Worker(args.name, args.host, args.port, cluster_spec)
Expand Down
3 changes: 0 additions & 3 deletions examples/python-functions/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging
import os

from grpc.experimental import aio

from moose.compiler.edsl import HostPlacement
from moose.compiler.edsl import add
from moose.compiler.edsl import computation
Expand Down Expand Up @@ -63,7 +61,6 @@ def my_comp():
concrete_comp = my_comp.trace_func()

if __name__ == "__main__":
aio.init_grpc_aio()
if args.runtime == "test":
runtime = TestRuntime(workers=concrete_comp.devices())
elif args.runtime == "remote":
Expand Down
1 change: 1 addition & 0 deletions moose/channels/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Channel:
def __init__(self, endpoint, buffer, ca_cert, ident_cert, ident_key):
self._buffer = buffer

aio.init_grpc_aio()
if ca_cert:
credentials = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
Expand Down
1 change: 1 addition & 0 deletions moose/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class RemoteExecutor:
def __init__(self, endpoint, ca_cert, ident_cert, ident_key):
aio.init_grpc_aio()
if ca_cert:
credentials = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
Expand Down
3 changes: 1 addition & 2 deletions moose/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def __init__(
ident_key_filename,
allow_insecure_networking=False,
):
aio.init_grpc_aio()

ca_cert = load_certificate(ca_cert_filename)
ident_cert = load_certificate(ident_cert_filename)
ident_key = load_certificate(ident_key_filename)
Expand All @@ -93,6 +91,7 @@ def __init__(
executor = AsyncExecutor(name=name, channel_manager=channel_manager)

# set up server
aio.init_grpc_aio()
# self._server = aio.server(interceptors=(MyInterceptor(),))
self._server = aio.server()

Expand Down

0 comments on commit 43a597e

Please sign in to comment.