Skip to content

Commit

Permalink
deletes inverse from PMatAbstract and adds solve instead, which makes…
Browse files Browse the repository at this point in the history
… more sense since it is more general
  • Loading branch information
tfjgeorge committed Aug 27, 2020
1 parent c9e5aa3 commit 504809a
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def vTMv(self, v):
raise NotImplementedError

@abstractmethod
def inverse(self, regul):
def solve(self, v, regul):
"""
Inverse of the matrix
Solves Fx = v in x
:param regul: Tikhonov regularization
:type regul: float
:param v: v
:type regul: PVector
"""
raise NotImplementedError

Expand Down Expand Up @@ -468,8 +470,11 @@ def compute_eigendecomposition(self, impl='symeig'):
def get_eigendecomposition(self):
return self.evals, self.evecs

def solve(self, v):
raise NotImplementedError


class PMatEKFAC:
class PMatEKFAC(PMatAbstract):
"""
EKFAC representation from
*George, Laurent et al., Fast Approximate Natural Gradient Descent
Expand Down Expand Up @@ -572,6 +577,12 @@ def trace(self):
def frobenius_norm(self):
return sum([(d**2).sum() for d in self.data[1].values()])**.5

def solve(self, v):
raise NotImplementedError

def get_diag(self, v):
raise NotImplementedError

def inverse(self, regul=1e-8):
evecs, diags = self.data
inv_diags = {i: 1. / (d + regul)
Expand Down Expand Up @@ -615,7 +626,7 @@ def frobenius_norm(self):
def get_dense_tensor(self):
raise NotImplementedError

def inverse(self, regul):
def solve(self, v):
raise NotImplementedError

def get_diag(self):
Expand Down Expand Up @@ -672,7 +683,7 @@ def frobenius_norm(self):
self.data.view(-1, self.data.size(2)).t())
return torch.norm(A)

def inverse(self, regul):
def solve(self, v):
raise NotImplementedError

def get_diag(self):
Expand Down Expand Up @@ -735,7 +746,7 @@ def get_diag(self):
return torch.cat([self.data[l_id][0] for l_id in
self.generator.layer_collection.layers.keys()])

def inverse(self):
def solve(self, v):
raise NotImplementedError

def trace(self):
Expand Down

0 comments on commit 504809a

Please sign in to comment.