In [16]:
import torch 

In [17]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [18]:
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([2.0, 4.0, 6.0, 8.0])

In [19]:
W1 = torch.randn(1, requires_grad=True)
B1 = torch.randn(1, requires_grad=True)
n = len(y)

H = W1 * x + B1
difference = H - y
difference_squared = difference ** 2
summed = difference_squared.sum()
loss = summed / n

for t in [difference, difference_squared, H, W1, B1, summed]:
  t.retain_grad()
loss.backward()
loss

tensor(54.2126, grad_fn=<DivBackward0>)

In [20]:
dsummed = torch.ones_like(summed) * (1/n)
ddifferenced_squared = 1 * dsummed
ddifference = 2 * difference * ddifferenced_squared
dH = 1 * ddifference
dy = 1 * dH
dW1 = x @ dH
dB1 = 1 * dH.sum(0)

In [21]:
cmp('summed', dsummed, summed)
cmp('ddifference_squared', ddifferenced_squared, difference_squared)
cmp('ddifference', ddifference, difference)
cmp('dH', dH, H)
cmp('dW', dW1, W1)
cmp('dW', dB1, B1)

summed          | exact: True  | approximate: True  | maxdiff: 0.0
ddifference_squared | exact: True  | approximate: True  | maxdiff: 0.0
ddifference     | exact: True  | approximate: True  | maxdiff: 0.0
dH              | exact: True  | approximate: True  | maxdiff: 0.0
dW              | exact: True  | approximate: True  | maxdiff: 0.0
dW              | exact: True  | approximate: True  | maxdiff: 0.0


In [22]:
W1 = torch.randn(1, requires_grad=True)
B1 = torch.randn(1, requires_grad=True)

In [23]:
for i in range(1000):

    # forward pass
    H = W1 * x + B1
    difference = H - y
    difference_squared = difference ** 2
    summed = difference_squared.sum()
    loss = summed / n
    
    #backward pass
    dsummed = torch.ones_like(summed) * (1/n)
    ddifferenced_squared = 1 * dsummed
    ddifference = 2 * difference * ddifferenced_squared
    dH = 1 * ddifference
    dy = 1 * dH
    dW1 = x @ dH
    dB1 = 1 * dH.sum(0)

    W1 = W1 - 0.1 * dW1
    B1 = B1 - 0.1 * dB1

print(f'loss: {loss.item()}')


loss: 0.0
