Skip to content

Commit

Permalink
improves pspace_repr documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Nov 22, 2021
1 parent b40b6a4 commit efcc378
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 19 deletions.
12 changes: 4 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
author = 'Thomas George'


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc']

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

Expand Down Expand Up @@ -63,7 +56,10 @@
'sphinx.ext.graphviz',
'sphinx.ext.intersphinx',
# 'sphinx.ext.linkcode'
'torch'
'torch',
'sphinxcontrib.bibtex'
]

master_doc = 'index'

bibtex_bibfiles = ['refs.bib']
29 changes: 24 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,27 @@ the tutorials or explore the API reference.
NNGeometry is under developement, as such it is possible that core components change when
between versions.

Tutorials
=========
Quick example
=============

Computing the Fisher Information Matrix on a given PyTorch model using a KFAC representation, and then computing its trace is as simple as:

>>> F_kfac = FIM(model=model,
loader=loader,
representation=PMatKFAC,
n_output=10,
variant='classif_logits',
device='cuda')
>>> print(F_kfac.trace())

If we instead wanted to choose a :class:`nngeometry.object.pspace.PMatBlockDiag` representation, we can just replace ``representation=PMatKFAC`` with ``representation=PMatBlockDiag`` in the above.

This example is further detailed in :doc:`/quick_example`. Other available parameter space representations are listed in :doc:`/pspace_repr`.

More examples
=============

More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples

Indices and tables
==================
Expand All @@ -32,12 +49,14 @@ In-depth
========

.. toctree::
overview.rst
quick_example.rst
install.rst
pspace_repr.rst
:maxdepth: 1

api/index.rst

Quick start
===========
References
==========

.. bibliography::
37 changes: 36 additions & 1 deletion docs/pspace_repr.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,55 @@
Parameter space representations
===============================

Parameter space representations are :math:`d \times d` objects that define metrics in parameter space such as:

- Fisher Information Matrices/Gauss-Newton matrix
- Gradient 2nd moment (e.g. the sometimes called *Empirical Fisher*)
- Other covariances such as in Bayesian Deep Learning

These matrices are often too large to fit in memory, for instance when :math:`d` is in the order of :math:`10^6 - 10^8`
as is typical in current deep networks. Here is a list of parameter space representations that are available in NNGeometry,
computed on a small network, represented as images where each pixel represent a component of the matrix, and the color is
the magnitude of these components. These matrices are normalized by their diagonal (i.e. these are correlation matrices) for
better visualization:

:class:`nngeometry.object.pspace.PMatDense` representation: this is the usual dense matrix. Memory cost: :math:`d \times d`

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatDense.png
:width: 400

:class:`nngeometry.object.pspace.PMatBlockDiag` representation: a block-diagonal representation where diagonal blocks are
dense matrices corresponding to parameters of a single layer, and cross-layer interactions are ignored (their coefficients are
set to :math:`0`). Memory cost: :math:`\sum_l d_l \times d_l` where :math:`d_l` is the number of parameters of layer :math:`l`.

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatBlockDiag.png
:width: 400

:class:`nngeometry.object.pspace.PMatKFAC` representation :cite:p:`martens2015optimizing, grosse2016kronecker`: a block-diagonal representation where diagonal blocks are
factored as the Kronecker product of two smaller matrices, and cross-layer interactions are ignored (their coefficients are
set to :math:`0`). Memory cost: :math:`\sum_l g_l \times g_l + a_l \times a_l` where :math:`a_l` is the number of neurons of the
input of layer :math:`l` and :math:`g_l` is the number of pre-activations of the output of layer :math:`l`.

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatKFAC.png
:width: 400

:class:`nngeometry.object.pspace.PMatEKFAC` representation :cite:p:`george2018fast`: a block-diagonal representation where diagonal blocks are
factored as a diagonal matrix in a Kronecker factored eigenbasis, and cross-layer interactions are ignored (their coefficients are
set to :math:`0`). Memory cost: :math:`\sum_l g_l \times g_l + a_l \times a_l + d_l` where :math:`a_l` is the number of neurons of the
input of layer :math:`l` and :math:`g_l` is the number of pre-activations of the output of layer :math:`l`, and :math:`d_l` is

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatEKFAC.png
:width: 400

:class:`nngeometry.object.pspace.PMatDiag` representation: a diagonal representation that ignores all interactions between parameters.
Memory cost: :math:`d`

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatDiag.png
:width: 400

:class:`nngeometry.object.pspace.PMatQuasiDiag` representation :cite:p:`ollivier2015riemannian`: a diagonal representation where for each neuron, a coefficient is also
stored that measures the interaction between this neuron's weights and the corresponding bias.
Memory cost: :math:`2 \times d`

.. image:: https://github.com/tfjgeorge/nngeometry/raw/master/examples/repr_img/PMatQuasiDiag.png
:width: 400
:width: 400
10 changes: 5 additions & 5 deletions docs/overview.rst → docs/quick_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Computing the FIM requires the following arguments:
>>> Fv = F_kfac.mv(v)

Note that switching from the :class:`.object.PMatKFAC` representation to any other representation such as :class:`.object.PMatDense` is as simple as passing ``representation=PMatDense`` when building the ``F_kfac`` object.
More examples
=============
More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples

More examples
=============

More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples
37 changes: 37 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
@article{george2018fast,
title={Fast Approximate Natural Gradient Descent in a Kronecker Factored Eigenbasis},
author={George, Thomas and Laurent, C{\'e}sar and Bouthillier, Xavier and Ballas, Nicolas and Vincent, Pascal},
journal={Advances in Neural Information Processing Systems},
volume={31},
pages={9550--9560},
year={2018}
}

@inproceedings{grosse2016kronecker,
title={A kronecker-factored approximate fisher matrix for convolution layers},
author={Grosse, Roger and Martens, James},
booktitle={International Conference on Machine Learning},
pages={573--582},
year={2016},
organization={PMLR}
}

@inproceedings{martens2015optimizing,
title={Optimizing neural networks with kronecker-factored approximate curvature},
author={Martens, James and Grosse, Roger},
booktitle={International conference on machine learning},
pages={2408--2417},
year={2015},
organization={PMLR}
}

@article{ollivier2015riemannian,
title={Riemannian metrics for neural networks I: feedforward networks},
author={Ollivier, Yann},
journal={Information and Inference: A Journal of the IMA},
volume={4},
number={2},
pages={108--153},
year={2015},
publisher={Oxford University Press}
}

0 comments on commit efcc378

Please sign in to comment.