In [16]:
import torch

x = -1 * torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.mse_loss(z, y)
print("w",w)
print("b",b)
print("z",z)
print("loss", loss)
print(z.grad_fn)
print(loss.grad_fn)

w tensor([[ 0.4484,  1.3778,  2.2954],
        [-2.1837, -0.0734,  2.1886],
        [ 1.3695,  1.0754,  1.1764],
        [-0.7611, -0.1212, -2.1143],
        [-0.2216,  1.6680,  0.1538]], requires_grad=True)
b tensor([ 0.3694, -0.0062, -0.9621], requires_grad=True)
z tensor([ 1.7178, -3.9328, -4.6619], grad_fn=<AddBackward0>)
loss tensor(13.3838, grad_fn=<MseLossBackward0>)
<AddBackward0 object at 0x0000023597C52100>
<MseLossBackward0 object at 0x0000023597DC4850>


In [18]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[-1.1452,  2.6219,  3.1080],
        [-1.1452,  2.6219,  3.1080],
        [-1.1452,  2.6219,  3.1080],
        [-1.1452,  2.6219,  3.1080],
        [-1.1452,  2.6219,  3.1080]])
tensor([ 1.1452, -2.6219, -3.1080])


# Disabling Gradient Tracking

In [21]:
z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

True
False


In [23]:
z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

False


# Optional Reading: Tensor Gradients and Jacobian Products

In [57]:
inp = torch.eye(4, 5, requires_grad=True)
out = (inp + 1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"inp\n{inp}")
print(f"out\n{out}")
print(f"First call inp.grad\n{inp.grad}")

inp
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.]], requires_grad=True)
out
tensor([[4., 1., 1., 1.],
        [1., 4., 1., 1.],
        [1., 1., 4., 1.],
        [1., 1., 1., 4.],
        [1., 1., 1., 1.]], grad_fn=<TBackward0>)
First call inp.grad
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])


In [59]:
out.backward(torch.ones_like(out), retain_graph=True)
print(f"Second call inp.grad\n{inp.grad}")

out.backward(torch.ones_like(out), retain_graph=True)
print(f"third call inp.grad\n{inp.grad}")

Second call inp.grad
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])
third call inp.grad
tensor([[12.,  6.,  6.,  6.,  6.],
        [ 6., 12.,  6.,  6.,  6.],
        [ 6.,  6., 12.,  6.,  6.],
        [ 6.,  6.,  6., 12.,  6.]])


In [61]:
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")


Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
