From 9e0521252253eab09571dc2be40f46bfeaf9746a Mon Sep 17 00:00:00 2001 From: Ilya Trushkin Date: Mon, 4 Oct 2021 12:45:22 +0300 Subject: [PATCH] Custom aggregation functions in interactive API (#183) * Specify aggregation function in Interactive API * Use pickling for aggregation function * Restore old file structure * Restore old file structure * Revert formatting * Fix linter * Update signature for aggregation function ovveride * Fix linter * Remove notebook * Remove deserializing in old API * Store serializer as plan property * Fixes * Clear notebook outputs * Clear notebook metadata * Add functools.wraps to decorator method * Revert notebook to develop state * Add PR-related changes to notebook * Use decorating wrapper for task interface * Remove functolls.wraps decorators in TaskInterface * Remove commented out code --- docs/overriding_agg_fn.rst | 15 +++++++ .../Updated_Kvasir_with_Director.ipynb | 5 +++ .../aggregation_functions/interface.py | 1 - openfl/component/aggregator/aggregator.py | 4 +- openfl/databases/tensor_db.py | 3 -- openfl/federated/plan/plan.py | 42 ++++++++++++++----- .../interface/interactive_api/experiment.py | 40 +++++++++++++++++- 7 files changed, 93 insertions(+), 17 deletions(-) diff --git a/docs/overriding_agg_fn.rst b/docs/overriding_agg_fn.rst index 5fdb7ff4bb..c36e18b7af 100644 --- a/docs/overriding_agg_fn.rst +++ b/docs/overriding_agg_fn.rst @@ -61,6 +61,21 @@ Example of ``plan/plan.yaml`` with modified aggregation function: metrics: - loss +Interactive API +================ +You can override aggregation function that will be used for the task this function corresponds to. +In order to do this, call the ``set_agg_fn`` decorator method of ``TaskInterface`` and pass ``AggregationFunctionInterface`` subclass instance as a parameter. +For example, you can try: + +..code_block:: python + from openfl.component.aggregation_functions import Median + TI = TaskInterface() + agg_fn = Median() + @TI.register_fl_task(model='model', data_loader='train_loader', \ + device='device', optimizer='optimizer') + @TI.set_agg_fn(agg_fn) + + ``AggregationFunctionInterface`` requires a single ``call`` function. This function receives tensors for a single parameter from multiple collaborators with additional metadata (see definition of :meth:`openfl.component.aggregation_functions.AggregationFunctionInterface.call`) and returns a single tensor that represents the result of aggregation. diff --git a/openfl-tutorials/interactive_api/Director_Pytorch_Kvasir_UNET/workspace/Updated_Kvasir_with_Director.ipynb b/openfl-tutorials/interactive_api/Director_Pytorch_Kvasir_UNET/workspace/Updated_Kvasir_with_Director.ipynb index 8d611b69b0..6150f6e7e3 100644 --- a/openfl-tutorials/interactive_api/Director_Pytorch_Kvasir_UNET/workspace/Updated_Kvasir_with_Director.ipynb +++ b/openfl-tutorials/interactive_api/Director_Pytorch_Kvasir_UNET/workspace/Updated_Kvasir_with_Director.ipynb @@ -378,15 +378,20 @@ "import torch\n", "\n", "import tqdm\n", + "from openfl.component.aggregation_functions import Median\n", "\n", "# The Interactive API supports registering functions definied in main module or imported.\n", "def function_defined_in_notebook(some_parameter):\n", " print(f'Also I accept a parameter and it is {some_parameter}')\n", "\n", + "#The Interactive API supports overriding of the aggregation function\n", + "aggregation_function = Median()\n", + "\n", "# Task interface currently supports only standalone functions.\n", "@TI.add_kwargs(**{'some_parameter': 42})\n", "@TI.register_fl_task(model='unet_model', data_loader='train_loader', \\\n", " device='device', optimizer='optimizer') \n", + "@TI.set_aggregation_function(aggregation_function)\n", "def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):\n", " if not torch.cuda.is_available():\n", " device = 'cpu'\n", diff --git a/openfl/component/aggregation_functions/interface.py b/openfl/component/aggregation_functions/interface.py index 0fe8c1aa59..2d74b37686 100644 --- a/openfl/component/aggregation_functions/interface.py +++ b/openfl/component/aggregation_functions/interface.py @@ -19,7 +19,6 @@ class AggregationFunctionInterface(metaclass=SingletonABCMeta): @abstractmethod def call(self, local_tensors: List[LocalTensor], - weights: np.ndarray, db_iterator: Iterator[pd.Series], tensor_name: str, fl_round: int, diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 28dafea2bc..066ed4b64c 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -5,6 +5,7 @@ import queue from logging import getLogger +from openfl.component.aggregation_functions import WeightedAverage from openfl.databases import TensorDB from openfl.pipelines import NoCompressionPipeline from openfl.pipelines import TensorCodec @@ -762,7 +763,7 @@ def _compute_validation_related_task_metrics(self, task_name): # collaborator in our subset, and apply the correct # transformations to the tensorkey to resolve the aggregated # tensor for that round - agg_function = self.assigner.get_aggregation_type_for_task(task_name) + task_agg_function = self.assigner.get_aggregation_type_for_task(task_name) task_key = TaskResultKey(task_name, collaborators_for_task[0], self.round_number) for tensor_key in self.collaborator_tasks_results[task_key]: tensor_name, origin, round_number, report, tags = tensor_key @@ -774,6 +775,7 @@ def _compute_validation_related_task_metrics(self, task_name): new_tags = tuple(tags[:-1]) agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags) agg_tensor_name, agg_origin, agg_round_number, agg_report, agg_tags = agg_tensor_key + agg_function = WeightedAverage() if 'metric' in tags else task_agg_function agg_results = self.tensor_db.get_aggregated_tensor( agg_tensor_key, collaborator_weight_dict, aggregation_function=agg_function) if report: diff --git a/openfl/databases/tensor_db.py b/openfl/databases/tensor_db.py index b27d06c2bd..bc334740d8 100644 --- a/openfl/databases/tensor_db.py +++ b/openfl/databases/tensor_db.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd -from openfl.component.aggregation_functions import WeightedAverage from openfl.utilities import LocalTensor from openfl.utilities import TensorKey @@ -164,8 +163,6 @@ def get_aggregated_tensor(self, tensor_key, collaborator_weight_dict, for col_name in collaborator_names] db_iterator = self._iterate() - if 'metric' in tags: - aggregation_function = WeightedAverage() # force simple averaging for metrics agg_nparray = aggregation_function(local_tensors, db_iterator, tensor_name, diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 420f09d217..00d7aa517b 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -223,6 +223,7 @@ def __init__(self): self.hash_ = None self.name_ = None + self.serializer_ = None @property def hash(self): # NOQA @@ -259,7 +260,8 @@ def get_assigner(self): defaults[SETTINGS]['authorized_cols'] = self.authorized_cols defaults[SETTINGS]['rounds_to_train'] = self.rounds_to_train - defaults[SETTINGS]['tasks'] = self.get_tasks() + task_interface = self.restore_object('tasks_obj.pkl') + defaults[SETTINGS]['tasks'] = self.get_tasks(task_interface) if self.assigner_ is None: self.assigner_ = Plan.build(**defaults) @@ -294,11 +296,16 @@ def get_aggregator(self, tensor_dict=None): return self.aggregator_ - def get_tasks(self): + def get_tasks(self, task_interface=None): """Get federation tasks.""" tasks = self.config.get('tasks', {}) tasks.pop(DEFAULTS, None) tasks.pop(SETTINGS, None) + if task_interface: + for task in tasks: + agg_fn = task_interface.get_aggregation_function(tasks[task]['function']) + tasks[task]['aggregation_type'] = agg_fn + return tasks for task in tasks: aggregation_type = tasks[task].get('aggregation_type') if aggregation_type is None: @@ -524,16 +531,31 @@ def interactive_api_get_server(self, *, tensor_dict, root_certificate, certifica def deserialize_interface_objects(self): """Deserialize objects for TaskRunner.""" - serializer = Plan.build( - self.config['api_layer']['required_plugin_components']['serializer_plugin'], {}) + api_layer = self.config['api_layer'] filenames = [ 'model_interface_file', 'tasks_interface_file', 'dataloader_interface_file' ] - interface_objects = [ - serializer.restore_object(self.config['api_layer']['settings'][filename]) - for filename in filenames - ] - model_provider, task_keeper, data_loader = interface_objects - return model_provider, task_keeper, data_loader + return (self.restore_object(api_layer['settings'][filename]) for filename in filenames) + + def get_serializer_plugin(self, **kwargs): + """Get serializer plugin. + + This plugin is used for serialization of interfaces in new interactive API + """ + if self.serializer_ is None: + if 'api_layer' not in self.config: # legacy API + return None + required_plugin_components = self.config['api_layer']['required_plugin_components'] + serializer_plugin = required_plugin_components['serializer_plugin'] + self.serializer_ = Plan.build(serializer_plugin, kwargs) + return self.serializer_ + + def restore_object(self, filename): + """Deserialize an object.""" + serializer_plugin = self.get_serializer_plugin() + if serializer_plugin is None: + return None + obj = serializer_plugin.restore_object(filename) + return obj diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index b94c4cb1b1..fa1f73f3d9 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 """Python low-level API module.""" -import functools import os import time from collections import defaultdict @@ -12,6 +11,8 @@ from tensorboardX import SummaryWriter +from openfl.component.aggregation_functions import AggregationFunctionInterface +from openfl.component.aggregation_functions import WeightedAverage from openfl.federated import Plan from openfl.interface.cli import setup_logging from openfl.interface.cli_helper import WORKSPACE @@ -360,6 +361,8 @@ def __init__(self) -> None: self.task_contract = {} # Mapping 'task name' -> arguments self.task_settings = defaultdict(dict) + # Mapping task name -> callable + self.aggregation_functions = {} def register_fl_task(self, model, data_loader, device, optimizer=None): """ @@ -389,7 +392,6 @@ def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=3 def decorator_with_args(training_method): # We could pass hooks to the decorator # @functools.wraps(training_method) - functools.wraps(training_method) def wrapper_decorator(**task_keywords): metric_dict = training_method(**task_keywords) @@ -423,6 +425,40 @@ def decorator_with_args(training_method): return decorator_with_args + def set_aggregation_function(self, aggregation_function: AggregationFunctionInterface): + """Set aggregation function for the task. + + Args: + aggregation_function: Aggregation function. + + You might need to override default FedAvg aggregation with built-in aggregation types: + - openfl.component.aggregation_functions.GeometricMedian + - openfl.component.aggregation_functions.Median + or define your own AggregationFunctionInterface subclass. + See more details on `Overriding the aggregation function`_ documentation page. + .. _Overriding the aggregation function: + https://openfl.readthedocs.io/en/latest/overriding_agg_fn.html + """ + def decorator_with_args(training_method): + if not isinstance(aggregation_function, AggregationFunctionInterface): + raise Exception('aggregation_function must implement ' + 'AggregationFunctionInterface interface.') + self.aggregation_functions[training_method.__name__] = aggregation_function + return training_method + return decorator_with_args + + def get_aggregation_function(self, function_name): + """Get aggregation type for the task function. + + Args: + function_name(str): Task function name. + Returns: + Return value. Aggregation function. + """ + if function_name not in self.aggregation_functions: + return WeightedAverage() + return self.aggregation_functions[function_name] + class ModelInterface: """