Skip to content

Commit

Permalink
Merge pull request #10 from tfjgeorge/mm_diag_dense
Browse files Browse the repository at this point in the history
Mm diag dense
  • Loading branch information
tfjgeorge committed Feb 26, 2021
2 parents e31b90f + b37d1b4 commit 3f5361e
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 0 deletions.
46 changes: 46 additions & 0 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def __rmul__(self, x):
return PMatDense(generator=self.generator,
data=x * self.data)

def mm(self, other):
"""
Matrix-matrix product where `other` is another
instance of PMatDense
:param other: Other FIM matrix
:type other: :class:`nngeometry.object.PMatDense`
:return: The matrix-matrix product
:rtype: :class:`nngeometry.object.PMatDense`
"""
return PMatDense(self.generator,
data=torch.mm(self.data, other.data))


class PMatDiag(PMatAbstract):
def __init__(self, generator, data=None, examples=None):
Expand Down Expand Up @@ -255,6 +269,20 @@ def __rmul__(self, x):
return PMatDiag(generator=self.generator,
data=x * self.data)

def mm(self, other):
"""
Matrix-matrix product where `other` is another
instance of PMatDiag
:param other: Other FIM matrix
:type other: :class:`nngeometry.object.PMatDiag`
:return: The matrix-matrix product
:rtype: :class:`nngeometry.object.PMatDiag`
"""
return PMatDiag(self.generator,
data=self.data * other.data)


class PMatBlockDiag(PMatAbstract):
def __init__(self, generator, data=None, examples=None):
Expand Down Expand Up @@ -369,6 +397,24 @@ def __rmul__(self, x):
return PMatBlockDiag(generator=self.generator,
data=sum_data)

def mm(self, other):
"""
Matrix-matrix product where `other` is another
instance of PMatBlockDiag
:param other: Other FIM matrix
:type other: :class:`nngeometry.object.PMatBlockDiag`
:return: The matrix-matrix product
:rtype: :class:`nngeometry.object.PMatBlockDiag`
"""
prod = dict()
for layer_id, block in self.data.items():
block_other = other.data[layer_id]
prod[layer_id] = torch.mm(block, block_other)
return PMatBlockDiag(self.generator,
data=prod)


class PMatKFAC(PMatAbstract):
def __init__(self, generator, data=None, examples=None):
Expand Down
98 changes: 98 additions & 0 deletions tests/test_representations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
from tasks import (get_conv_gn_task, get_fullyconnect_task, get_conv_task)
from nngeometry.object.pspace import (PMatDense, PMatDiag, PMatBlockDiag,
PMatImplicit, PMatLowRank, PMatQuasiDiag)
from nngeometry.generator import Jacobian
from utils import check_ratio, check_tensors, check_angle
import pytest


nonlinear_tasks = [get_conv_gn_task, get_fullyconnect_task, get_conv_task]

if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'

@pytest.fixture(autouse=True)
def make_test_deterministic():
torch.manual_seed(1234)
yield

def test_diag():
for get_task in nonlinear_tasks:
loader, lc, parameters, model1, function1, n_output = get_task()
_, _, _, model2, function2, _ = get_task()

generator1 = Jacobian(layer_collection=lc,
model=model1,
function=function1,
n_output=n_output)
generator2 = Jacobian(layer_collection=lc,
model=model2,
function=function1,
n_output=n_output)
M_diag1 = PMatDiag(generator=generator1, examples=loader)
M_diag2 = PMatDiag(generator=generator2, examples=loader)

prod = M_diag1.mm(M_diag2)

M_diag1_tensor = M_diag1.get_dense_tensor()
M_diag2_tensor = M_diag2.get_dense_tensor()

prod_tensor = prod.get_dense_tensor()

check_tensors(torch.mm(M_diag1_tensor, M_diag2_tensor),
prod_tensor)

def test_dense():
for get_task in nonlinear_tasks:
loader, lc, parameters, model1, function1, n_output = get_task()
_, _, _, model2, function2, _ = get_task()

generator1 = Jacobian(layer_collection=lc,
model=model1,
function=function1,
n_output=n_output)
generator2 = Jacobian(layer_collection=lc,
model=model2,
function=function1,
n_output=n_output)
M_dense1 = PMatDense(generator=generator1, examples=loader)
M_dense2 = PMatDense(generator=generator2, examples=loader)

prod = M_dense1.mm(M_dense2)

M_dense1_tensor = M_dense1.get_dense_tensor()
M_dense2_tensor = M_dense2.get_dense_tensor()

prod_tensor = prod.get_dense_tensor()

check_tensors(torch.mm(M_dense1_tensor, M_dense2_tensor),
prod_tensor)

def test_blockdiag():
for get_task in nonlinear_tasks:
loader, lc, parameters, model1, function1, n_output = get_task()
_, _, _, model2, function2, _ = get_task()

generator1 = Jacobian(layer_collection=lc,
model=model1,
function=function1,
n_output=n_output)
generator2 = Jacobian(layer_collection=lc,
model=model2,
function=function1,
n_output=n_output)
M_blockdiag1 = PMatBlockDiag(generator=generator1, examples=loader)
M_blockdiag2 = PMatBlockDiag(generator=generator2, examples=loader)

prod = M_blockdiag1.mm(M_blockdiag2)

M_blockdiag1_tensor = M_blockdiag1.get_dense_tensor()
M_blockdiag2_tensor = M_blockdiag2.get_dense_tensor()

prod_tensor = prod.get_dense_tensor()

check_tensors(torch.mm(M_blockdiag1_tensor, M_blockdiag2_tensor),
prod_tensor)

0 comments on commit 3f5361e

Please sign in to comment.