Skip to content

Commit

Permalink
Merge pull request #57 from tfjgeorge/deprec_solve
Browse files Browse the repository at this point in the history
changes to linalg.solve and linalg.eigh
  • Loading branch information
tfjgeorge committed May 12, 2023
2 parents ba2dcb1 + b78eec3 commit fd151d9
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 34 deletions.
13 changes: 8 additions & 5 deletions nngeometry/object/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ def __init__(self, generator, data=None, examples=None):
else:
self.data = generator.get_gram_matrix(examples)

def compute_eigendecomposition(self, impl='symeig'):
# TODO: test
if impl == 'symeig':
self.evals, self.evecs = torch.symeig(self.data, eigenvectors=True)
def compute_eigendecomposition(self, impl='eigh'):
s = self.data.size()
M = self.data.view(s[0] * s[1], s[2] * s[3])
if impl == 'eigh':
self.evals, self.evecs = torch.linalg.eigh(M)
elif impl == 'svd':
_, self.evals, self.evecs = torch.svd(self.data, some=False)
_, self.evals, self.evecs = torch.svd(M, some=False)
else:
raise NotImplementedError

def mv(self, v):
# TODO: test
Expand Down
50 changes: 24 additions & 26 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ def __init__(self, generator, data=None, examples=None):
else:
self.data = generator.get_covariance_matrix(examples)

def compute_eigendecomposition(self, impl='symeig'):
# TODO: test
if impl == 'symeig':
self.evals, self.evecs = torch.symeig(self.data, eigenvectors=True)
def compute_eigendecomposition(self, impl='eigh'):
if impl == 'eigh':
self.evals, self.evecs = torch.linalg.eigh(self.data)
elif impl == 'svd':
_, self.evals, self.evecs = torch.svd(self.data, some=False)
else:
Expand All @@ -141,10 +140,10 @@ def solve(self, v, regul=1e-8, impl='solve'):
# TODO: test
if impl == 'solve':
# TODO: reuse LU decomposition once it is computed
inv_v, _ = torch.solve(v.get_flat_representation().view(-1, 1),
self.data +
regul * torch.eye(self.size(0),
device=self.data.device))
inv_v = torch.linalg.solve(self.data +
regul * torch.eye(self.size(0),
device=self.data.device),
v.get_flat_representation().view(-1, 1))
return PVector(v.layer_collection, vector_repr=inv_v[:, 0])
elif impl == 'eigendecomposition':
v_eigenbasis = self.project_to_diag(v)
Expand Down Expand Up @@ -350,10 +349,10 @@ def solve(self, vs, regul=1e-8):
v = torch.cat([v, vs_dict[layer_id][1].view(-1)])
block = self.data[layer_id]

inv_v, _ = torch.solve(v.view(-1, 1),
block +
regul * torch.eye(block.size(0),
device=block.device))
inv_v = torch.linalg.solve(block +
regul * torch.eye(block.size(0),
device=block.device),
v.view(-1, 1))
inv_v_tuple = (inv_v[:layer.weight.numel()]
.view(*layer.weight.size),)
if layer.bias is not None:
Expand Down Expand Up @@ -478,8 +477,8 @@ def solve(self, vs, regul=1e-8, use_pi=True):
a_reg = a + regul**.5 * pi * torch.eye(a.size(0), device=g.device)
g_reg = g + regul**.5 / pi * torch.eye(g.size(0), device=g.device)

solve_g, _ = torch.solve(v, g_reg)
solve_a, _ = torch.solve(solve_g.t(), a_reg)
solve_g, _, _, _ = torch.linalg.lstsq(g_reg, v)
solve_a, _, _, _ = torch.linalg.lstsq(a_reg, solve_g.t())
solve_a = solve_a.t()
if layer.bias is None:
solve_tuple = (solve_a.view(*sw),)
Expand Down Expand Up @@ -571,14 +570,14 @@ def frobenius_norm(self):
return sum([torch.trace(torch.mm(a, a)) * torch.trace(torch.mm(g, g))
for a, g in self.data.values()])**.5

def compute_eigendecomposition(self, impl='symeig'):
def compute_eigendecomposition(self, impl='eigh'):
self.evals = dict()
self.evecs = dict()
if impl == 'symeig':
if impl == 'eigh':
for layer_id in self.generator.layer_collection.layers.keys():
a, g = self.data[layer_id]
evals_a, evecs_a = torch.symeig(a, eigenvectors=True)
evals_g, evecs_g = torch.symeig(g, eigenvectors=True)
evals_a, evecs_a = torch.linalg.eigh(a)
evals_g, evecs_g = torch.linalg.eigh(g)
self.evals[layer_id] = (evals_a, evals_g)
self.evecs[layer_id] = (evecs_a, evecs_g)
else:
Expand Down Expand Up @@ -625,8 +624,8 @@ def __init__(self, generator, data=None, examples=None):
for layer_id, layer in \
self.generator.layer_collection.layers.items():
a, g = kfac_blocks[layer_id]
evals_a, evecs_a = torch.symeig(a, eigenvectors=True)
evals_g, evecs_g = torch.symeig(g, eigenvectors=True)
evals_a, evecs_a = torch.linalg.eigh(a)
evals_g, evecs_g = torch.linalg.eigh(g)
evecs[layer_id] = (evecs_a, evecs_g)
diags[layer_id] = kronecker(evals_g.view(-1, 1),
evals_a.view(-1, 1))
Expand Down Expand Up @@ -843,12 +842,11 @@ def mv(self, v):
torch.mv(data_mat, v.get_flat_representation()))
return PVector(v.layer_collection, vector_repr=v_flat)

def compute_eigendecomposition(self, impl='symeig'):
if impl == 'symeig':
self.evals, V = torch.symeig(torch.mm(self.data, self.data.t()),
eigenvectors=True)
self.evecs = torch.mm(self.data.t(), V) / \
(self.evals**.5).unsqueeze(0)
def compute_eigendecomposition(self, impl='svd'):
data_mat = self.data.view(-1, self.data.size(2))
if impl == 'svd':
_, sqrt_evals, self.evecs = torch.svd(data_mat, some=True)
self.evals = sqrt_evals**2
else:
raise NotImplementedError

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==1.8.1
torchvision==0.9.1
requests==2.24.0
torch==2.0.1
torchvision>=0.9.1
requests>=2.24.0
68 changes: 68 additions & 0 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,74 @@ def test_jacobian_fdense_vs_pullback():
check_ratio(frob_direct, frob_FMat)


def test_jacobian_eigendecomposition_fdense():
for get_task in [get_small_conv_transpose_task]:
for impl in ['eigh', 'svd']:
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(layer_collection=lc,
model=model,
function=function,
n_output=n_output,
centering=True)
FMat_dense = FMatDense(generator=generator,
examples=loader)
FMat_dense.compute_eigendecomposition(impl=impl)
evals, evecs = FMat_dense.get_eigendecomposition()

tensor = FMat_dense.get_dense_tensor()
s = tensor.size()
check_tensors(tensor.view(s[0] * s[1], s[2] * s[3]),
evecs @ torch.diag_embed(evals) @ evecs.T)

with pytest.raises(NotImplementedError):
FMat_dense.compute_eigendecomposition(impl='stupid')


def test_jacobian_eigendecomposition_pdense():
for get_task in [get_small_conv_transpose_task]:
for impl in ['eigh', 'svd']:
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(layer_collection=lc,
model=model,
function=function,
n_output=n_output,
centering=True)
pmat_dense = PMatDense(generator=generator,
examples=loader)
pmat_dense.compute_eigendecomposition(impl=impl)
evals, evecs = pmat_dense.get_eigendecomposition()

check_tensors(pmat_dense.get_dense_tensor(),
evecs @ torch.diag_embed(evals) @ evecs.T)

with pytest.raises(NotImplementedError):
pmat_dense.compute_eigendecomposition(impl='stupid')


def test_jacobian_eigendecomposition_plowrank():
for get_task in [get_conv_task]:
for impl in ['svd']:
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(layer_collection=lc,
model=model,
function=function,
n_output=n_output,
centering=True)
pmat_lowrank = PMatLowRank(generator=generator,
examples=loader)
pmat_lowrank.compute_eigendecomposition(impl=impl)
evals, evecs = pmat_lowrank.get_eigendecomposition()

assert not evals.isnan().any()
assert not evecs.isnan().any()

check_tensors(pmat_lowrank.get_dense_tensor(),
evecs @ torch.diag_embed(evals) @ evecs.T)

with pytest.raises(NotImplementedError):
pmat_lowrank.compute_eigendecomposition(impl='stupid')


def test_jacobian_pdense_vs_pushforward():
# NB: sometimes the test with centering=True do not pass,
# which is probably due to the way we compute centering
Expand Down

0 comments on commit fd151d9

Please sign in to comment.