Skip to content

Commit

Permalink
typo in PMatBlockDiag
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Sep 16, 2020
1 parent a16e4e7 commit cc058da
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def solve(self, vs, regul=1e-8):

inv_v, _ = torch.solve(v.view(-1, 1),
block +
regul * torch.eye(sblock.size(0),
regul * torch.eye(block.size(0),
device=block.device))
inv_v_tuple = (inv_v[:layer.weight.numel()]
.view(*layer.weight.size),)
Expand Down Expand Up @@ -793,9 +793,6 @@ def get_diag(self):
return torch.cat([self.data[l_id][0] for l_id in
self.generator.layer_collection.layers.keys()])

def solve(self, v):
raise NotImplementedError

def trace(self):
return sum([self.data[l_id][0].sum() for l_id in
self.generator.layer_collection.layers.keys()])
Expand Down Expand Up @@ -852,3 +849,35 @@ def mv(self, vs):
out_dict[layer_id] = (mv_weight, mv_bias)
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)

def solve(self, vs, regul=1e-8):
vs_dict = vs.get_dict_representation()
out_dict = dict()
for layer_id, layer in self.generator.layer_collection.layers.items():
diag, cross = self.data[layer_id]

v_weight = vs_dict[layer_id][0]
d_weight = diag[:layer.weight.numel()].view(*v_weight.size()) + regul
solve_bias = None
if layer.bias is None:
solve_weight = v_weight / d_weight
else:
v_bias = vs_dict[layer_id][1]
d_bias = diag[layer.weight.numel():] + regul
if len(cross.size()) == 2:
d_bias_expanded = d_bias.view(-1, 1)
v_bias_expanded = v_bias.view(-1, 1)
elif len(cross.size()) == 4:
d_bias_expanded = d_bias.view(-1, 1, 1, 1)
v_bias_expanded = v_bias.view(-1, 1, 1, 1)

# solve_weight = v_weight / d_weight
# solve_bias = v_bias / d_bias
solve_weight = (v_weight * d_bias_expanded - v_bias_expanded * cross) / \
(d_weight * d_bias_expanded - cross**2)
print((d_weight * d_bias_expanded - cross**2))
solve_bias = (v_bias - (solve_weight * cross).view(v_bias.size(0), -1).sum(dim=1)) / d_bias

out_dict[layer_id] = (solve_weight, solve_bias)
return PVector(layer_collection=vs.layer_collection,
dict_repr=out_dict)

0 comments on commit cc058da

Please sign in to comment.