Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use client_id instead of cert name in no tls case #166

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions openfl/interface/interactive_api/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Federation:
their local data and network setting to enable communication in federation.
"""

def __init__(self, client_id, director_node_fqdn=None, director_port=None, tls=True,
def __init__(self, director_node_fqdn=None, director_port=None, client_id=None, tls=True,
cert_chain=None, api_cert=None, api_private_key=None) -> None:
"""
Initialize federation.
Expand All @@ -31,7 +31,6 @@ def __init__(self, client_id, director_node_fqdn=None, director_port=None, tls=T
- director_node_fqdn: Address and port a director's service is running on.
User passes here an address with a port.
"""
self.client_id = client_id
if director_node_fqdn is None:
self.director_node_fqdn = getfqdn_env()
else:
Expand All @@ -45,9 +44,9 @@ def __init__(self, client_id, director_node_fqdn=None, director_port=None, tls=T

# Create Director client
self.dir_client = DirectorClient(
client_id=client_id,
director_host=director_node_fqdn,
director_port=director_port,
client_id=client_id,
tls=tls,
root_certificate=cert_chain,
private_key=api_private_key,
Expand Down
12 changes: 1 addition & 11 deletions openfl/protocols/director.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ import "google/protobuf/duration.proto";
import "federation.proto";


message RequestHeader {
string sender = 1;
}

// Envoy Messages

message NodeInfo {
Expand Down Expand Up @@ -54,7 +50,6 @@ message ExperimentData {
// API Messages

message ExperimentInfo {
RequestHeader header = 1;
string name = 2;
repeated string collaborator_names = 3;
ExperimentData experiment_data = 4;
Expand All @@ -71,7 +66,6 @@ message GetTrainedModelRequest {
BEST_MODEL = 0;
LAST_MODEL = 1;
}
RequestHeader header = 1;
string experiment_name = 2;
ModelType model_type = 3;
}
Expand All @@ -80,12 +74,9 @@ message TrainedModelResponse {
ModelProto model_proto = 1;
}

message GetDatasetInfoRequest {
RequestHeader header = 1;
}
message GetDatasetInfoRequest {}

message StreamMetricsRequest {
RequestHeader header = 1;
string experiment_name = 2;
}

Expand All @@ -98,7 +89,6 @@ message StreamMetricsResponse {
}

message RemoveExperimentRequest {
RequestHeader header = 1;
string experiment_name = 2;
}

Expand Down
2,085 changes: 1,002 additions & 1,083 deletions openfl/protocols/director_pb2.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions openfl/protocols/federation_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 3 additions & 14 deletions openfl/protocols/federation_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
Expand All @@ -8,10 +6,7 @@


class AggregatorStub(object):
"""Copyright (C) 2020 Intel Corporation
Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

"""
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Expand All @@ -37,10 +32,7 @@ def __init__(self, channel):


class AggregatorServicer(object):
"""Copyright (C) 2020 Intel Corporation
Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

"""
"""Missing associated documentation comment in .proto file."""

def GetTasks(self, request, context):
"""Missing associated documentation comment in .proto file."""
Expand Down Expand Up @@ -86,10 +78,7 @@ def add_AggregatorServicer_to_server(servicer, server):

# This class is part of an EXPERIMENTAL API.
class Aggregator(object):
"""Copyright (C) 2020 Intel Corporation
Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

"""
"""Missing associated documentation comment in .proto file."""

@staticmethod
def GetTasks(request,
Expand Down
75 changes: 75 additions & 0 deletions openfl/protocols/interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""gRPC interceptors module."""
import collections

import grpc


class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor):

def __init__(self, interceptor_function):
self._fn = interceptor_function

def intercept_unary_unary(self, continuation, client_call_details, request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response

def intercept_unary_stream(self, continuation, client_call_details,
request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, True)
response_it = continuation(new_details, next(new_request_iterator))
return postprocess(response_it) if postprocess else response_it

def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, False)
response = continuation(new_details, new_request_iterator)
return postprocess(response) if postprocess else response

def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it


def _create_generic_interceptor(intercept_call):
return _GenericClientInterceptor(intercept_call)


class _ClientCallDetails(
collections.namedtuple(
'_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials')
),
grpc.ClientCallDetails
):
pass


def headers_adder(headers):
"""Create interceptor with added headers."""

def intercept_call(client_call_details, request_iterator, request_streaming,
response_streaming):
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
for header, value in headers.items():
metadata.append((
header,
value,
))
client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials)
return client_call_details, request_iterator, None

return _create_generic_interceptor(intercept_call)
16 changes: 12 additions & 4 deletions openfl/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def model_proto_to_bytes_and_metadata(model_proto):
round_number = None
for tensor_proto in model_proto.tensors:
bytes_dict[tensor_proto.name] = tensor_proto.data_bytes
metadata_dict[tensor_proto.name] = [{'int_to_float': proto.int_to_float,
'int_list': proto.int_list,
'bool_list': proto.bool_list} for proto in
tensor_proto.transformer_metadata]
metadata_dict[tensor_proto.name] = [{
'int_to_float': proto.int_to_float,
'int_list': proto.int_list,
'bool_list': proto.bool_list
}
for proto in tensor_proto.transformer_metadata
]
if round_number is None:
round_number = tensor_proto.round_number
else:
Expand Down Expand Up @@ -255,3 +258,8 @@ def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)):
chunk = npbytes[i: i + buffer_size]
reply = DataStream(npbytes=chunk, size=len(chunk))
yield reply


def get_headers(context) -> dict:
"""Get headers from context."""
return {header[0]: header[1] for header in context.invocation_metadata()}
31 changes: 16 additions & 15 deletions openfl/transport/grpc/director_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openfl.pipelines import NoCompressionPipeline
from openfl.protocols import director_pb2
from openfl.protocols import director_pb2_grpc
from openfl.protocols import interceptors
from openfl.protocols.utils import construct_model_proto
from openfl.protocols.utils import deconstruct_model_proto

Expand Down Expand Up @@ -113,14 +114,22 @@ def send_health_check(self, *, collaborator_name: str, is_experiment_running: bo
class DirectorClient:
"""Director client class for users."""

def __init__(self, *, client_id, director_host, director_port, tls=True,
def __init__(self, *, director_host, director_port, client_id=None, tls=True,
root_certificate=None, private_key=None, certificate=None) -> None:
"""Initialize director client object."""
director_addr = f'{director_host}:{director_port}'
channel_opt = [('grpc.max_send_message_length', 512 * 1024 * 1024),
('grpc.max_receive_message_length', 512 * 1024 * 1024)]
if not tls:
if client_id is None:
raise Exception('"client_id" is mandatory in case of tls == False')
channel = grpc.insecure_channel(director_addr, options=channel_opt)
headers = {
'client_id': client_id,
}
header_interceptor = interceptors.headers_adder(headers)
channel = grpc.intercept_channel(channel, header_interceptor)

else:
if not (root_certificate and private_key and certificate):
raise Exception('No certificates provided')
Expand All @@ -143,9 +152,6 @@ def __init__(self, *, client_id, director_host, director_port, tls=True,
channel = grpc.secure_channel(director_addr, credentials, options=channel_opt)
self.stub = director_pb2_grpc.FederationDirectorStub(channel)

self.client_id = client_id
self.header = director_pb2.RequestHeader(sender=self.client_id)

def set_new_experiment(self, name, col_names, arch_path,
initial_tensor_dict=None):
"""Send the new experiment to director to launch."""
Expand All @@ -170,7 +176,6 @@ def _get_experiment_info(self, arch_path, name, col_names, model_proto):
raise StopIteration
# TODO: add hash or/and size to check
experiment_info = director_pb2.ExperimentInfo(
header=self.header,
name=name,
collaborator_names=col_names,
model_proto=model_proto
Expand All @@ -182,19 +187,19 @@ def _get_experiment_info(self, arch_path, name, col_names, model_proto):

def get_dataset_info(self):
"""Request the dataset info from the director."""
resp = self.stub.GetDatasetInfo(director_pb2.GetDatasetInfoRequest(header=self.header))
resp = self.stub.GetDatasetInfo(director_pb2.GetDatasetInfoRequest())
return resp.sample_shape, resp.target_shape

def _get_trained_model(self, experiment_name, model_type):
"""Get trained model RPC."""
get_model_request = director_pb2.GetTrainedModelRequest(
header=self.header,
experiment_name=experiment_name,
model_type=model_type)
model_type=model_type,
)
model_proto_response = self.stub.GetTrainedModel(get_model_request)
tensor_dict, _ = deconstruct_model_proto(
model_proto_response.model_proto,
NoCompressionPipeline()
NoCompressionPipeline(),
)
return tensor_dict

Expand All @@ -210,9 +215,7 @@ def get_last_model(self, experiment_name):

def stream_metrics(self, experiment_name):
"""Stream metrics RPC."""
request = director_pb2.StreamMetricsRequest(
header=self.header,
experiment_name=experiment_name)
request = director_pb2.StreamMetricsRequest(experiment_name=experiment_name)
for metric_message in self.stub.StreamMetrics(request):
yield {
'metric_origin': metric_message.metric_origin,
Expand All @@ -224,9 +227,7 @@ def stream_metrics(self, experiment_name):

def remove_experiment_data(self, name):
"""Remove experiment data RPC."""
request = director_pb2.RemoveExperimentRequest(
header=self.header,
experiment_name=name)
request = director_pb2.RemoveExperimentRequest(experiment_name=name)
response = self.stub.RemoveExperimentData(request)
return response.acknowledgement

Expand Down
Loading