Skip to content

Commit

Permalink
adds get_diag, frobenius norm and trace to PMatQuasiDiag
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Aug 26, 2020
1 parent f64d55e commit dddc290
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
24 changes: 21 additions & 3 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def get_dense_tensor(self):
def get_diag(self):
return self.data

def solve(self, v, regul=1e-8):
"""
solves v = Ax in x
"""
# TODO: test
solution = v.to_flat_representation() / (self.data + regul)
return PVector(layer_collection=v.layer_collection,
vector_repr=solution)

def __add__(self, other):
sum_diags = self.data + other.data
return PMatDiag(generator=self.generator,
Expand Down Expand Up @@ -683,10 +692,18 @@ def get_dense_tensor(self):
return M

def frobenius_norm(self):
raise NotImplementedError
norm2 = 0
for layer_id in self.generator.layer_collection.layers.keys():
diag, cross = self.data[layer_id]
norm2 += torch.dot(diag, diag)
if cross is not None:
norm2 += 2 * torch.dot(cross.view(-1), cross.view(-1))

return norm2 ** .5

def get_diag(self):
raise NotImplementedError
return torch.cat([self.data[l_id][0] for l_id in
self.generator.layer_collection.layers.keys()])

def inverse(self):
raise NotImplementedError
Expand All @@ -695,7 +712,8 @@ def mv(self):
raise NotImplementedError

def trace(self):
raise NotImplementedError
return sum([self.data[l_id][0].sum() for l_id in
self.generator.layer_collection.layers.keys()])

def vTMv(self):
raise NotImplementedError
18 changes: 18 additions & 0 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,21 @@ def test_jacobian_pquasidiag_vs_pdense():
# compare upper triangular block with lower triangular one
check_tensors(matrix_qd[start:start+sw+sb, start+sw:],
matrix_qd[start+sw:, start:start+sw+sb].t())

def test_jacobian_pquasidiag():
for get_task in [get_fullyconnect_task]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
function=function,
n_output=n_output)
PMat_qd = PMatQuasiDiag(generator)
dense_tensor = PMat_qd.get_dense_tensor()

check_tensors(torch.diag(dense_tensor), PMat_qd.get_diag())

check_ratio(torch.norm(dense_tensor), PMat_qd.frobenius_norm())

check_ratio(torch.trace(dense_tensor), PMat_qd.trace())

0 comments on commit dddc290

Please sign in to comment.