Skip to content

Commit

Permalink
Custom aggregation functions in interactive API (#183)
Browse files Browse the repository at this point in the history
* Specify aggregation function in Interactive API

* Use pickling for aggregation function

* Restore old file structure

* Restore old file structure

* Revert formatting

* Fix linter

* Update signature for aggregation function ovveride

* Fix linter

* Remove notebook

* Remove deserializing in old API

* Store serializer as plan property

* Fixes

* Clear notebook outputs

* Clear notebook metadata

* Add functools.wraps to decorator method

* Revert notebook to develop state

* Add PR-related changes to notebook

* Use decorating wrapper for task interface

* Remove functolls.wraps decorators in TaskInterface

* Remove commented out code
  • Loading branch information
itrushkin committed Oct 4, 2021
1 parent 189d8de commit 9e05212
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 17 deletions.
15 changes: 15 additions & 0 deletions docs/overriding_agg_fn.rst
Expand Up @@ -61,6 +61,21 @@ Example of ``plan/plan.yaml`` with modified aggregation function:
metrics:
- loss
Interactive API
================
You can override aggregation function that will be used for the task this function corresponds to.
In order to do this, call the ``set_agg_fn`` decorator method of ``TaskInterface`` and pass ``AggregationFunctionInterface`` subclass instance as a parameter.
For example, you can try:

..code_block:: python
from openfl.component.aggregation_functions import Median
TI = TaskInterface()
agg_fn = Median()
@TI.register_fl_task(model='model', data_loader='train_loader', \
device='device', optimizer='optimizer')
@TI.set_agg_fn(agg_fn)


``AggregationFunctionInterface`` requires a single ``call`` function.
This function receives tensors for a single parameter from multiple collaborators with additional metadata (see definition of :meth:`openfl.component.aggregation_functions.AggregationFunctionInterface.call`) and returns a single tensor that represents the result of aggregation.

Expand Down
Expand Up @@ -378,15 +378,20 @@
"import torch\n",
"\n",
"import tqdm\n",
"from openfl.component.aggregation_functions import Median\n",
"\n",
"# The Interactive API supports registering functions definied in main module or imported.\n",
"def function_defined_in_notebook(some_parameter):\n",
" print(f'Also I accept a parameter and it is {some_parameter}')\n",
"\n",
"#The Interactive API supports overriding of the aggregation function\n",
"aggregation_function = Median()\n",
"\n",
"# Task interface currently supports only standalone functions.\n",
"@TI.add_kwargs(**{'some_parameter': 42})\n",
"@TI.register_fl_task(model='unet_model', data_loader='train_loader', \\\n",
" device='device', optimizer='optimizer') \n",
"@TI.set_aggregation_function(aggregation_function)\n",
"def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):\n",
" if not torch.cuda.is_available():\n",
" device = 'cpu'\n",
Expand Down
1 change: 0 additions & 1 deletion openfl/component/aggregation_functions/interface.py
Expand Up @@ -19,7 +19,6 @@ class AggregationFunctionInterface(metaclass=SingletonABCMeta):
@abstractmethod
def call(self,
local_tensors: List[LocalTensor],
weights: np.ndarray,
db_iterator: Iterator[pd.Series],
tensor_name: str,
fl_round: int,
Expand Down
4 changes: 3 additions & 1 deletion openfl/component/aggregator/aggregator.py
Expand Up @@ -5,6 +5,7 @@
import queue
from logging import getLogger

from openfl.component.aggregation_functions import WeightedAverage
from openfl.databases import TensorDB
from openfl.pipelines import NoCompressionPipeline
from openfl.pipelines import TensorCodec
Expand Down Expand Up @@ -762,7 +763,7 @@ def _compute_validation_related_task_metrics(self, task_name):
# collaborator in our subset, and apply the correct
# transformations to the tensorkey to resolve the aggregated
# tensor for that round
agg_function = self.assigner.get_aggregation_type_for_task(task_name)
task_agg_function = self.assigner.get_aggregation_type_for_task(task_name)
task_key = TaskResultKey(task_name, collaborators_for_task[0], self.round_number)
for tensor_key in self.collaborator_tasks_results[task_key]:
tensor_name, origin, round_number, report, tags = tensor_key
Expand All @@ -774,6 +775,7 @@ def _compute_validation_related_task_metrics(self, task_name):
new_tags = tuple(tags[:-1])
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags)
agg_tensor_name, agg_origin, agg_round_number, agg_report, agg_tags = agg_tensor_key
agg_function = WeightedAverage() if 'metric' in tags else task_agg_function
agg_results = self.tensor_db.get_aggregated_tensor(
agg_tensor_key, collaborator_weight_dict, aggregation_function=agg_function)
if report:
Expand Down
3 changes: 0 additions & 3 deletions openfl/databases/tensor_db.py
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pandas as pd

from openfl.component.aggregation_functions import WeightedAverage
from openfl.utilities import LocalTensor
from openfl.utilities import TensorKey

Expand Down Expand Up @@ -164,8 +163,6 @@ def get_aggregated_tensor(self, tensor_key, collaborator_weight_dict,
for col_name in collaborator_names]

db_iterator = self._iterate()
if 'metric' in tags:
aggregation_function = WeightedAverage() # force simple averaging for metrics
agg_nparray = aggregation_function(local_tensors,
db_iterator,
tensor_name,
Expand Down
42 changes: 32 additions & 10 deletions openfl/federated/plan/plan.py
Expand Up @@ -223,6 +223,7 @@ def __init__(self):

self.hash_ = None
self.name_ = None
self.serializer_ = None

@property
def hash(self): # NOQA
Expand Down Expand Up @@ -259,7 +260,8 @@ def get_assigner(self):

defaults[SETTINGS]['authorized_cols'] = self.authorized_cols
defaults[SETTINGS]['rounds_to_train'] = self.rounds_to_train
defaults[SETTINGS]['tasks'] = self.get_tasks()
task_interface = self.restore_object('tasks_obj.pkl')
defaults[SETTINGS]['tasks'] = self.get_tasks(task_interface)

if self.assigner_ is None:
self.assigner_ = Plan.build(**defaults)
Expand Down Expand Up @@ -294,11 +296,16 @@ def get_aggregator(self, tensor_dict=None):

return self.aggregator_

def get_tasks(self):
def get_tasks(self, task_interface=None):
"""Get federation tasks."""
tasks = self.config.get('tasks', {})
tasks.pop(DEFAULTS, None)
tasks.pop(SETTINGS, None)
if task_interface:
for task in tasks:
agg_fn = task_interface.get_aggregation_function(tasks[task]['function'])
tasks[task]['aggregation_type'] = agg_fn
return tasks
for task in tasks:
aggregation_type = tasks[task].get('aggregation_type')
if aggregation_type is None:
Expand Down Expand Up @@ -524,16 +531,31 @@ def interactive_api_get_server(self, *, tensor_dict, root_certificate, certifica

def deserialize_interface_objects(self):
"""Deserialize objects for TaskRunner."""
serializer = Plan.build(
self.config['api_layer']['required_plugin_components']['serializer_plugin'], {})
api_layer = self.config['api_layer']
filenames = [
'model_interface_file',
'tasks_interface_file',
'dataloader_interface_file'
]
interface_objects = [
serializer.restore_object(self.config['api_layer']['settings'][filename])
for filename in filenames
]
model_provider, task_keeper, data_loader = interface_objects
return model_provider, task_keeper, data_loader
return (self.restore_object(api_layer['settings'][filename]) for filename in filenames)

def get_serializer_plugin(self, **kwargs):
"""Get serializer plugin.
This plugin is used for serialization of interfaces in new interactive API
"""
if self.serializer_ is None:
if 'api_layer' not in self.config: # legacy API
return None
required_plugin_components = self.config['api_layer']['required_plugin_components']
serializer_plugin = required_plugin_components['serializer_plugin']
self.serializer_ = Plan.build(serializer_plugin, kwargs)
return self.serializer_

def restore_object(self, filename):
"""Deserialize an object."""
serializer_plugin = self.get_serializer_plugin()
if serializer_plugin is None:
return None
obj = serializer_plugin.restore_object(filename)
return obj
40 changes: 38 additions & 2 deletions openfl/interface/interactive_api/experiment.py
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

"""Python low-level API module."""
import functools
import os
import time
from collections import defaultdict
Expand All @@ -12,6 +11,8 @@

from tensorboardX import SummaryWriter

from openfl.component.aggregation_functions import AggregationFunctionInterface
from openfl.component.aggregation_functions import WeightedAverage
from openfl.federated import Plan
from openfl.interface.cli import setup_logging
from openfl.interface.cli_helper import WORKSPACE
Expand Down Expand Up @@ -360,6 +361,8 @@ def __init__(self) -> None:
self.task_contract = {}
# Mapping 'task name' -> arguments
self.task_settings = defaultdict(dict)
# Mapping task name -> callable
self.aggregation_functions = {}

def register_fl_task(self, model, data_loader, device, optimizer=None):
"""
Expand Down Expand Up @@ -389,7 +392,6 @@ def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=3
def decorator_with_args(training_method):
# We could pass hooks to the decorator
# @functools.wraps(training_method)
functools.wraps(training_method)

def wrapper_decorator(**task_keywords):
metric_dict = training_method(**task_keywords)
Expand Down Expand Up @@ -423,6 +425,40 @@ def decorator_with_args(training_method):

return decorator_with_args

def set_aggregation_function(self, aggregation_function: AggregationFunctionInterface):
"""Set aggregation function for the task.
Args:
aggregation_function: Aggregation function.
You might need to override default FedAvg aggregation with built-in aggregation types:
- openfl.component.aggregation_functions.GeometricMedian
- openfl.component.aggregation_functions.Median
or define your own AggregationFunctionInterface subclass.
See more details on `Overriding the aggregation function`_ documentation page.
.. _Overriding the aggregation function:
https://openfl.readthedocs.io/en/latest/overriding_agg_fn.html
"""
def decorator_with_args(training_method):
if not isinstance(aggregation_function, AggregationFunctionInterface):
raise Exception('aggregation_function must implement '
'AggregationFunctionInterface interface.')
self.aggregation_functions[training_method.__name__] = aggregation_function
return training_method
return decorator_with_args

def get_aggregation_function(self, function_name):
"""Get aggregation type for the task function.
Args:
function_name(str): Task function name.
Returns:
Return value. Aggregation function.
"""
if function_name not in self.aggregation_functions:
return WeightedAverage()
return self.aggregation_functions[function_name]


class ModelInterface:
"""
Expand Down

0 comments on commit 9e05212

Please sign in to comment.