Skip to content

Commit

Permalink
revert arg_dict stuff :c
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthfrey committed Jul 26, 2019
1 parent 1ff2776 commit d3a1fdc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 42 deletions.
74 changes: 33 additions & 41 deletions ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
self.wrapper = None
self.child_wrapper = None
self.child_decorator = None
self.support_arg_dict = True
self.polling_strategy = PollingStrategy('flat')

def __call__(self, *args, **kwargs):
ret = getattr(self, self.get_mode())(*args, **kwargs)
Expand All @@ -61,10 +63,10 @@ def _init_to_graph(self, children: List[Callable], weights: Optional[List[np.nda
child = Model(child, self.name)
Graph.add_node(self.name, child, weight)

def call_child(self, child_name, **kwargs):
def call_child(self, child_name, *args, **kwargs):
child = self.children[child_name]
child = child if self.child_decorator is None else self.child_decorator(child)
ret = child(**kwargs)
ret = child(*args, **kwargs)
return ret if self.child_wrapper is None else self.child_wrapper(ret)

# error helpers
Expand Down Expand Up @@ -143,70 +145,60 @@ def generate_children(self) -> Iterator[Tuple[str, Node]]:
for name, node in self.children.items():
yield name, node

def generate_all_calls(self, arg_dict: Dict, **kwargs) -> Iterator[Tuple[str, None]]:
def generate_all_calls(self, *args, **kwargs) -> Iterator[Tuple[str, None]]:
for name, node in self.generate_children():
if self.get_polling_strategy() == 'structured':
yield name, self.call_child(name, **arg_dict[name])
else:
if isinstance(node, self.__class__) or self.child_decorator is not None:
yield name, self.call_child(name, **kwargs)
yield name, self.call_child(name, *args, **kwargs)
else:
filtered_kwargs = {k: v for k, v in kwargs.items() if k in node.get_arg_names()}
yield name, self.call_child(name, **filtered_kwargs)
yield name, self.call_child(name, *args, **filtered_kwargs)

def generate_all_call_return_values(self, arg_dict: Dict, **kwargs):
return (return_value for _, return_value in self.generate_all_calls(arg_dict, **kwargs))
def generate_all_call_return_values(self, *args, **kwargs):
return (return_value for _, return_value in self.generate_all_calls(*args, **kwargs))

def get_all_call_return_values(self, arg_dict: Dict, **kwargs):
return list(self.generate_all_call_return_values(arg_dict, **kwargs))
def get_all_call_return_values(self, *args, **kwargs):
return list(self.generate_all_call_return_values(*args, **kwargs))

# main callers

def multiplex(self, child: str, **kwargs):
def multiplex(self, child: str, *args, **kwargs):
Ensemble._raise_if_node_not_found(child)
Ensemble._raise_if_node_not_in_ensemble(self.name, child)
return self.call_child(child, **kwargs)
return self.call_child(child, *args, **kwargs)

@poller
def call_children(self, arg_dict: Dict = dict(), **kwargs):
return {k: v for k, v in self.generate_all_calls(arg_dict, **kwargs)}
def call_children(self, *args, **kwargs):
return {k: v for k, v in self.generate_all_calls(*args, **kwargs)}

@poller
def aggregate(self, agg: Callable, arg_dict=dict(), **kwargs):
return agg(self.get_all_call_return_values(arg_dict, **kwargs))
def aggregate(self, agg: Callable, *args, **kwargs):
return agg(self.get_all_call_return_values(*args, **kwargs))

@poller
def mean(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(np.mean, arg_dict, **kwargs)
def mean(self, *args, **kwargs) -> np.float:
return self.aggregate(np.mean, *args, **kwargs)

@poller
def sum(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(np.sum, arg_dict, **kwargs)
def sum(self, *args, **kwargs) -> np.float:
return self.aggregate(np.sum, *args, **kwargs)

@poller
def max(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(max, arg_dict, **kwargs)
def max(self, *args, **kwargs) -> np.float:
return self.aggregate(max, *args, **kwargs)

@poller
def any(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(any, arg_dict, **kwargs)
def any(self, *args, **kwargs) -> np.float:
return self.aggregate(any, *args, **kwargs)

@poller
def all(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(all, arg_dict, **kwargs)
def all(self, *args, **kwargs) -> np.float:
return self.aggregate(all, *args, **kwargs)

# other callers

@poller
def weighted_mean(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
def weighted_mean(self, *args, **kwargs) -> np.float:
agg = partial(np.average, weights=self.weights)
return self.aggregate(agg, arg_dict, **kwargs)
return self.aggregate(agg, *args, **kwargs)

@poller
def weighted_sum(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
def weighted_sum(self, *args, **kwargs) -> np.float:
agg = lambda values: np.dot(values, self.weights)
return self.aggregate(agg, arg_dict, **kwargs)
return self.aggregate(agg, *args, **kwargs)

@poller
def vote(self, arg_dict: Dict = dict(), **kwargs) -> np.float:
return self.aggregate(np.bincount, arg_dict, **kwargs).argmax()
def vote(self, *args, **kwargs) -> np.float:
return self.aggregate(np.bincount, *args, **kwargs).argmax()
3 changes: 2 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def c(z):

e = Ensemble('ab', [a, b])
print(e(x=1, y=1))
print(e(dict(a=dict(x=2), b=dict(y=2))))
#print(e(dict(a=dict(x=2), b=dict(y=2))))
print(e(2))
print(e8)


Expand Down

0 comments on commit d3a1fdc

Please sign in to comment.