# Stochastic gradient descent

## General description

Stochastic Gradient Descent (SGD) is a computational trick to acclerate the optimization for the objective functions of the form
$$f(para) = \sum_{i=1}^m g(data_i; para).$$

This optimization problem is computationally hard if $g$ or $g'$ is not easy to compute and $m$ is huge.

To solve it, SGD pick a subset $I$ of the data points randomly at each iteration and apply gradient descent to the subset objective function
$$f(para) = \sum_{i \in I} g(data_i; para).$$

## Example: linear regression

Suppose we have data generated by $y_i = k \cdot x_i + N(0,1)$ for $i = 1, \ldots, m$. We want to find $k$ that minimize the mean square error
$$f(k) = \frac 1m \sum_{i=1}^m (y_i - k x_i)^2.$$

Note we have $f'(k) = -\frac 2m \sum_{i=1}^m x_i (y_i - k x_i)$.

In [1]:
using Random
Random.seed!(1)

m = 1000

x = collect(1:m) / m * 10
y = 3 * x + randn(m);

**Stopping criteria**: we cannot naively apply the stopping criteria of gradient descent to SGD. (why?)

We will discuss a proper stopping criteria for SGD later. But now let's use the following stopping rule and assume it can be computed very efficiently.

In [2]:
using Statistics

function stop_now(x, y, k, tol=1e-2)
    mean((y - k * x).^2) < mean((y - 3 * x).^2) + tol
end

stop_now (generic function with 2 methods)

In [3]:
function gradient_descent(x, y, k0 = 0, α=0.01; tol=1e-3, maxiter=500)
    k = k0
    iter = 0
    
    for i = 1:maxiter
        gradient = -2 * mean(x .* (y - k * x))
        
        if stop_now(x, y, k, tol)
            break
        end
        
        k -= α*gradient
        iter += 1
    end
    
    println("k = $(k)")
    return iter
end


iter_num = gradient_descent(x, y)
comp_count = iter_num * m

println("converges after $(iter_num) iterations; call g' $(comp_count) times.")

k = 2.99354312434587
converges after 6 iterations; call g' 6000 times.


In [4]:
function stochastic_gradient_descent(x, y, k0 = 0, α=0.01; subset_size = 10, tol=1e-3, maxiter=500)
    k = k0
    iter = 0
    
    for i = 1:maxiter
        
        #generate a subset of data at each iteration
        subset = rand(1:length(x), subset_size)
        # use the gradient of the subset objective function
        gradient = -2 * mean(x[subset] .* (y[subset] - k * x[subset]))
        
        if stop_now(x, y, k, tol)
            break
        end
        
        k -= α*gradient
        iter += 1
    end
    println("k = $(k)")
    return iter
end


Random.seed!(1)

subset_size = 10
iter_num = stochastic_gradient_descent(x, y, subset_size = subset_size)
comp_count = iter_num * subset_size

println("converges after $(iter_num) iterations; call g' $(comp_count) times.")

k = 2.9933621095869487
converges after 15 iterations; call g' 150 times.


## Stopping criteria

Let's call the data set above $(x, y)$ training set. Assume we have another data set named validation set that has much small size than the training set but is generated in the same way. e.g.

In [5]:
val_size = 10

x_val = rand(val_size) * 10
y_val = 3 * x_val + randn(val_size);

The validation set has size much smaller than the training set, so the same $f$ defined on it
$$f_{val}(para) = \sum_{i \in \text{validation}} g(data_i; para)$$
is easy to compute.  
Moreover, since the validation set is generated by the same mechanism, it should represent the overall property of the training set.

Combining these two points, we can propose a reasonable and computationally efficient stopping rule named **early stopping**: compute $f_{val}$ at each iteration and stop if we see $f_{val}$ start increasing.

![early_stopping](https://i.imgur.com/eP0gppr.png)

## Some final notes

The type of optimization problems that SGD tries to solve is not exactly the same as the traditional optimization problems which minimize $f(x)$. The SGD optimization problem builds its objective function on a data set and **aims to find the mechanism that generate this data set rather than minimizing the objective function itself**. These two goals are similar in many ways but not exactly the same. And this is also why we can introduce a validation data set that is independent of $f(x)$ to help stopping in optimizing $f(x)$.