Example: a single neuron

$$ y = wx + b $$

Loss:
$$ L = \frac{1}{2}(y - t)^2 $$

we write out the deriavtives and then apply the gradient manually. To calculate the derivatives we apply the chain rule becuase L depends on y and y depends on w and b

Derivative of the loss w.r.t. the output:

$$ \frac{\partial L}{\partial y} = y - t $$

Derivative of the output w.r.t. the parameters:

$$ \frac{\partial y}{\partial w} = x $$

$$ \frac{\partial y}{\partial b} = 1 $$

Gradients of the loss w.r.t. parameters:

$$ \frac{\partial L}{\partial w} = (y - t)\, x $$

$$ \frac{\partial L}{\partial b} = (y - t) $$

Manual calculation

In [16]:
import math

#params
w = 0.5
b = -0.1
x = 2.0
t = 1.0 # target
lr = 0.1

#forward
y = w * x + b
loss = 0.5 * (y - t)**2

#manual grads - this is equivalemy of loss.backward()
dL_dy = y - t
dL_dw = dL_dy * x
dL_db = dL_dy

print("raw manual gradients:", dL_dw, dL_db)

#update 
w_updated = w - lr * dL_dw
b_updated = b - lr * dL_db

print("updated manual params:", w_updated, b_updated)

raw manual gradients: -0.19999999999999996 -0.09999999999999998
updated manual params: 0.52 -0.09000000000000001


Autograd - PyTorch computation

In [21]:
import torch
w = torch.tensor(0.5, requires_grad=True)
b = torch.tensor(-0.1, requires_grad=True)
x = torch.tensor(2.0)
t = torch.tensor(1.0)
lr = 0.1

optimizer = torch.optim.SGD([w, b], lr=0.1)

#forward
y = w * x + b
loss = 0.5 * (y - t)**2

# backward (compute gradients)
loss.backward()

print("raw torch gradients:", w.grad.item(), b.grad.item())

# update parameters manually with no_grad()
with torch.no_grad():
    w_updated = w - lr * w.grad
    b_updated = b - lr * b.grad

print("Manual updated torch params:", w_updated.item(), b_updated.item())

# or update using optimiser
optimizer.step()

print("Optimised updated torch params:", w.item(), b.item())

raw torch gradients: -0.20000004768371582 -0.10000002384185791
Manual updated torch params: 0.5199999809265137 -0.08999999612569809
Optimised updated torch params: 0.5199999809265137 -0.08999999612569809
