Skip to content

Commit

Permalink
fixes from_model_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Jun 16, 2020
1 parent e02db89 commit 95c5415
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def copy_to_model(self, model):
mod.weight.data.copy_(dict_repr[layer_id][0])

@staticmethod
# TODO: fix and test
def from_model_grad(model):
dict_repr = dict()
for mod in get_individual_modules(model)[0]:
if mod.bias is not None:
dict_repr[mod] = (mod.weight.grad, mod.bias.grad)
layer_collection = LayerCollection.from_model(model)
l_to_m, _ = layer_collection.get_layerid_module_maps(model)
for layer_id, layer in layer_collection.layers.items():
mod = l_to_m[layer_id]
if layer.bias is not None:
dict_repr[layer_id] = (mod.weight.grad, mod.bias.grad)
else:
dict_repr[mod] = (mod.weight.grad)
return PVector(model, dict_repr=dict_repr)
dict_repr[layer_id] = (mod.weight.grad,)
return PVector(layer_collection, dict_repr=dict_repr)

def clone(self):
if self.dict_repr is not None:
Expand Down

0 comments on commit 95c5415

Please sign in to comment.