In [1]:
import torch as t

model = t.nn.Sequential(t.nn.Linear(16, 128), t.nn.ReLU(), t.nn.Linear(128, 16))

batch_size, n_batches = 10, 2
data_batches = [t.randn((batch_size, 16)) for _ in range(n_batches)]

############### Forward pass ###############

stored_outputs = {}

for i, inputs in enumerate(data_batches):
    outputs = model(inputs)
    stored_outputs[i] = outputs

############## Backwards pass ##############

loss_fn = t.nn.MSELoss()
optimizer = t.optim.SGD(model.parameters(), 0.1)

for i, targets in enumerate(data_batches):
    optimizer.zero_grad()
    
    outputs = stored_outputs[i]

    loss = loss_fn(targets, outputs)
    
    loss.backward()
    optimizer.step()

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 16]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I think the problem is related to the saved tensors on the graph. The best source of info I've found about this is here:

https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html

The hacky fix is just to re-run the forward pass each time before running the backward pass to repopulate the saved tensors correctly. There's probably a better fix that involves packing, saving, and unpacking the saved tensors appropriately for each microbatch.

In [2]:
x = data_batches[0]
outputs = model(x)
loss = loss_fn(x, outputs)

# This is one of the secret stateful tensors that stick around in the computation graph
loss.grad_fn._saved_self 

tensor([[-0.7427,  0.5930,  2.7530,  0.3157,  0.5558, -0.7625, -0.5604, -0.6728,
         -0.0853, -1.8455,  1.6673,  0.1490,  1.0437, -2.1890, -0.3083,  0.9335],
        [-0.1113, -1.7596, -0.4557, -0.2848,  1.0087,  1.6762, -0.7310,  0.1322,
          0.5212,  0.5044, -0.4175, -0.2241,  1.2632,  0.5341,  0.5677, -1.6662],
        [-0.6822, -0.1219, -0.9987, -1.1098,  1.2427, -0.2316, -1.2185,  0.1468,
          0.8322,  0.7007,  0.0628,  0.6061,  0.9081,  1.2711, -2.2495, -0.2998],
        [-1.4726, -0.1904,  1.7050,  0.6522,  2.2435,  0.7128,  0.8931,  1.0558,
          0.4479, -2.5197, -1.5794,  2.7752, -1.0689, -1.6108, -1.3693,  2.0114],
        [-0.5615,  1.5408, -1.4485, -0.0782,  0.3234,  1.0927, -0.0450, -1.7949,
          0.1155,  0.7487,  0.0355,  0.8411,  1.1565, -1.7849,  0.4582,  1.1205],
        [-0.0702,  0.2253,  0.6500,  0.1752,  0.3533,  0.6881, -2.0929,  1.7572,
         -0.1277,  0.6316,  0.1293,  0.1294, -0.5409, -1.0919,  0.6151,  0.1582],
        [-1.1471, -0.4

In [3]:
# Here's some more lingering saved data
f = loss.grad_fn.next_functions[1][0]
f._saved_mat2_strides

(1, 128)