Skip to content

Commit

Permalink
Remove functolls.wraps decorators in TaskInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
itrushkin committed Sep 28, 2021
1 parent 33747b4 commit 715f643
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions openfl/interface/interactive_api/experiment.py
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

"""Python low-level API module."""
import functools
import os
import time
from collections import defaultdict
Expand Down Expand Up @@ -365,7 +364,7 @@ def __init__(self) -> None:
# Mapping task name -> callable
self.aggregation_functions = {}

def register_fl_task(self, training_method):
def register_fl_task(self, model, data_loader, device, optimizer=None):
"""
Register FL tasks.
Expand All @@ -389,8 +388,10 @@ def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=3
...
`
"""
@functools.wraps(training_method)
def wrapper(model, data_loader, device, optimizer=None):
# The highest level wrapper for allowing arguments for the decorator
def decorator_with_args(training_method):
# We could pass hooks to the decorator
# @functools.wraps(training_method)

def wrapper_decorator(**task_keywords):
metric_dict = training_method(**task_keywords)
Expand All @@ -404,9 +405,9 @@ def wrapper_decorator(**task_keywords):
# We do not alter user environment
return training_method

return wrapper
return decorator_with_args

def add_kwargs(self, training_method):
def add_kwargs(self, **task_kwargs):
"""
Register tasks settings.
Expand All @@ -415,16 +416,16 @@ def add_kwargs(self, training_method):
This one is a decorator because we need task name and
to be consistent with the main registering method
"""
@functools.wraps(training_method)
def wrapper(**task_kwargs):
# The highest level wrapper for allowing arguments for the decorator
def decorator_with_args(training_method):
# Saving the task's settings to be written in plan
self.task_settings[training_method.__name__] = task_kwargs

return training_method

return wrapper
return decorator_with_args

def set_aggregation_function(self, training_method):
def set_aggregation_function(self, aggregation_function: AggregationFunctionInterface):
"""Set aggregation function for the task.
Args:
Expand All @@ -438,15 +439,14 @@ def set_aggregation_function(self, training_method):
.. _Overriding the aggregation function:
https://openfl.readthedocs.io/en/latest/overriding_agg_fn.html
"""
@functools.wraps(training_method)
def wrapper(aggregation_function: AggregationFunctionInterface):
def decorator_with_args(training_method):
# @functools.wraps(training_method)
if not isinstance(aggregation_function, AggregationFunctionInterface):
raise Exception('aggregation_function must implement '
'AggregationFunctionInterface interface.')
self.aggregation_functions[training_method.__name__] = aggregation_function
return training_method

return wrapper
return decorator_with_args

def get_aggregation_function(self, function_name):
"""Get aggregation type for the task function.
Expand Down

0 comments on commit 715f643

Please sign in to comment.