Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task Assigner initial implementation #343

Merged
134 changes: 134 additions & 0 deletions docs/running_the_federation.rst
Expand Up @@ -479,6 +479,9 @@ The following are parameters of the :code:`start()` method in FLExperiment:
:code:`data_loader`
This parameter is defined earlier by the :code:`DataInterface` object.

:code:`task_assigner`
This parameter is optional. You can pass a `Custom task assigner function`_.

:code:`rounds_to_train`
This parameter defines the number of aggregation rounds needed to be conducted before the experiment is considered finished.

Expand Down Expand Up @@ -520,6 +523,137 @@ When the experiment has completed:

You may use the same federation object to report another experiment or even schedule several experiments that will be executed in series.

Custom task assigner function
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
OpenFL has an entity named Task Assigner, that responsible for aggregator task assigning to collaborators.
There are three default tasks that are used: :code:`train`, :code:`locally_tuned_model_validate`,
:code:`aggregated_model_validate`.
When you register a train function and pass optimizer it generates a train task:

.. code-block:: python

task_keeper = TaskInterface()


@task_keeper.register_fl_task(model='net_model', data_loader='train_loader',
device='device', optimizer='optimizer')
def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
torch.manual_seed(0)
...

When you register a validate function, it generates two tasks: :code:`locally_tuned_model_validate` and
:code:`aggregated_model_validate`.
:code:`locally_tuned_model_validate` is applied by collaborator to locally trained model,
:code:`aggregated_model_validate` - to a globally aggregated model.
If there not a train task only aggregated_model_validate are generated.

Since 1.3 version it is possible to create a custom task assigner function to implement your own task assigning logic.
You can get registered task from :code:`task_keeper` calling method :code:`get_registered_tasks`:

.. code-block:: python

tasks = task_keeper.get_registered_tasks()


And then implement your own assigner function:

.. code-block:: python

def random_assigner(collaborators, round_number, **kwargs):
"""Assigning task groups randomly while ensuring target distribution"""
import random
random.shuffle(collaborators)
collaborator_task_map = {}
for idx, col in enumerate(collaborators):
# select only 70% collaborators for training and validation, 30% for validation
if (idx+1)/len(collaborators) <= 0.7:
collaborator_task_map[col] = tasks.values() # all three tasks
else:
collaborator_task_map[col] = [tasks['aggregated_model_validate']]
return collaborator_task_map

And then pass that function to fl_experiment start method:
.. code-block:: python

fl_experiment.start(
model_provider=model_interface,
task_keeper=task_keeper,
data_loader=fed_dataset,
task_assigner=random_assigner,
rounds_to_train=50,
opt_treatment='CONTINUE_GLOBAL',
device_assignment_policy='CUDA_PREFERRED'
)


It will be passed to assigner and tasks will be assigned to collaborators by using this function.

Another example.
If you want only exclude some collaborators from experiment, you can define next assigner function:

.. code-block:: python

def filter_assigner(collaborators, round_number, **kwargs):
collaborator_task_map = {}
exclude_collaborators = ['env_two', 'env_three']
for collaborator_name in collaborators:
if collaborator_name in exclude_collaborators:
continue
collaborator_task_map[collaborator_name] = [
tasks['train'],
tasks['locally_tuned_model_validate'],
tasks['aggregated_model_validate']
]
return collaborator_task_map


Also you can use static shard information to exclude any collaborators without cuda devices from training:

.. code-block:: python

shard_registry = federation.get_shard_registry()
def filter_by_shard_registry_assigner(collaborators, round_number, **kwargs):
collaborator_task_map = {}
for collaborator in collaborators:
col_status = shard_registry.get(collaborator)
if not col_status or not col_status['is_online']:
continue
node_info = col_status['shard_info'].node_info
# Assign train task if collaborator has GPU with total memory more that 8 GB
if len(node_info.cuda_devices) > 0 and node_info.cuda_devices[0].memory_total > 8 * 1024**3:
collaborator_task_map[collaborator] = [
tasks['train'],
tasks['locally_tuned_model_validate'],
tasks['aggregated_model_validate'],
]
else:
collaborator_task_map[collaborator] = [
tasks['aggregated_model_validate'],
]
return collaborator_task_map


Assigner with additional validation round:

.. code-block:: python

rounds_to_train = 3
total_rounds = rounds_to_train + 1 # use fl_experiment.start(..., rounds_to_train=total_rounds,...)

def assigner_with_last_round_validation(collaborators, round_number, **kwargs):
collaborator_task_map = {}
for collaborator in collaborators:
if round_number == total_rounds - 1:
collaborator_task_map[collaborator] = [
tasks['aggregated_model_validate'],
]
else:
collaborator_task_map[collaborator] = [
tasks['train'],
tasks['locally_tuned_model_validate'],
tasks['aggregated_model_validate']
]
return collaborator_task_map


.. _running_the_federation_aggregator_based:
Expand Down
30 changes: 16 additions & 14 deletions openfl/component/aggregator/aggregator.py
Expand Up @@ -276,9 +276,7 @@ def get_tasks(self, collaborator_name):
time_to_quit = False

# otherwise, get the tasks from our task assigner
tasks = self.assigner.get_tasks_for_collaborator(
collaborator_name,
self.round_number) # fancy task assigners may want aggregator state
tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number)

# if no tasks, tell the collaborator to sleep
if len(tasks) == 0:
Expand All @@ -288,10 +286,17 @@ def get_tasks(self, collaborator_name):
return tasks, self.round_number, sleep_time, time_to_quit

# if we do have tasks, remove any that we already have results for
tasks = [
t for t in tasks if not self._collaborator_task_completed(
collaborator_name, t, self.round_number)
]
if isinstance(tasks[0], str):
# backward compatibility
tasks = [
t for t in tasks if not self._collaborator_task_completed(
collaborator_name, t, self.round_number)
]
else:
tasks = [
t for t in tasks if not self._collaborator_task_completed(
collaborator_name, t.name, self.round_number)
]

# Do the check again because it's possible that all tasks have
# been completed
Expand Down Expand Up @@ -751,7 +756,8 @@ def _compute_validation_related_task_metrics(self, task_name):
# This handles getting the subset of collaborators that may be
# part of the validation task
collaborators_for_task = self.assigner.get_collaborators_for_task(
task_name, self.round_number)
task_name, self.round_number
)
# The collaborator data sizes for that task
collaborator_weights_unnormalized = {
c: self.collaborator_task_weight[TaskResultKey(task_name, c, self.round_number)]
Expand All @@ -775,7 +781,6 @@ def _compute_validation_related_task_metrics(self, task_name):
assert (tags[-1] == collaborators_for_task[0]), (
f'Tensor {tensor_key} in task {task_name} has not been processed correctly'
)

# Strip the collaborator label, and lookup aggregated tensor
new_tags = tuple(tags[:-1])
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags)
Expand Down Expand Up @@ -843,7 +848,6 @@ def _end_of_round_check(self):
self._compute_validation_related_task_metrics(task_name)

# Once all of the task results have been processed
# Increment the round number
self.round_number += 1

# Save the latest model
Expand Down Expand Up @@ -873,11 +877,9 @@ def _is_task_done(self, task_name):

def _is_round_done(self):
"""Check that round is done."""
tasks_for_round = self.assigner.get_all_tasks_for_round(
self.round_number
)
tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number)

return all([self._is_task_done(t) for t in tasks_for_round])
return all([self._is_task_done(task_name) for task_name in tasks_for_round])

def _log_big_warning(self):
"""Warn user about single collaborator cert mode."""
Expand Down
75 changes: 75 additions & 0 deletions openfl/component/assigner/custom_assigner.py
@@ -0,0 +1,75 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Custom Assigner module."""


import logging
from collections import defaultdict

from openfl.component.aggregation_functions import WeightedAverage

logger = logging.getLogger(__name__)


class Assigner:
"""Custom assigner class."""

def __init__(
self,
*,
assigner_function,
aggregation_functions_by_task,
authorized_cols,
rounds_to_train
):
"""Initialize."""
self.agg_functions_by_task = aggregation_functions_by_task
self.agg_functions_by_task_name = {}
self.authorized_cols = authorized_cols
self.rounds_to_train = rounds_to_train
self.all_tasks_for_round = defaultdict(dict)
self.collaborators_for_task = defaultdict(lambda: defaultdict(list))
self.collaborator_tasks = defaultdict(lambda: defaultdict(list))
self.assigner_function = assigner_function

self.define_task_assignments()

def define_task_assignments(self):
"""Abstract method."""
for round_number in range(self.rounds_to_train):
tasks_by_collaborator = self.assigner_function(
self.authorized_cols,
round_number,
number_of_callaborators=len(self.authorized_cols)
)
for collaborator_name, tasks in tasks_by_collaborator.items():
self.collaborator_tasks[round_number][collaborator_name].extend(tasks)
for task in tasks:
self.all_tasks_for_round[round_number][task.name] = task
self.collaborators_for_task[round_number][task.name].append(collaborator_name)
if self.agg_functions_by_task:
self.agg_functions_by_task_name[
task.name
] = self.agg_functions_by_task.get(task.function_name, WeightedAverage())

def get_tasks_for_collaborator(self, collaborator_name, round_number):
"""Abstract method."""
return self.collaborator_tasks[round_number][collaborator_name]

def get_collaborators_for_task(self, task_name, round_number):
"""Abstract method."""
return self.collaborators_for_task[round_number][task_name]

def get_all_tasks_for_round(self, round_number):
"""
Return tasks for the current round.

Currently all tasks are performed on each round,
But there may be a reason to change this.
"""
return [task.name for task in self.all_tasks_for_round[round_number].values()]

def get_aggregation_type_for_task(self, task_name):
"""Extract aggregation type from self.tasks."""
agg_fn = self.agg_functions_by_task_name.get(task_name, WeightedAverage())
return agg_fn
32 changes: 32 additions & 0 deletions openfl/component/assigner/tasks.py
@@ -0,0 +1,32 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Task module."""


from dataclasses import dataclass
from dataclasses import field


@dataclass
class Task:
"""Task base dataclass."""

name: str
function_name: str
task_type: str
apply_local: bool = False
parameters: dict = field(default_factory=dict) # We can expend it in the future


@dataclass
class TrainTask(Task):
"""TrainTask class."""

task_type: str = 'train'


@dataclass
class ValidateTask(Task):
"""Validation Task class."""

task_type: str = 'validate'
20 changes: 17 additions & 3 deletions openfl/component/collaborator/collaborator.py
Expand Up @@ -184,8 +184,22 @@ def get_tasks(self):
def do_task(self, task, round_number):
"""Do the specified task."""
# map this task to an actual function name and kwargs
func_name = self.task_config[task]['function']
kwargs = self.task_config[task]['kwargs']
if hasattr(self.task_runner, 'TASK_REGISTRY'):
func_name = task.function_name
task_name = task.name
kwargs = {}
if task.task_type == 'validate':
if task.apply_local:
kwargs['apply'] = 'local'
else:
kwargs['apply'] = 'global'
else:
if isinstance(task, str):
task_name = task
else:
task_name = task.name
func_name = self.task_config[task_name]['function']
kwargs = self.task_config[task_name]['kwargs']

# this would return a list of what tensors we require as TensorKeys
required_tensorkeys_relative = self.task_runner.get_required_tensorkeys_for_function(
Expand Down Expand Up @@ -250,7 +264,7 @@ def do_task(self, task, round_number):

# send the results for this tasks; delta and compression will occur in
# this function
self.send_task_results(global_output_tensor_dict, round_number, task)
self.send_task_results(global_output_tensor_dict, round_number, task_name)

def get_numpy_dict_for_tensorkeys(self, tensor_keys):
"""Get tensor dictionary for specified tensorkey set."""
Expand Down
2 changes: 1 addition & 1 deletion openfl/component/director/experiment.py
Expand Up @@ -83,7 +83,7 @@ async def start(
logger.info(f'Experiment "{self.name}" was finished successfully.')
except Exception as e:
self.status = Status.FAILED
logger.error(f'Experiment "{self.name}" was failed with error: {e}.')
logger.exception(f'Experiment "{self.name}" was failed with error: {e}.')

def _create_aggregator_grpc_server(
self, *,
Expand Down