Skip to content

Commit

Permalink
Merge pull request #42 from tfjgeorge/inplace
Browse files Browse the repository at this point in the history
fixes autograd error when using inplace operations
  • Loading branch information
tfjgeorge committed Dec 6, 2021
2 parents 4510c47 + 0ecb989 commit f3ad0dc
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 25 deletions.
38 changes: 17 additions & 21 deletions nngeometry/generator/jacobian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,13 @@ def implicit_Jv(self, v, examples):

def _add_hooks(self, hook_x, hook_gy, mods):
handles = []

def _hook_x(mod, i, o):
hook_x(mod, i)
o.register_hook(lambda g_o: hook_gy(mod, g_o))

for m in mods:
handles.append(m.register_forward_pre_hook(hook_x))
handles.append(m.register_full_backward_hook(hook_gy))
handles.append(m.register_forward_hook(_hook_x))
return handles

def _hook_savex(self, mod, i):
Expand All @@ -637,8 +641,7 @@ def _hook_savex_io(self, mod, i):
else:
self.x_inner[mod] = i[0]

def _hook_compute_flat_grad(self, mod, grad_input, grad_output):
gy = grad_output[0]
def _hook_compute_flat_grad(self, mod, gy):
x = self.xs[mod]
bs = x.size(0)
layer_id = self.m_to_l[mod]
Expand All @@ -649,8 +652,7 @@ def _hook_compute_flat_grad(self, mod, grad_input, grad_output):
start_p:start_p+layer.numel()],
mod, layer, x, gy)

def _hook_compute_diag(self, mod, grad_input, grad_output):
gy = grad_output[0]
def _hook_compute_diag(self, mod, gy):
x = self.xs[mod]
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
Expand All @@ -659,26 +661,23 @@ def _hook_compute_diag(self, mod, grad_input, grad_output):
self.diag_m[start_p:start_p+layer.numel()],
mod, layer, x, gy)

def _hook_compute_quasidiag(self, mod, grad_input, grad_output):
gy = grad_output[0]
def _hook_compute_quasidiag(self, mod, gy):
x = self.xs[mod]
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
diag, cross = self._blocks[layer_id]
FactoryMap[layer.__class__].quasidiag(diag, cross, mod, layer, x, gy)

def _hook_compute_layer_blocks(self, mod, grad_input, grad_output):
gy = grad_output[0]
def _hook_compute_layer_blocks(self, mod, gy):
x = self.xs[mod]
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
block = self._blocks[layer_id]
FactoryMap[layer.__class__].layer_block(block,
mod, layer, x, gy)

def _hook_compute_kfac_blocks(self, mod, grad_input, grad_output):
def _hook_compute_kfac_blocks(self, mod, gy):
mod_class = mod.__class__.__name__
gy = grad_output[0]
x = self.xs[mod]
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
Expand All @@ -693,9 +692,8 @@ def _hook_compute_kfac_blocks(self, mod, grad_input, grad_output):
else:
raise NotImplementedError

def _hook_compute_kfe_diag(self, mod, grad_input, grad_output):
def _hook_compute_kfe_diag(self, mod, gy):
mod_class = mod.__class__.__name__
gy = grad_output[0]
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
x = self.xs[mod]
Expand All @@ -706,13 +704,13 @@ def _hook_compute_kfe_diag(self, mod, grad_input, grad_output):
else:
raise NotImplementedError

def _hook_kxy(self, mod, grad_input, grad_output):
def _hook_kxy(self, mod, gy):
if self.outerloop_switch:
self.gy_outer[mod] = grad_output[0]
self.gy_outer[mod] = gy
else:
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
gy_inner = grad_output[0]
gy_inner = gy
gy_outer = self.gy_outer[mod]
x_outer = self.x_outer[mod]
x_inner = self.x_inner[mod]
Expand All @@ -725,9 +723,8 @@ def _hook_kxy(self, mod, grad_input, grad_output):
self.e_outer:self.e_outer+bs_outer],
mod, layer, x_inner, gy_inner, x_outer, gy_outer)

def _hook_compute_Jv(self, mod, grad_input, grad_output):
def _hook_compute_Jv(self, mod, gy):
if self.compute_switch:
gy = grad_output[0]
x = self.xs[mod]
bs = x.size(0)
layer_id = self.m_to_l[mod]
Expand All @@ -740,8 +737,7 @@ def _hook_compute_Jv(self, mod, grad_input, grad_output):
self._Jv[self.i_output, self.start:self.start+bs],
mod, layer, x, gy, v_weight, v_bias)

def _hook_compute_trace(self, mod, grad_input, grad_output):
gy = grad_output[0]
def _hook_compute_trace(self, mod, gy):
x = self.xs[mod]
layer_id = self.m_to_l[mod]
layer = self.layer_collection.layers[layer_id]
Expand Down
1 change: 0 additions & 1 deletion nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None,
stride=mod.stride, padding=mod.padding, dilation=mod.dilation)
t = (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1)
print(gw.size(), t.size())
gw -= (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1)
buffer.add_(gw.view(bs, -1))

Expand Down
2 changes: 1 addition & 1 deletion tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, x):
x = tF.max_pool2d(x, 2, 2)
x = tF.relu(self.conv2(x))
x = tF.max_pool2d(x, 2, 2)
x = tF.relu(self.conv3(x))
x = tF.relu(self.conv3(x), inplace=True)
x = x.view(-1, 1*1*7)
if self.normalization == 'batch_norm':
x = self.bn2(self.fc1(x))
Expand Down
2 changes: 0 additions & 2 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_jacobian_pullback_dense():

def test_jacobian_fdense_vs_pullback():
for get_task in linear_tasks + nonlinear_tasks:
print(get_task)
for centering in [True, False]:
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(layer_collection=lc,
Expand Down Expand Up @@ -565,7 +564,6 @@ def test_jacobian_plowrank():
# We will try to recover mv, which is in the span of the
# low rank matrix
regul = 1e-3
print(get_task)
mmv = PMat_lowrank.mv(mv)
mv_using_inv = PMat_lowrank.solve(mmv + regul*mv, regul=regul)
check_tensors(mv.get_flat_representation(),
Expand Down

0 comments on commit f3ad0dc

Please sign in to comment.