Skip to content

Commit

Permalink
Merge pull request #9 from tfjgeorge/solve_lowrank
Browse files Browse the repository at this point in the history
adds solve to lowrank repr
  • Loading branch information
tfjgeorge committed Feb 26, 2021
2 parents 0d0a1e7 + a768e4c commit e31b90f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
7 changes: 5 additions & 2 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,11 @@ def frobenius_norm(self):
self.data.view(-1, self.data.size(2)).t())
return torch.norm(A)

def solve(self, v):
raise NotImplementedError
def solve(self, b, regul=1e-8):
u, s, v = torch.svd(self.data.view(-1, self.data.size(2)))
x = torch.mv(v, torch.mv(v.t(), b.get_flat_representation()) /
(s**2 + regul))
return PVector(b.layer_collection, vector_repr=x)

def get_diag(self):
return (self.data**2).sum(dim=(0, 1))
Expand Down
13 changes: 11 additions & 2 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def test_jacobian_plowrank():
n_output=n_output)
PMat_lowrank = PMatLowRank(generator=generator, examples=loader)
dw = random_pvector(lc, device=device)
dw = dw / dw.norm()
dense_tensor = PMat_lowrank.get_dense_tensor()

# Test get_diag
Expand All @@ -545,14 +546,22 @@ def test_jacobian_plowrank():

# Test mv
mv_direct = torch.mv(dense_tensor, dw.get_flat_representation())
mv = PMat_lowrank.mv(dw)
check_tensors(mv_direct,
PMat_lowrank.mv(dw).get_flat_representation())
mv.get_flat_representation())

# Test vTMV
check_ratio(torch.dot(mv_direct, dw.get_flat_representation()),
PMat_lowrank.vTMv(dw))

# Test solve TODO
# Test solve
# We will try to recover mv, which is in the span of the
# low rank matrix
regul = 1e-3
mmv = PMat_lowrank.mv(mv)
mv_using_inv = PMat_lowrank.solve(mmv, regul=regul)
check_tensors(mv.get_flat_representation(),
mv_using_inv.get_flat_representation(), eps=1e-2)
# Test inv TODO

# Test add, sub, rmul
Expand Down
14 changes: 6 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ def check_ratio(vref, v2, eps=1e-3):

def check_tensors(tref, t2, eps=1e-3, only_print_diff=False):
if torch.norm(tref) == 0:
if only_print_diff:
print(torch.norm(t2 - tref))
else:
assert torch.norm(t2 - tref) < eps
relative_diff = torch.norm(t2 - tref)
else:
relative_diff = torch.norm(t2 - tref) / torch.norm(tref)
if only_print_diff:
print(relative_diff)
else:
assert relative_diff < eps
if only_print_diff:
print(relative_diff)
else:
assert relative_diff < eps
return relative_diff


def check_angle(v1, v2, eps=1e-3):
Expand Down

0 comments on commit e31b90f

Please sign in to comment.