Skip to content

Commit

Permalink
jacobian generator doc
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Jul 27, 2020
1 parent 25fe8f5 commit d79b246
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
6 changes: 2 additions & 4 deletions docs/api/generators.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
Generators
==========

The spirit of NNGeometry is that you do not directly manipulate Generator objects, but instead instantiate representations such as `PSpaceDense` or `PSpaceKFAC` so that you do not have to worry about implementing i.e. matrix-vector products for a `PSpaceKFAC` representation.
The spirit of NNGeometry is that you do not directly manipulate Generator objects, these can be considered as a backend that you do not need to worry about once instantiated. You instead instantiate concrete representations such as `PSpaceDense` or `PSpaceKFAC` and directly call linear algebra operations on these concrete representations.

.. automodule:: nngeometry.generator
.. automodule:: nngeometry.generator.jacobian
:members:
:undoc-members:
:show-inheritance:
22 changes: 22 additions & 0 deletions nngeometry/generator/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@


class Jacobian:
"""
Computes jacobians :math:`\mathbf{J}_{ijk}=\\frac{\partial f\left(x_{j}\\right)_{i}}{\delta\mathbf{w}_{k}}`,
FIM matrices :math:`\mathbf{F}_{k,k'}=\\frac{1}{n}\sum_{i,j}\mathbf{J}_{ijk}\mathbf{J}_{ijk'}`
and NTK matrices :math:`\mathbf{K}_{iji'j'}=\sum_{k}\mathbf{J}_{ijk}\mathbf{J}_{ijk'}`.
This generator is written in pure PyTorch and exploits some tricks in order to make computations
more efficient.
:param layer_collection:
:type layer_collection: :class:`.layercollection.LayerCollection`
:param model:
:type model: Pytorch `nn.Module`
:param loader:
:type loader: Pytorch `utils.data.DataLoader`
:param function: A function :math:`f\left(X,Y,Z\\right)` where :math:`X,Y,Z` are minibatchs
returned by the dataloader (Note that in some cases :math:`Y,Z` are not required)
:type function: python function
:param n_output: How many output is there for each example of your function. E.g. in 10 class
classification this would probably be 10.
:type n_output: integer
"""
def __init__(self, layer_collection, model, loader, function, n_output=1,
centering=False):
self.model = model
Expand Down

0 comments on commit d79b246

Please sign in to comment.