Skip to content

Commit

Permalink
Merge pull request #7 from tfjgeorge/add_to_model
Browse files Browse the repository at this point in the history
adds add_to_model for PVector, fixes bug in copy_to_model
  • Loading branch information
tfjgeorge committed Feb 25, 2021
2 parents 3fd2c7a + eeadea4 commit 5d2ca21
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
18 changes: 16 additions & 2 deletions nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,22 @@ def copy_to_model(self, model):
mod = l_to_m[layer_id]
if layer.bias is not None:
mod.bias.data.copy_(dict_repr[layer_id][1])
else:
mod.weight.data.copy_(dict_repr[layer_id][0])
mod.weight.data.copy_(dict_repr[layer_id][0])

def add_to_model(self, model):
"""
Updates `model` parameter values by adding the current PVector
Note. This is an inplace operation
"""
dict_repr = self.get_dict_representation()
layer_collection = LayerCollection.from_model(model)
l_to_m, _ = layer_collection.get_layerid_module_maps(model)
for layer_id, layer in layer_collection.layers.items():
mod = l_to_m[layer_id]
if layer.bias is not None:
mod.bias.data.add_(dict_repr[layer_id][1])
mod.weight.data.add_(dict_repr[layer_id][0])

@staticmethod
def from_model_grad(model):
Expand Down
23 changes: 22 additions & 1 deletion tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from nngeometry.layercollection import LayerCollection
import torch.nn as nn
import torch.nn.functional as tF
from utils import check_ratio
from utils import check_ratio, check_tensors
import pytest


Expand Down Expand Up @@ -164,3 +164,24 @@ def test_norm():
check_ratio(torch.norm(v.get_flat_representation()), v.norm())


def test_from_to_model():
model1 = ConvNet()
model2 = ConvNet()

w1 = PVector.from_model(model1).clone()
w2 = PVector.from_model(model2).clone()

model3 = ConvNet()
w1.copy_to_model(model3)
# now model1 and model3 should be the same

for p1, p3 in zip(model1.parameters(), model3.parameters()):
check_tensors(p1, p3)

###
diff_1_2 = w2 - w1
diff_1_2.add_to_model(model3)
# now model2 and model3 should be the same

for p2, p3 in zip(model2.parameters(), model3.parameters()):
check_tensors(p2, p3)

0 comments on commit 5d2ca21

Please sign in to comment.