Skip to content

Commit

Permalink
adds copy_to_model method to PVector object
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed May 24, 2020
1 parent 1f648ef commit e02db89
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def from_model(model):
dict_repr[layer_id] = (mod.weight,)
return PVector(layer_collection, dict_repr=dict_repr)

def copy_to_model(self, model):
"""
Updates `model` parameter values with the current vector
Note. This is an inplace operation
"""
dict_repr = self.get_dict_representation()
layer_collection = LayerCollection.from_model(model)
l_to_m, _ = layer_collection.get_layerid_module_maps(model)
for layer_id, layer in layer_collection.layers.items():
mod = l_to_m[layer_id]
if layer.bias is not None:
mod.bias.data.copy_(dict_repr[layer_id][1])
else:
mod.weight.data.copy_(dict_repr[layer_id][0])

@staticmethod
# TODO: fix and test
def from_model_grad(model):
Expand Down

0 comments on commit e02db89

Please sign in to comment.