In [4]:
import torch
from functools import partial


# SGD

## use SGD to estimate quadratic equation

We are trying to estimate a quadratic equation. To do so, we initialize random parameters of the equation. We use mean square error as loss function to calculate the gradient. Then we use SGD to 

In [21]:
# mean square error
def mse(preds, acts):
    return ((preds - acts) ** 2).mean()


def quad(a, b, c, x):
    return a * x**2 + b * x + x


def mk_quad(a, b, c):
    return partial(quad, a, b, c)


# target model
f = mk_quad(2, 3, 4)
f(2)

# assume some data points
x = torch.linspace(-2, 2, 20)[:, None]
torch.manual_seed(42)

# Generate a tensor of random numbers with the same shape as f(x)
# torch.rand_like(f(x)) generates random numbers between 0 and 1
# with the same shape as f(x). We scale and shift it to the desired range.
random_numbers = torch.rand_like(f(x)) * 10 - 5

# dataset
y = f(x) + random_numbers


# loss function
def quad_mse(params):
    f = mk_quad(*params)
    return mse(f(x), y)


# initial params
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()

loss = quad_mse(params)
loss

loss.backward()
params.grad

tensor([11.2822,  6.9424,  0.0000])

Let's calculate the the SGD and loss manually. 

Here is the loss function:
$$
\text{mse}(f(x), y) = \frac{1}{n} \sum_{i=1}^n (f(x_i) - y_i)^2
$$

where $n$ is the number of data points.

To calculate the gradient of the loss, we start with a generic expression

$$
f(x_i) = a x_i^2 + b x_i + c
$$

then

$$
\frac{\partial f(x_i)}{\partial a}
= x_i^2.
$$

Then we have the loss function

$$
L = \frac{1}{n} \sum_{i=1}^n (f(x_i) - y_i)^2.
$$

So for parameter $a$

$$
\frac{\partial L}{\partial a}
= \frac{2}{n} \sum_{i=1}^n (f(x_i) - y_i) \cdot x_i^2.
$$

The process is similar for $b$ and $c$.


In [24]:
print("target params", f)
print("initial params", params)
print("initial values of x: ", x[:2])

target params functools.partial(<function quad at 0x7f470ba72a20>, 2, 3, 4)
initial params tensor([4., 5., 7.], requires_grad=True)
initial values of x:  tensor([[-2.0000],
        [-1.7895]])


In [None]:
lr = 0.0023
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()
for _ in range(1000):
    loss = quad_mse(params)
    print("loss ", loss.item())
    loss.backward()
    params.data -= lr * params.grad.data
    params.grad = None

loss  23.300430297851562
loss  22.899810791015625
loss  22.511001586914062
loss  22.133596420288086
loss  21.767236709594727
loss  21.411556243896484
loss  21.066207885742188
loss  20.73085594177246
loss  20.405170440673828
loss  20.088848114013672
loss  19.78158187866211
loss  19.483074188232422
loss  19.19304847717285
loss  18.911235809326172
loss  18.63736915588379
loss  18.37118911743164
loss  18.112457275390625
loss  17.86093521118164
loss  17.616390228271484
loss  17.37860107421875
loss  17.14735984802246
loss  16.922447204589844
loss  16.70367431640625
loss  16.4908447265625
loss  16.283777236938477
loss  16.08228302001953
loss  15.886190414428711
loss  15.695341110229492
loss  15.509554862976074
loss  15.328689575195312
loss  15.152587890625
loss  14.981100082397461
loss  14.8140869140625
loss  14.65141487121582
loss  14.492950439453125
loss  14.338556289672852
loss  14.188115119934082
loss  14.041513442993164
loss  13.898626327514648
loss  13.759344100952148
loss  13.623559951