Skip to content

Commit

Permalink
memory usage improvement for KFAC and EKFAC
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Aug 7, 2020
1 parent 815bb1c commit 4c36ecd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
14 changes: 9 additions & 5 deletions nngeometry/generator/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,15 @@ def get_kfac_blocks(self):
torch.autograd.grad(output[self.i_output], [inputs],
retain_graph=retain_graph,
only_inputs=True)
blocks = {layer_id: (self._blocks[layer_id][0] / n_examples *
self.n_output**.5,
self._blocks[layer_id][1] / n_examples /
self.n_output**.5)
for layer_id in self.layer_collection.layers.keys()}
for layer_id in self.layer_collection.layers.keys():
self._blocks[layer_id][0].div_(n_examples / self.n_output**.5)
self._blocks[layer_id][1].div_(self.n_output**.5 * n_examples)
blocks = self._blocks
# blocks = {layer_id: (self._blocks[layer_id][0] / n_examples *
# self.n_output**.5,
# self._blocks[layer_id][1] / n_examples /
# self.n_output**.5)
# for layer_id in self.layer_collection.layers.keys()}

# remove hooks
del self._blocks
Expand Down
1 change: 1 addition & 0 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def __init__(self, generator, data=None):
evecs[layer_id] = (evecs_a, evecs_g)
diags[layer_id] = kronecker(evals_g.view(-1, 1),
evals_a.view(-1, 1))
del a, g, kfac_blocks[layer_id]
self.data = (evecs, diags)
else:
self.data = data
Expand Down

0 comments on commit 4c36ecd

Please sign in to comment.