Skip to content

Commit

Permalink
adds solve to PMatQuasiDiag
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Sep 17, 2020
1 parent 86e9ab8 commit fa7f21c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 23 deletions.
38 changes: 22 additions & 16 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,27 +857,33 @@ def solve(self, vs, regul=1e-8):
diag, cross = self.data[layer_id]

v_weight = vs_dict[layer_id][0]
d_weight = diag[:layer.weight.numel()].view(*v_weight.size()) + regul
solve_bias = None
# keep original size
s_w = v_weight.size()
v_weight = v_weight.view(s_w[0], -1)

d_weight = diag[:layer.weight.numel()].view(s_w[0], -1) + regul
solve_b = None
if layer.bias is None:
solve_weight = v_weight / d_weight
solve_w = v_weight / d_weight
else:
v_bias = vs_dict[layer_id][1]
d_bias = diag[layer.weight.numel():] + regul
if len(cross.size()) == 2:
d_bias_expanded = d_bias.view(-1, 1)
v_bias_expanded = v_bias.view(-1, 1)
elif len(cross.size()) == 4:
d_bias_expanded = d_bias.view(-1, 1, 1, 1)
v_bias_expanded = v_bias.view(-1, 1, 1, 1)

# solve_weight = v_weight / d_weight
# solve_bias = v_bias / d_bias
solve_weight = (v_weight * d_bias_expanded - v_bias_expanded * cross) / \
(d_weight * d_bias_expanded - cross**2)
print((d_weight * d_bias_expanded - cross**2))
solve_bias = (v_bias - (solve_weight * cross).view(v_bias.size(0), -1).sum(dim=1)) / d_bias
cross = cross.view(s_w[0], -1)

solve_b_denom = d_bias - bdot(cross / d_weight, cross)
solve_b = ((v_bias - bdot(cross / d_weight, v_weight))
/ solve_b_denom)

solve_w = (v_weight - solve_b.unsqueeze(1) * cross) / d_weight

out_dict[layer_id] = (solve_weight, solve_bias)
out_dict[layer_id] = (solve_w.view(*s_w), solve_b)
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)


def bdot(A, B):
"""
batched dot product
"""
return torch.matmul(A.unsqueeze(1), B.unsqueeze(2)).squeeze(1).squeeze(1)
21 changes: 14 additions & 7 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PMatImplicit, PMatLowRank, PMatQuasiDiag)
from nngeometry.generator import Jacobian
from nngeometry.object.vector import random_pvector, random_fvector, PVector
from utils import check_ratio, check_tensors
from utils import check_ratio, check_tensors, check_angle
import pytest


Expand Down Expand Up @@ -608,15 +608,15 @@ def test_jacobian_pquasidiag_vs_pdense():
for i in range(sb):
# check the strips bias/weight
check_tensors(matrix_dense[start+i*s_in:start+(i+1)*s_in,
start+sw+i:start+sw+i],
start+sw+i],
matrix_qd[start+i*s_in:start+(i+1)*s_in,
start+sw+i:start+sw+i])
start+sw+i])

# verify that the rest is 0
assert torch.norm(matrix_qd[start+i*s_in:start+(i+1)*s_in,
start+sw:start+sw+i]) < 1e-5
start+sw:start+sw+i]) < 1e-10
assert torch.norm(matrix_qd[start+i*s_in:start+(i+1)*s_in,
start+sw+i+1:]) < 1e-5
start+sw+i+1:]) < 1e-10

# compare upper triangular block with lower triangular one
check_tensors(matrix_qd[start:start+sw+sb, start+sw:],
Expand Down Expand Up @@ -644,8 +644,15 @@ def test_jacobian_pquasidiag():

check_ratio(torch.trace(dense_tensor), PMat_qd.trace())

mv = PMat_qd.mv(v)
check_tensors(torch.mv(dense_tensor, v_flat),
PMat_qd.mv(v).get_flat_representation())
mv.get_flat_representation())

check_ratio(torch.dot(torch.mv(dense_tensor, v_flat), v_flat),
PMat_qd.vTMv(v))
PMat_qd.vTMv(v))

# Test solve
regul = 1e-8
v_back = PMat_qd.solve(mv + regul * v, regul=regul)
check_tensors(v.get_flat_representation(),
v_back.get_flat_representation())

0 comments on commit fa7f21c

Please sign in to comment.