Skip to content
This repository has been archived by the owner on Mar 31, 2019. It is now read-only.

Commit

Permalink
convenience reapply
Browse files Browse the repository at this point in the history
  • Loading branch information
justheuristic committed May 11, 2017
1 parent 9caa406 commit 3e94cbf
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions agentnet/utils/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,24 @@ def clone_network(original_network, bottom_layers=None,


return network_clone


def reapply(layer_or_layers, new_bottom, share_params=True, name_prefix=None):
"""
Applies a part of lasagne network to a new place. Wraps clone_network
:param layer_or_layers: layers to be re-applied
:param new_bottom: a dict {old_layer:new_layer} that defines which layers should be substituted by which other layers
:param share_params: if True, cloned network will use same shared variables for weights.
Otherwise new shared variables will be created and set to original NN values.
WARNING! shared weights must be accessible via lasagne.layers.get_all_params with no flags
If you want custom other parameters to be shared, agentnet.utils.clone_network
:param name_prefix: if not None, adds this prefix to all the layers and params of the cloned network
:return: a new layer or layers that represent re-applying layer_or_layers to new_bottom
"""
assert isinstance(new_bottom,dict)
for layer in lasagne.layers.get_all_layers(layer_or_layers,list(new_bottom.keys())):
if isinstance(layer,lasagne.layers.InputLayer):
assert layer in new_bottom, "must explicitly provide all new_bottom for each branch of original network. " \
"Assert caused by {}. For dirty hacks, use clone_network.".format(layer.name or layer)
return clone_network(layer_or_layers, new_bottom, share_params=share_params, share_inputs=False, name_prefix=name_prefix)

0 comments on commit 3e94cbf

Please sign in to comment.