Skip to content

Commit

Permalink
adds documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Aug 21, 2020
1 parent 7414610 commit 048b73a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/api/pspace-representations.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
Parameter space matrix representations
======================================

.. autoclass:: nngeometry.object.PMat.PMatAbstract
.. autoclass:: nngeometry.object.pspace.PMatAbstract
:members:

Concrete representations
========================

.. automodule:: nngeometry.object.PMat
.. automodule:: nngeometry.object.pspace
:members:
:exclude-members: PMatAbstract
2 changes: 0 additions & 2 deletions docs/api/vectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ In NNGeometry, vectors are not just a bunch of scalars, but they have a semantic
- :class:`nngeometry.object.vector.PVector` objects are vectors living in the parameter space of a neural network model. An example of such vector is :math:`\delta \mathbf w` in the EWC penalty :math:`\delta \mathbf w^\top F \delta \mathbf w`.
- :class:`nngeometry.object.vector.FVector` objects are vectors living in the function space of a neural network model. An example of such vector is :math:`\mathbf{f}=\left(f\left(x_{1}\right),\ldots,f\left(x_{n}\right)\right)^{\top}` where :math:`f` is a neural network and :math:`x_1,\ldots,x_n` are examples from a training dataset.

API documentation
=================

.. automodule:: nngeometry.object.vector
:members:
Expand Down
4 changes: 2 additions & 2 deletions examples/FIM for EWC.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "pytenv",
"display_name": "Python 3 (pytenv)",
"language": "python",
"name": "pytenv"
},
Expand All @@ -676,7 +676,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/Fisher Information Matrix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
20 changes: 20 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@


class LayerCollection:
"""
This class describes a set or subset of layers, that can be used
in order to instantiate :class:`nngeometry.object.PVector` or
:class:`nngeometry.object.PSpaceDense` objects
:param layers:
"""

def __init__(self, layers=None):
if layers is None:
Expand All @@ -16,6 +23,13 @@ def __init__(self, layers=None):
raise NotImplementedError

def from_model(model):
"""
Constructs a new LayerCollection object by using all parameters
of the model passed as argument.
:param model: The PyTorch model
:type model: `nn.Module`
"""
lc = LayerCollection()
for layer, mod in model.named_modules():
mod_class = mod.__class__.__name__
Expand Down Expand Up @@ -75,6 +89,12 @@ def _module_to_layer(mod):
return BatchNorm2dLayer(num_features=mod.num_features)

def numel(self):
"""
Total number of scalar parameters in this LayerCollection object
:return: number of scalar parameters
:rtype: int
"""
return self._numel

def __getitem__(self, layer_id):
Expand Down

0 comments on commit 048b73a

Please sign in to comment.