Skip to content

Commit

Permalink
changed PSpace prefix to PMat, and FSpace to FMat
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Jul 29, 2020
1 parent d79b246 commit 47a802b
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 185 deletions.
2 changes: 1 addition & 1 deletion docs/api/generators.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Generators
==========

The spirit of NNGeometry is that you do not directly manipulate Generator objects, these can be considered as a backend that you do not need to worry about once instantiated. You instead instantiate concrete representations such as `PSpaceDense` or `PSpaceKFAC` and directly call linear algebra operations on these concrete representations.
The spirit of NNGeometry is that you do not directly manipulate Generator objects, these can be considered as a backend that you do not need to worry about once instantiated. You instead instantiate concrete representations such as `PMatDense` or `PMatKFAC` and directly call linear algebra operations on these concrete representations.

.. automodule:: nngeometry.generator.jacobian
:members:
6 changes: 3 additions & 3 deletions docs/api/pspace-representations.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
Parameter space matrix representations
======================================

.. autoclass:: nngeometry.object.pspace.PSpaceAbstract
.. autoclass:: nngeometry.object.PMat.PMatAbstract
:members:

Concrete representations
========================

.. automodule:: nngeometry.object.pspace
.. automodule:: nngeometry.object.PMat
:members:
:exclude-members: PSpaceAbstract
:exclude-members: PMatAbstract
6 changes: 3 additions & 3 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Let us now illustrate this by computing the FIM using the KFAC representation.
>>> F_kfac = FIM(layer_collection=layer_collection,
model=model,
loader=loader,
representation=PSpaceKFAC,
representation=PMatKFAC,
n_output=10,
variant='classif_logits',
device='cuda')
Expand All @@ -26,7 +26,7 @@ Computing the FIM requires the following arguments:
- The :class:`.layercollection.LayerCollection` ``layer_collection`` object describes the structure of the parameters that we will be manipulating. If we are interested in computing the FIM for the last 2 layers for instance, it will describe the structure of these last 2 layers and the size of their parameters.
- The :class:`torch.nn.Module` ``model`` object is the PyTorch model used as our neural network.
- The :class:`torch.utils.data.DataLoader` ``loader`` object is the dataloader that contains examples used for computing the FIM.
- The :class:`.object.pspace.PSpaceKFAC` argument specifies which representation to use in order to store the FIM.
- The :class:`.object.PMat.PMatKFAC` argument specifies which representation to use in order to store the FIM.

We will next define a vector in parameter space, by using the current value given by our model:

Expand All @@ -36,4 +36,4 @@ Computing the FIM requires the following arguments:

>>> Fv = F_kfac.mv(v)

Note that switching from the :class:`.object.pspace.PSpaceKFAC` representation to any other representation such as :class:`.object.pspace.PSpaceDense` is as simple as passing ``representation=PSpaceDense`` when building the ``F_kfac`` object.
Note that switching from the :class:`.object.PMat.PMatKFAC` representation to any other representation such as :class:`.object.PMat.PMatDense` is as simple as passing ``representation=PMatDense`` when building the ``F_kfac`` object.
14 changes: 7 additions & 7 deletions examples/FIM for EWC.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"metadata": {},
"outputs": [],
"source": [
"from nngeometry.object.pspace import PSpaceKFAC, PSpaceDiag, PSpaceBlockDiag, PSpaceDense\n",
"from nngeometry.object.PMat import PMatKFAC, PMatDiag, PMatBlockDiag, PMatDense\n",
"from nngeometry.object.vector import PVector\n",
"from nngeometry.metrics import FIM"
]
Expand All @@ -123,23 +123,23 @@
"F_kfac = FIM(layer_collection=layer_collection,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceKFAC,\n",
" representation=PMatKFAC,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
"\n",
"F_blockdiag = FIM(layer_collection=layer_collection,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceBlockDiag,\n",
" representation=PMatBlockDiag,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
"\n",
"F_dense = FIM(layer_collection=layer_collection,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceDense,\n",
" representation=PMatDense,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
Expand All @@ -148,7 +148,7 @@
"F_diag = FIM(layer_collection=layer_collection,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceDiag,\n",
" representation=PMatDiag,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')"
Expand Down Expand Up @@ -561,15 +561,15 @@
"F_linear_kfac = FIM(layer_collection=layer_collection_linear,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceKFAC,\n",
" representation=PMatKFAC,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
"\n",
"F_bn_blockdiag = FIM(layer_collection=layer_collection_bn,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceBlockDiag,\n",
" representation=PMatBlockDiag,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/Fisher Information Matrix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
}
],
"source": [
"from nngeometry.pspace import M2Gradients\n",
"from nngeometry.PMat import M2Gradients\n",
"\n",
"m2_generator = M2Gradients(model=model, dataloader=loader, loss_function=loss_fim_mc_estimate)\n",
"n_parameters = m2_generator.get_n_parameters()\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/GGN MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"source": [
"from tasks import get_linear_task\n",
"from nngeometry.generator.jacobian import Jacobian\n",
"from nngeometry.object.pspace import PSpaceDense, PSpaceDiag, PSpaceKFAC\n",
"from nngeometry.object.PMat import PMatDense, PMatDiag, PMatKFAC\n",
"from nngeometry.object.vector import random_pvector\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
Expand Down Expand Up @@ -104,7 +104,7 @@
"GGN_dense = FIM(layer_collection=lc,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceDense,\n",
" representation=PMatDense,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
Expand Down Expand Up @@ -141,7 +141,7 @@
"GGN_diag = FIM(layer_collection=lc,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceDiag,\n",
" representation=PMatDiag,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
Expand Down Expand Up @@ -178,7 +178,7 @@
"GGN_kfac = FIM(layer_collection=lc,\n",
" model=model,\n",
" loader=loader,\n",
" representation=PSpaceKFAC,\n",
" representation=PMatKFAC,\n",
" n_output=10,\n",
" variant='classif_logits',\n",
" device='cuda')\n",
Expand Down
8 changes: 4 additions & 4 deletions nngeometry/object/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .pspace import (PSpaceDense, PSpaceBlockDiag, PSpaceDiag,
PSpaceLowRank, PSpaceImplicit,
PSpaceKFAC, PSpaceEKFAC)
from .PMat import (PMatDense, PMatBlockDiag, PMatDiag,
PMatLowRank, PMatImplicit,
PMatKFAC, PMatEKFAC)
from .vector import (PVector, FVector)
from .fspace import (FSpaceDense,)
from .FMat import (FMatDense,)
from .map import (PushForwardDense, PushForwardImplicit,
PullBackDense)
8 changes: 4 additions & 4 deletions nngeometry/object/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from .vector import FVector, PVector


class FSpaceAbstract(ABC):
class FMatAbstract(ABC):

@abstractmethod
def __init__(self, generator):
return NotImplementedError


class FSpaceDense(FSpaceAbstract):
class FMatDense(FMatAbstract):
def __init__(self, generator, data=None):
self.generator = generator
if data is not None:
Expand Down Expand Up @@ -70,11 +70,11 @@ def get_dense_tensor(self):
def __add__(self, other):
# TODO: test
sum_data = self.data + other.data
return FSpaceDense(generator=self.generator,
return FMatDense(generator=self.generator,
data=sum_data)

def __sub__(self, other):
# TODO: test
sub_data = self.data - other.data
return FSpaceDense(generator=self.generator,
return FMatDense(generator=self.generator,
data=sub_data)
50 changes: 25 additions & 25 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .vector import PVector


class PSpaceAbstract(ABC):
class PMatAbstract(ABC):
"""
A :math:`d \\times d` matrix in parameter space. This abstract class
defines common methods used in concrete representations.
Expand Down Expand Up @@ -86,7 +86,7 @@ def size(self, dim=None):
raise IndexError


class PSpaceDense(PSpaceAbstract):
class PMatDense(PMatAbstract):
def __init__(self, generator, data=None):
self.generator = generator
if data is not None:
Expand Down Expand Up @@ -126,7 +126,7 @@ def inverse(self, regul=1e-8):
inv_tensor = torch.inverse(self.data +
regul * torch.eye(self.size(0),
device=self.data.device))
return PSpaceDense(generator=self.generator,
return PMatDense(generator=self.generator,
data=inv_tensor)

def mv(self, v):
Expand Down Expand Up @@ -164,20 +164,20 @@ def get_diag(self):

def __add__(self, other):
sum_data = self.data + other.data
return PSpaceDense(generator=self.generator,
return PMatDense(generator=self.generator,
data=sum_data)

def __sub__(self, other):
sub_data = self.data - other.data
return PSpaceDense(generator=self.generator,
return PMatDense(generator=self.generator,
data=sub_data)

def __rmul__(self, x):
return PSpaceDense(generator=self.generator,
return PMatDense(generator=self.generator,
data=x * self.data)


class PSpaceDiag(PSpaceAbstract):
class PMatDiag(PMatAbstract):
def __init__(self, generator=None, data=None):
self.generator = generator
if data is not None:
Expand All @@ -187,7 +187,7 @@ def __init__(self, generator=None, data=None):

def inverse(self, regul=1e-8):
inv_tensor = 1. / (self.data + regul)
return PSpaceDiag(generator=self.generator,
return PMatDiag(generator=self.generator,
data=inv_tensor)

def mv(self, v):
Expand All @@ -212,20 +212,20 @@ def get_diag(self):

def __add__(self, other):
sum_diags = self.data + other.data
return PSpaceDiag(generator=self.generator,
return PMatDiag(generator=self.generator,
data=sum_diags)

def __sub__(self, other):
sub_diags = self.data - other.data
return PSpaceDiag(generator=self.generator,
return PMatDiag(generator=self.generator,
data=sub_diags)

def __rmul__(self, x):
return PSpaceDiag(generator=self.generator,
return PMatDiag(generator=self.generator,
data=x * self.data)


class PSpaceBlockDiag(PSpaceAbstract):
class PMatBlockDiag(PMatAbstract):
def __init__(self, generator, data=None):
self.generator = generator
if data is not None:
Expand Down Expand Up @@ -278,7 +278,7 @@ def inverse(self, regul=1e-8):
regul *
torch.eye(b.size(0), device=b.device))
inv_data[layer_id] = inv_b
return PSpaceBlockDiag(generator=self.generator,
return PMatBlockDiag(generator=self.generator,
data=inv_data)

def frobenius_norm(self):
Expand All @@ -299,22 +299,22 @@ def vTMv(self, vector):
def __add__(self, other):
sum_data = {l_id: d + other.data[l_id]
for l_id, d in self.data.items()}
return PSpaceBlockDiag(generator=self.generator,
return PMatBlockDiag(generator=self.generator,
data=sum_data)

def __sub__(self, other):
sum_data = {l_id: d - other.data[l_id]
for l_id, d in self.data.items()}
return PSpaceBlockDiag(generator=self.generator,
return PMatBlockDiag(generator=self.generator,
data=sum_data)

def __rmul__(self, x):
sum_data = {l_id: x * d for l_id, d in self.data.items()}
return PSpaceBlockDiag(generator=self.generator,
return PMatBlockDiag(generator=self.generator,
data=sum_data)


class PSpaceKFAC(PSpaceAbstract):
class PMatKFAC(PMatAbstract):
def __init__(self, generator, data=None):
self.generator = generator
if data is None:
Expand Down Expand Up @@ -342,7 +342,7 @@ def inverse(self, regul=1e-8, use_pi=True):
regul**.5 / pi *
torch.eye(g.size(0), device=g.device))
inv_data[layer_id] = (inv_a, inv_g)
return PSpaceKFAC(generator=self.generator,
return PMatKFAC(generator=self.generator,
data=inv_data)

def get_dense_tensor(self, split_weight_bias=True):
Expand Down Expand Up @@ -440,7 +440,7 @@ def get_eigendecomposition(self):
return self.evals, self.evecs


class PSpaceEKFAC:
class PMatEKFAC:
"""
EKFAC representation from
*George, Laurent et al., Fast Approximate Natural Gradient Descent
Expand Down Expand Up @@ -542,17 +542,17 @@ def inverse(self, regul=1e-8):
evecs, diags = self.data
inv_diags = {i: 1. / (d + regul)
for i, d in diags.items()}
return PSpaceEKFAC(generator=self.generator,
return PMatEKFAC(generator=self.generator,
data=(evecs, inv_diags))

def __rmul__(self, x):
evecs, diags = self.data
diags = {l_id: x * d for l_id, d in diags.items()}
return PSpaceEKFAC(generator=self.generator,
return PMatEKFAC(generator=self.generator,
data=(evecs, diags))


class PSpaceImplicit(PSpaceAbstract):
class PMatImplicit(PMatAbstract):
def __init__(self, generator):
self.generator = generator

Expand All @@ -578,7 +578,7 @@ def get_diag(self):
raise NotImplementedError


class PSpaceLowRank(PSpaceAbstract):
class PMatLowRank(PMatAbstract):
def __init__(self, generator, data=None):
self.generator = generator
if data is not None:
Expand Down Expand Up @@ -635,10 +635,10 @@ def get_diag(self):
return (self.data**2).sum(dim=(0, 1))

def __rmul__(self, x):
return PSpaceLowRank(generator=self.generator,
return PMatLowRank(generator=self.generator,
data=x**.5 * self.data)


class KrylovLowRankMatrix(PSpaceAbstract):
class KrylovLowRankMatrix(PMatAbstract):
def __init__(self, generator):
raise NotImplementedError()

0 comments on commit 47a802b

Please sign in to comment.