Skip to content

Commit

Permalink
fixes incorrect shape of returned PVector for KFAC and EKFAC mv and s…
Browse files Browse the repository at this point in the history
…olve
  • Loading branch information
tfjgeorge committed Sep 26, 2020
1 parent a9edc0c commit 30cfb46
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def solve(self, vs, regul=1e-8, use_pi=True):
vs_dict = vs.get_dict_representation()
out_dict = dict()
for layer_id, layer in self.generator.layer_collection.layers.items():
v = vs_dict[layer_id][0].view(vs_dict[layer_id][0].size(0), -1)
vw = vs_dict[layer_id][0]
sw = vw.size()
v = vw.view(sw[0], -1)
if layer.bias is not None:
v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1)
a, g = self.data[layer_id]
Expand All @@ -399,9 +401,9 @@ def solve(self, vs, regul=1e-8, use_pi=True):
solve_a, _ = torch.solve(solve_g.t(), a_reg)
solve_a = solve_a.t()
if layer.bias is None:
solve_tuple = (solve_a,)
solve_tuple = (solve_a.view(*sw),)
else:
solve_tuple = (solve_a[:, :-1].contiguous(),
solve_tuple = (solve_a[:, :-1].contiguous().view(*sw),
solve_a[:, -1].contiguous())
out_dict[layer_id] = solve_tuple
return PVector(layer_collection=vs.layer_collection,
Expand Down Expand Up @@ -454,15 +456,18 @@ def mv(self, vs):
vs_dict = vs.get_dict_representation()
out_dict = dict()
for layer_id, layer in self.generator.layer_collection.layers.items():
v = vs_dict[layer_id][0].view(vs_dict[layer_id][0].size(0), -1)
vw = vs_dict[layer_id][0]
sw = vw.size()
v = vw.view(sw[0], -1)
if layer.bias is not None:
v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1)
a, g = self.data[layer_id]
mv = torch.mm(torch.mm(g, v), a)
if layer.bias is None:
mv_tuple = (mv,)
mv_tuple = (mv.view(*sw),)
else:
mv_tuple = (mv[:, :-1].contiguous(), mv[:, -1].contiguous())
mv_tuple = (mv[:, :-1].contiguous().view(*sw),
mv[:, -1].contiguous())
out_dict[layer_id] = mv_tuple
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)
Expand Down Expand Up @@ -570,16 +575,19 @@ def mv(self, vs):
for l_id, l in self.generator.layer_collection.layers.items():
diag = diags[l_id]
evecs_a, evecs_g = evecs[l_id]
v = vs_dict[l_id][0].view(vs_dict[l_id][0].size(0), -1)
vw = vs_dict[l_id][0]
sw = vw.size()
v = vw.view(sw[0], -1)
if l.bias is not None:
v = torch.cat([v, vs_dict[l_id][1].unsqueeze(1)], dim=1)
v_kfe = torch.mm(torch.mm(evecs_g.t(), v), evecs_a)
mv_kfe = v_kfe * diag.view(*v_kfe.size())
mv = torch.mm(torch.mm(evecs_g, mv_kfe), evecs_a.t())
if l.bias is None:
mv_tuple = (mv,)
mv_tuple = (mv.view(*sw),)
else:
mv_tuple = (mv[:, :-1].contiguous(), mv[:, -1].contiguous())
mv_tuple = (mv[:, :-1].contiguous().view(*sw),
mv[:, -1].contiguous())
out_dict[l_id] = mv_tuple
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)
Expand Down Expand Up @@ -622,16 +630,19 @@ def solve(self, vs, regul=1e-8):
for l_id, l in self.generator.layer_collection.layers.items():
diag = diags[l_id]
evecs_a, evecs_g = evecs[l_id]
v = vs_dict[l_id][0].view(vs_dict[l_id][0].size(0), -1)
vw = vs_dict[l_id][0]
sw = vw.size()
v = vw.view(sw[0], -1)
if l.bias is not None:
v = torch.cat([v, vs_dict[l_id][1].unsqueeze(1)], dim=1)
v_kfe = torch.mm(torch.mm(evecs_g.t(), v), evecs_a)
inv_kfe = v_kfe / (diag.view(*v_kfe.size()) + regul)
inv = torch.mm(torch.mm(evecs_g, inv_kfe), evecs_a.t())
if l.bias is None:
inv_tuple = (inv,)
inv_tuple = (inv.view(*sw),)
else:
inv_tuple = (inv[:, :-1].contiguous(), inv[:, -1].contiguous())
inv_tuple = (inv[:, :-1].contiguous().view(*sw),
inv[:, -1].contiguous())
out_dict[l_id] = inv_tuple
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)
Expand Down

0 comments on commit 30cfb46

Please sign in to comment.