In [1]:
import torch

In [5]:
x = torch.rand(10, requires_grad=True)
print(x)

x_no_requires_grad = torch.rand(10)
print(x_no_requires_grad)

tensor([0.4577, 0.6337, 0.7402, 0.4093, 0.3120, 0.8809, 0.9714, 0.7267, 0.4113,
        0.9406], requires_grad=True)
tensor([0.2036, 0.5271, 0.1620, 0.6432, 0.8239, 0.9221, 0.5321, 0.1278, 0.0251,
        0.7651])


In [6]:
(x_no_requires_grad**2).mean()

tensor(0.3174)

In [9]:
# see how this has a grad_fn that allows us to do backpropagation
# since we have requires_grad=True and we can compute gradients
b = (x**2).mean()

print(b)

b.backward()

tensor(0.4725, grad_fn=<MeanBackward0>)


In [14]:
# running the backward function will populate the grad attribute of the tensor
print(x.grad)

# gradient of the function b derived mathematically
mathematical_derivative = 2*x/10
print(mathematical_derivative)

# note how the two values are the same

tensor([0.0915, 0.1267, 0.1480, 0.0819, 0.0624, 0.1762, 0.1943, 0.1453, 0.0823,
        0.1881])
tensor([0.0915, 0.1267, 0.1480, 0.0819, 0.0624, 0.1762, 0.1943, 0.1453, 0.0823,
        0.1881], grad_fn=<DivBackward0>)


In [26]:
# one million elements
x = torch.rand(2 ** 20, requires_grad=True, device='mps')

# utilizes 4MB of memory (due to each tensor having a float32 value, which requires 4 bytes per element)
torch.mps.current_allocated_memory() / 1024 / 1024


4.0

In [27]:
b = torch.relu(x)
torch.mps.current_allocated_memory() / 1024 / 1024

# adding one additional node to the computation graph made us use 8MB of memory

8.0

In [28]:
# try doing this a bunch of times
b = x
for _ in range(100):
    b = torch.relu(b)
torch.mps.current_allocated_memory() / 1024 / 1024  

404.0

In [29]:
# now sum and call backwards
b.sum().backward()
torch.mps.current_allocated_memory() / 1024 / 1024  

# this collapsed the entire computation graph, so all the memory that
# was allocated to compute gradients has now collapsed and deallocated

# this is why we cannot call backwards on a graph twice, since the memory allocated is now freed

12.000244140625

In [31]:
x.grad

tensor([1., 1., 1.,  ..., 1., 1., 1.], device='mps:0')