Skip to content

Commit

Permalink
Merge pull request #22 from tfjgeorge/grad
Browse files Browse the repository at this point in the history
adds grad for PVectors
  • Loading branch information
tfjgeorge committed Apr 26, 2021
2 parents 4a50ca2 + 4669859 commit bdcf8a7
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
43 changes: 42 additions & 1 deletion nngeometry/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from nngeometry.object.vector import PVector


def per_example_grad_conv(mod, x, gy):
Expand All @@ -20,4 +21,44 @@ def display_correl(M, axis):
dM = (diag + diag.mean() / 100) **.5
correl = torch.abs(M) / dM.unsqueeze(0) / dM.unsqueeze(1)

axis.imshow(correl.cpu())
axis.imshow(correl.cpu())


def grad(output, vec, *args, **kwargs):
"""
Computes the gradient of `output` with respect to the `PVector` `vec`
..warning This function only works when internally your `vec` has been
created from leaf nodes in the graph (e.g. model parameters)
:param output: The scalar quantity to be differentiated
:param vec: a `PVector`
:return: a `PVector` of gradients of `output` w.r.t `vec`
"""
if vec.dict_repr is not None:
# map all parameters to a list
params = []
pos = []
lenghts = []
current_pos = 0
for k in vec.dict_repr.keys():
p = vec.dict_repr[k]
params += list(p)
pos.append(current_pos)
lenghts.append(len(p))
current_pos = current_pos + len(p)

grad_list = torch.autograd.grad(output, params, *args, **kwargs)
dict_repr_grad = dict()

for k, p, l in zip(vec.dict_repr.keys(), pos, lenghts):
if l == 1:
dict_repr_grad[k] = (grad_list[p],)
elif l == 2:
dict_repr_grad[k] = (grad_list[p], grad_list[p+1])

return PVector(vec.layer_collection,
dict_repr=dict_repr_grad)
else:
raise RuntimeError('grad only works with the vector is created ' +
'from leaf nodes in the computation graph')
37 changes: 37 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from nngeometry.object.vector import PVector, random_pvector
from nngeometry.utils import grad
from utils import check_tensors
import pytest
from tasks import get_conv_gn_task, to_device

@pytest.fixture(autouse=True)
def make_test_deterministic():
torch.manual_seed(1234)
yield


def test_grad_dict_repr():
loader, lc, parameters, model, function, n_output = get_conv_gn_task()

d, _ = next(iter(loader))
scalar_output = model(to_device(d)).sum()
vec = PVector.from_model(model)

grad_nng = grad(scalar_output, vec, retain_graph=True)

scalar_output.backward()
grad_direct = PVector.from_model_grad(model)

check_tensors(grad_direct.get_flat_representation(),
grad_nng.get_flat_representation())


def test_grad_flat_repr():
loader, lc, parameters, model, function, n_output = get_conv_gn_task()

vec = random_pvector(lc)
scalar_output = vec.norm()

with pytest.raises(RuntimeError):
grad(scalar_output, vec)

0 comments on commit bdcf8a7

Please sign in to comment.