Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom aggregation functions in interactive API #183

Merged
merged 22 commits into from Oct 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/overriding_agg_fn.rst
Expand Up @@ -61,6 +61,21 @@ Example of ``plan/plan.yaml`` with modified aggregation function:
metrics:
- loss

Interactive API
================
You can override aggregation function that will be used for the task this function corresponds to.
In order to do this, call the ``set_agg_fn`` decorator method of ``TaskInterface`` and pass ``AggregationFunctionInterface`` subclass instance as a parameter.
For example, you can try:

..code_block:: python
from openfl.component.aggregation_functions import Median
TI = TaskInterface()
agg_fn = Median()
@TI.register_fl_task(model='model', data_loader='train_loader', \
device='device', optimizer='optimizer')
@TI.set_agg_fn(agg_fn)


``AggregationFunctionInterface`` requires a single ``call`` function.
This function receives tensors for a single parameter from multiple collaborators with additional metadata (see definition of :meth:`openfl.component.aggregation_functions.AggregationFunctionInterface.call`) and returns a single tensor that represents the result of aggregation.

Expand Down
Expand Up @@ -378,15 +378,20 @@
"import torch\n",
"\n",
"import tqdm\n",
"from openfl.component.aggregation_functions import Median\n",
"\n",
"# The Interactive API supports registering functions definied in main module or imported.\n",
"def function_defined_in_notebook(some_parameter):\n",
" print(f'Also I accept a parameter and it is {some_parameter}')\n",
"\n",
"#The Interactive API supports overriding of the aggregation function\n",
"aggregation_function = Median()\n",
"\n",
"# Task interface currently supports only standalone functions.\n",
"@TI.add_kwargs(**{'some_parameter': 42})\n",
"@TI.register_fl_task(model='unet_model', data_loader='train_loader', \\\n",
" device='device', optimizer='optimizer') \n",
"@TI.set_aggregation_function(aggregation_function)\n",
"def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):\n",
" if not torch.cuda.is_available():\n",
" device = 'cpu'\n",
Expand Down
1 change: 0 additions & 1 deletion openfl/component/aggregation_functions/interface.py
Expand Up @@ -19,7 +19,6 @@ class AggregationFunctionInterface(metaclass=SingletonABCMeta):
@abstractmethod
def call(self,
local_tensors: List[LocalTensor],
weights: np.ndarray,
db_iterator: Iterator[pd.Series],
tensor_name: str,
fl_round: int,
Expand Down
4 changes: 3 additions & 1 deletion openfl/component/aggregator/aggregator.py
Expand Up @@ -5,6 +5,7 @@
import queue
from logging import getLogger

from openfl.component.aggregation_functions import WeightedAverage
from openfl.databases import TensorDB
from openfl.pipelines import NoCompressionPipeline
from openfl.pipelines import TensorCodec
Expand Down Expand Up @@ -762,7 +763,7 @@ def _compute_validation_related_task_metrics(self, task_name):
# collaborator in our subset, and apply the correct
# transformations to the tensorkey to resolve the aggregated
# tensor for that round
agg_function = self.assigner.get_aggregation_type_for_task(task_name)
task_agg_function = self.assigner.get_aggregation_type_for_task(task_name)
task_key = TaskResultKey(task_name, collaborators_for_task[0], self.round_number)
for tensor_key in self.collaborator_tasks_results[task_key]:
tensor_name, origin, round_number, report, tags = tensor_key
Expand All @@ -774,6 +775,7 @@ def _compute_validation_related_task_metrics(self, task_name):
new_tags = tuple(tags[:-1])
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags)
agg_tensor_name, agg_origin, agg_round_number, agg_report, agg_tags = agg_tensor_key
agg_function = WeightedAverage() if 'metric' in tags else task_agg_function
agg_results = self.tensor_db.get_aggregated_tensor(
agg_tensor_key, collaborator_weight_dict, aggregation_function=agg_function)
if report:
Expand Down
3 changes: 0 additions & 3 deletions openfl/databases/tensor_db.py
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pandas as pd

from openfl.component.aggregation_functions import WeightedAverage
from openfl.utilities import LocalTensor
from openfl.utilities import TensorKey

Expand Down Expand Up @@ -164,8 +163,6 @@ def get_aggregated_tensor(self, tensor_key, collaborator_weight_dict,
for col_name in collaborator_names]

db_iterator = self._iterate()
if 'metric' in tags:
aggregation_function = WeightedAverage() # force simple averaging for metrics
agg_nparray = aggregation_function(local_tensors,
db_iterator,
tensor_name,
Expand Down
42 changes: 32 additions & 10 deletions openfl/federated/plan/plan.py
Expand Up @@ -223,6 +223,7 @@ def __init__(self):

self.hash_ = None
self.name_ = None
self.serializer_ = None

@property
def hash(self): # NOQA
Expand Down Expand Up @@ -259,7 +260,8 @@ def get_assigner(self):

defaults[SETTINGS]['authorized_cols'] = self.authorized_cols
defaults[SETTINGS]['rounds_to_train'] = self.rounds_to_train
defaults[SETTINGS]['tasks'] = self.get_tasks()
task_interface = self.restore_object('tasks_obj.pkl')
defaults[SETTINGS]['tasks'] = self.get_tasks(task_interface)

if self.assigner_ is None:
self.assigner_ = Plan.build(**defaults)
Expand Down Expand Up @@ -294,11 +296,16 @@ def get_aggregator(self, tensor_dict=None):

return self.aggregator_

def get_tasks(self):
def get_tasks(self, task_interface=None):
"""Get federation tasks."""
tasks = self.config.get('tasks', {})
tasks.pop(DEFAULTS, None)
tasks.pop(SETTINGS, None)
if task_interface:
for task in tasks:
agg_fn = task_interface.get_aggregation_function(tasks[task]['function'])
tasks[task]['aggregation_type'] = agg_fn
return tasks
for task in tasks:
aggregation_type = tasks[task].get('aggregation_type')
if aggregation_type is None:
Expand Down Expand Up @@ -524,16 +531,31 @@ def interactive_api_get_server(self, *, tensor_dict, root_certificate, certifica

def deserialize_interface_objects(self):
"""Deserialize objects for TaskRunner."""
serializer = Plan.build(
self.config['api_layer']['required_plugin_components']['serializer_plugin'], {})
api_layer = self.config['api_layer']
filenames = [
'model_interface_file',
'tasks_interface_file',
'dataloader_interface_file'
]
interface_objects = [
serializer.restore_object(self.config['api_layer']['settings'][filename])
for filename in filenames
]
model_provider, task_keeper, data_loader = interface_objects
return model_provider, task_keeper, data_loader
return (self.restore_object(api_layer['settings'][filename]) for filename in filenames)

def get_serializer_plugin(self, **kwargs):
"""Get serializer plugin.

This plugin is used for serialization of interfaces in new interactive API
"""
if self.serializer_ is None:
if 'api_layer' not in self.config: # legacy API
return None
required_plugin_components = self.config['api_layer']['required_plugin_components']
serializer_plugin = required_plugin_components['serializer_plugin']
self.serializer_ = Plan.build(serializer_plugin, kwargs)
return self.serializer_

def restore_object(self, filename):
"""Deserialize an object."""
serializer_plugin = self.get_serializer_plugin()
if serializer_plugin is None:
return None
obj = serializer_plugin.restore_object(filename)
return obj
40 changes: 38 additions & 2 deletions openfl/interface/interactive_api/experiment.py
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

"""Python low-level API module."""
import functools
import os
import time
from collections import defaultdict
Expand All @@ -12,6 +11,8 @@

from tensorboardX import SummaryWriter

from openfl.component.aggregation_functions import AggregationFunctionInterface
from openfl.component.aggregation_functions import WeightedAverage
from openfl.federated import Plan
from openfl.interface.cli import setup_logging
from openfl.interface.cli_helper import WORKSPACE
Expand Down Expand Up @@ -360,6 +361,8 @@ def __init__(self) -> None:
self.task_contract = {}
# Mapping 'task name' -> arguments
self.task_settings = defaultdict(dict)
# Mapping task name -> callable
self.aggregation_functions = {}

def register_fl_task(self, model, data_loader, device, optimizer=None):
"""
Expand Down Expand Up @@ -389,7 +392,6 @@ def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=3
def decorator_with_args(training_method):
# We could pass hooks to the decorator
# @functools.wraps(training_method)
functools.wraps(training_method)

def wrapper_decorator(**task_keywords):
metric_dict = training_method(**task_keywords)
Expand Down Expand Up @@ -423,6 +425,40 @@ def decorator_with_args(training_method):

return decorator_with_args

def set_aggregation_function(self, aggregation_function: AggregationFunctionInterface):
"""Set aggregation function for the task.

Args:
aggregation_function: Aggregation function.

You might need to override default FedAvg aggregation with built-in aggregation types:
- openfl.component.aggregation_functions.GeometricMedian
- openfl.component.aggregation_functions.Median
or define your own AggregationFunctionInterface subclass.
See more details on `Overriding the aggregation function`_ documentation page.
.. _Overriding the aggregation function:
https://openfl.readthedocs.io/en/latest/overriding_agg_fn.html
"""
def decorator_with_args(training_method):
if not isinstance(aggregation_function, AggregationFunctionInterface):
raise Exception('aggregation_function must implement '
'AggregationFunctionInterface interface.')
self.aggregation_functions[training_method.__name__] = aggregation_function
return training_method
aleksandr-mokrov marked this conversation as resolved.
Show resolved Hide resolved
return decorator_with_args

def get_aggregation_function(self, function_name):
"""Get aggregation type for the task function.

Args:
function_name(str): Task function name.
Returns:
Return value. Aggregation function.
"""
if function_name not in self.aggregation_functions:
return WeightedAverage()
return self.aggregation_functions[function_name]


class ModelInterface:
"""
Expand Down