Skip to content

Commit

Permalink
Merge pull request #54 from tfjgeorge/get_kfe
Browse files Browse the repository at this point in the history
adds get_KFE method to PMatEKFAC
  • Loading branch information
tfjgeorge committed Mar 15, 2023
2 parents 930996b + 576b538 commit ba2dcb1
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,25 +642,43 @@ def get_dense_tensor(self, split_weight_bias=True):
involves more operations. Otherwise the coefficients corresponding
to the bias are mixed between coefficients of the weight matrix
"""
evecs, diags = self.data
_, diags = self.data
s = self.generator.layer_collection.numel()
M = torch.zeros((s, s), device=self.generator.get_device())
for layer_id, layer in self.generator.layer_collection.layers.items():
evecs_a, evecs_g = evecs[layer_id]
KFE_layers = self.get_KFE(split_weight_bias=split_weight_bias)
for layer_id, _ in self.generator.layer_collection.layers.items():
diag = diags[layer_id]
start = self.generator.layer_collection.p_pos[layer_id]
sAG = diag.numel()
KFE = KFE_layers[layer_id]
M[start:start+sAG, start:start+sAG].add_(
torch.mm(KFE, torch.mm(torch.diag(diag.view(-1)),
KFE.t())))
return M

def get_KFE(self, split_weight_bias=True):
"""
Returns a dict index by layers, of dense eigenvectors constructed from
Kronecker-factored eigenvectors
- split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it
involves more operations. Otherwise the coefficients corresponding
to the bias are mixed between coefficients of the weight matrix
"""
evecs, _ = self.data
KFE = dict()
for layer_id, _ in self.generator.layer_collection.layers.items():
evecs_a, evecs_g = evecs[layer_id]
start = self.generator.layer_collection.p_pos[layer_id]
if split_weight_bias:
kronecker(evecs_g, evecs_a[:-1, :])
kronecker(evecs_g, evecs_a[-1:, :].contiguous())
KFE = torch.cat([kronecker(evecs_g, evecs_a[:-1, :]),
KFE[layer_id] = torch.cat([kronecker(evecs_g, evecs_a[:-1, :]),
kronecker(evecs_g, evecs_a[-1:, :])], dim=0)
else:
KFE = kronecker(evecs_g, evecs_a)
M[start:start+sAG, start:start+sAG].add_(
torch.mm(KFE, torch.mm(torch.diag(diag.view(-1)),
KFE.t())))
return M
KFE[layer_id] = kronecker(evecs_g, evecs_a)
return KFE

def update_diag(self, examples):
"""
Expand Down

0 comments on commit ba2dcb1

Please sign in to comment.