Skip to content

Commit

Permalink
codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Sep 15, 2020
1 parent a207389 commit a16e4e7
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,18 @@ def solve(self, vs, regul=1e-8):
v = vs_dict[layer_id][0].view(-1)
if layer.bias is not None:
v = torch.cat([v, vs_dict[layer_id][1].view(-1)])
block = self.data[layer_id]

inv_v, _ = torch.solve(v.view(-1, 1),
self.data[layer_id] +
regul * torch.eye(self.data[layer_id].size(0),
device=self.data[layer_id].device))
inv_v_tuple = (inv_v[:layer.weight.numel()].view(*layer.weight.size),)
block +
regul * torch.eye(sblock.size(0),
device=block.device))
inv_v_tuple = (inv_v[:layer.weight.numel()]
.view(*layer.weight.size),)
if layer.bias is not None:
inv_v_tuple = (inv_v_tuple[0],
inv_v[layer.weight.numel():].view(*layer.bias.size),)
inv_v[layer.weight.numel():]
.view(*layer.bias.size),)

out_dict[layer_id] = inv_v_tuple
return PVector(layer_collection=vs.layer_collection,
Expand Down Expand Up @@ -762,7 +765,8 @@ def get_dense_tensor(self):
out_s = cross.size(0)
in_s = cross.numel() // out_s

block_bias = torch.cat((cross.view(cross.size(0), -1).t().reshape(-1, 1),
block_bias = torch.cat((cross.view(cross.size(0), -1).t()
.reshape(-1, 1),
torch.zeros((out_s * in_s, out_s),
device=device)),
dim=1)
Expand Down

0 comments on commit a16e4e7

Please sign in to comment.