In [1]:
import torch
from functools import partial


# SGD

## use SGD to estimate quadratic equation

We are trying to estimate a quadratic equation. To do so, we initialize random parameters of the equation. We use mean square error as loss function to calculate the gradient. Then we use SGD to 

In [2]:
# mean square error
def mse(preds, acts):
    return ((preds - acts) ** 2).mean()


def quad(a, b, c, x):
    return a * x**2 + b * x + x


def mk_quad(a, b, c):
    return partial(quad, a, b, c)


# target model
f = mk_quad(2, 3, 4)
f(2)

# assume some data points
x = torch.linspace(-2, 2, 20)[:, None]
torch.manual_seed(42)

# Generate a tensor of random numbers with the same shape as f(x)
# torch.rand_like(f(x)) generates random numbers between 0 and 1
# with the same shape as f(x). We scale and shift it to the desired range.
random_numbers = torch.rand_like(f(x)) * 10 - 5

# dataset
y = f(x) + random_numbers


# loss function
def quad_mse(params):
    f = mk_quad(*params)
    return mse(f(x), y)


# initial params
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()

loss = quad_mse(params)
loss

loss.backward()
params.grad

tensor([11.2822,  6.9424,  0.0000])

Let's calculate the the SGD and loss manually. 

Here is the loss function:
$$
\text{mse}(f(x), y) = \frac{1}{n} \sum_{i=1}^n (f(x_i) - y_i)^2
$$

where $n$ is the number of data points.

To calculate the gradient of the loss, we start with a generic expression

$$
f(x_i) = a x_i^2 + b x_i + c
$$

then

$$
\frac{\partial f(x_i)}{\partial a}
= x_i^2.
$$

Then we have the loss function

$$
L = \frac{1}{n} \sum_{i=1}^n (f(x_i) - y_i)^2.
$$

So for parameter $a$

$$
\frac{\partial L}{\partial a}
= \frac{2}{n} \sum_{i=1}^n (f(x_i) - y_i) \cdot x_i^2.
$$

The process is similar for $b$ and $c$.


In [3]:
print("target params", f)
print("initial params", params)
print("initial values of x: ", x[:2])

target params functools.partial(<function quad at 0x7fe6424a3060>, 2, 3, 4)
initial params tensor([4., 5., 7.], requires_grad=True)
initial values of x:  tensor([[-2.0000],
        [-1.7895]])


In [5]:
lr = 0.23
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()
for _ in range(10):
    loss = quad_mse(params)
    print("loss ", loss.item())
    loss.backward()
    params.data -= lr * params.grad.data
    params.grad = None

loss  23.300430297851562
loss  12.930697441101074
loss  10.261430740356445
loss  8.984521865844727
loss  8.224483489990234
loss  7.7517900466918945
loss  7.455584526062012
loss  7.269739627838135
loss  7.153112888336182
loss  7.079920768737793
