Skip to content

Commit

Permalink
Merge pull request #8 from tfjgeorge/kfac_mm
Browse files Browse the repository at this point in the history
adds matrix-matrix product to PMatKFAC
  • Loading branch information
tfjgeorge committed Feb 26, 2021
2 parents 5d2ca21 + 6307d37 commit fd245ea
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
18 changes: 18 additions & 0 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,24 @@ def compute_eigendecomposition(self, impl='symeig'):
def get_eigendecomposition(self):
return self.evals, self.evecs

def mm(self, other):
"""
Matrix-matrix product where `other` is another
instance of PMatKFAC
:param other: Other FIM matrix
:type other: :class:`nngeometry.object.PMatKFAC`
:return: The matrix-matrix product
:rtype: :class:`nngeometry.object.PMatKFAC`
"""
prod = dict()
for layer_id, (a, g) in self.data.items():
(a_other, g_other) = other.data[layer_id]
prod[layer_id] = (torch.mm(a, a_other),
torch.mm(g, g_other))
return PMatKFAC(self.generator, data=prod)


class PMatEKFAC(PMatAbstract):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,30 @@ def test_pspace_kfac_eigendecomposition():
assert angle_v_Mv < 1 + eps and angle_v_Mv > 1 - eps
norm_mv = torch.norm(Mv.get_flat_representation())
check_ratio(evals[l_id][0][i_a] * evals[l_id][1][i_g], norm_mv)


def test_kfac():
for get_task in [get_fullyconnect_task, get_conv_task]:
loader, lc, parameters, model1, function1, n_output = get_task()
_, _, _, model2, function2, _ = get_task()

generator1 = Jacobian(layer_collection=lc,
model=model1,
function=function1,
n_output=n_output)
generator2 = Jacobian(layer_collection=lc,
model=model2,
function=function1,
n_output=n_output)
M_kfac1 = PMatKFAC(generator=generator1, examples=loader)
M_kfac2 = PMatKFAC(generator=generator2, examples=loader)

prod = M_kfac1.mm(M_kfac2)

M_kfac1_tensor = M_kfac1.get_dense_tensor(split_weight_bias=True)
M_kfac2_tensor = M_kfac2.get_dense_tensor(split_weight_bias=True)

prod_tensor = prod.get_dense_tensor(split_weight_bias=True)

check_tensors(torch.mm(M_kfac1_tensor, M_kfac2_tensor),
prod_tensor)

0 comments on commit fd245ea

Please sign in to comment.