# ECON622: Computational Economics with Data Science Applications

Optimization for Machine Learning

Jesse Perla (University of British Columbia)

# Overview

## Summary

-   This lecture continues from the previous lecture on gradients to
    further explore optimization methods in machine learning, and
    discusses training pipelines and tooling
-   Primary reference materials are:
    -   [ProbML Book 1:
        Introduction](https://probml.github.io/pml-book/book1.html)
    -   [ProbML Book 2: Advanced
        Topics](https://probml.github.io/pml-book/book2.html) including
        Section 6.3
    -   [Mark Schmidt’s ML Lecture
        Notes](https://www.cs.ubc.ca/~schmidtm/Courses/LecturesOnML/)
-   We will also give a sense of a standard machine learning pipeline of
    training, validation, and test data and discuss generalization,
    logging, etc.

## Why the Emphasis on Optimization and Gradients?

-   A huge number of algorithms for economists can be written as
    optimization problems (e.g., MLE, interpolation) or as something
    similar in spirit (e.g. Bayesian Sampling, Reinforcement Learning)
-   Previous lectures on AD showed ways to find VJPs for extremely
    complicated functions. **Differentiate everything, no excuses!**
-   In practice, **all** problems with high-dimensions parameters or
    latents require gradients for algorithms to be feasible
-   We will soon take a further step: **are (unbiased) estimates of the
    gradient good enough for many algorithms?**

# Optimization Crash Course

## Optimization Methods

-   Learning continuous optimization methods is an enormous project
-   See referenced materials and lecture notes
-   Here we will give an overview of some key concepts
-   Be warned! The details matter, so more study is required if you want
    to use these methods in practice

## Crash Course in Unconstrained Optimization

$$
\min_{\theta} \mathcal{L}(\theta)
$$

Will briefly introduce

-   First-order methods
-   Second-order methods
-   Preconditioning
-   Momentum
-   Regularization

## First-Order Methods

-   See [ProbML Book 1](https://probml.github.io/pml-book/book1.html)
    Section 8.2
-   Armed with reverse-mode AD for
    $\mathcal{L} : \mathbb{R}^N \to \mathbb{R}$ we can calculate
    $\nabla \mathcal{L}(\theta)$ with the same computational order as
    $\mathcal{L}(\theta)$
-   Furthermore, given JVPs we know we can calculate these objective
    functions for extremely complicated functions (e.g., nested fixed
    points, and implicit functions)
-   Iterative: take $\theta_0$ and provide $\theta_t \to \theta_{t+1}$
    -   May converge to a stationary point (hopefully close to a global
        argmin)
    -   If it doesn’t converge, the solution may still be an argmin
    -   See references for details on convergence for convex and
        non-convex problems

## Gradient Descent

-   See [Mark Schmidt’s
    Notes](https://www.cs.ubc.ca/~schmidtm/Courses/340-F22/L13.pdf)
-   Gradient descent takes $\theta_0$, and stepsize $\eta_t$ and
    iterates until $\nabla \mathcal{L}(\theta_t)$ is small, or
    $\theta_t$ stationary

$$
\theta_{t+1} = \theta_t - \eta_t \nabla \mathcal{L}(\theta_t)
$$

-   It is the simplest “first-order” method (i.e., ones using just the
    gradient of $\mathcal{L}$)
-   Will call $\eta_t$ a “learning rate schedule”
-   Think of line-search methods as choosing the stepsize $\eta_t$
    optimally. Useful as well for economists, even if used infrequently
    in M

## When and Where Does This Converge?

-   Skipping a million details, see [ProbML Book
    1](https://probml.github.io/pml-book/book1.html) Section 8.2.2 and
    [Mark Schmidt’s
    basic](https://www.cs.ubc.ca/~schmidtm/Courses/340-F22/L14.pdf) and
    [more advanced
    notes](https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S22/S1.pdf)

-   For strictly convex problems this converges to the global minima,
    though sufficient conditions include Robbins-Monro
    $\lim_{T\to \infty} \eta_T \to 0$ and

    $$
    \lim_{T\to\infty}\frac{\sum_{t=1}^T \eta_t}{\sum_{t=1}^T \eta_t^2} = 0
    $$

-   For problems that not globally convex this may go to local optima,
    but if the function is locally strictly convex then it will converge
    to a local optima

-   For other types of functions (e.g.,
    [invex](https://en.wikipedia.org/wiki/Invex_function)) it may still
    converge to the “right” solution in some important sense

## Preconditioned Gradient Descent

-   As we saw analyzing LLS, badly conditioned problems converge slowly
    with iterative methods

-   We can precondition a problem as we did with linear systems, and it
    has the same stationary point

-   Choose some $C_t$ for preconditioned gradient descent $$
    \theta_{t+1} = \theta_t - \eta_t C_t \nabla \mathcal{L}(\theta_t)
    $$

-   We saw before that the Hessian tells us the geometry, so the optimal
    preconditioner must be related to $\nabla^2 \mathcal{L}(\theta_t)$

## Second-Order Methods

-   See [ProbML Book 1](https://probml.github.io/pml-book/book1.html)
    Section 8.3 and [Mark Schmidt’s
    Notes](https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S22/S2.pdf)
-   Adapt $\eta_t C_t$ to use the Hessian (e.g., Newton’s Method)

$$
\theta_{t+1} = \theta_t - \eta_t \left[\nabla^2 \mathcal{L}(\theta_t) \right]^{-1}\nabla \mathcal{L}(\theta_t)
$$

-   Second order methods are rarer because the calculating the Hessian
    is no longer the same computational order as $\mathcal{L}(\theta)$
-   See [ProbML Book 1](https://probml.github.io/pml-book/book1.html)
    Section 8.3.2 for info on Quasi-Newtonian methods which
    approximation Hessian using gradients like BFGS

## Momentum

-   See [ProbML Book 1](https://probml.github.io/pml-book/book1.html)
    Section 8.2.4 and [Mark Schmidt’s
    Notes](https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S22/S2.pdf)
-   Can use “momentum”, which speeds up convergence, helps avoid local
    optima, and moves fast in flat regions
-   Momentum will be a common feature of many ML optimizers (e.g. Adam,
    RMSProp, etc.) as it helps with heavily non-convex problems
-   A classic method is called Nesterov Accelerated Gradient (NAG),
    which is a modification of gradient descent for some
    $\beta_t\in (0,1)$ (e.g., $0.9$)

$$
\begin{aligned}
\hat{\theta}_{t+1} &= \theta_t + \beta_t(\theta_t - \theta_{t-1})\\
\theta_{t+1} &= \hat{\theta}_{t+1} - \eta_t \nabla \mathcal{L}(\hat{\theta}_{t+1})
\end{aligned}
$$

## Does Uniqueness Matter?

-   Remember from our previous lecture on Sobolev norms and
    regularization that we care about functions, not parameters.
-   Consider when $\theta$ is used as parameters for a function
    (e.g. $\hat{f}_{\theta}$)
    -   Then what does a lack of convergence of the $\theta_t$ or
        multiplicity with multiple $\theta$ solutions mean?
    -   Maybe nothing! If
        $||\hat{f}_{\theta_0} - \hat{f}_{\theta_1}||_S$ is small, then
        the functions themselves may be in the same equivalence class.
        Depends on the norm, of course.
-   This topic will be discussed when we consider double-descent curves,
    but the punchline for now is that the training/optimization is a
    means to an end (i.e., generalization) and not an end in itself.

## Regularization

-   See [Mark Schmidt’s
    Notes](https://www.cs.ubc.ca/~schmidtm/Courses/340-F22/L16.pdf). For
    LLS this is the ridge regression
-   We discussed regularization as a way to deal with multiplicity

$$
\min_{\theta}\left[\mathcal{L}(\theta) + \frac{\alpha}{2} ||\theta||^2\right]
$$

-   Gradient descent becomes (called “weight decay” in ML, and “ridge
    regression” if objective is LLS)

$$
\theta_{t+1} = \theta_t - \eta_t \left[\nabla \mathcal{L}(\theta_t) + \alpha \theta_t\right]
$$

-   Mapping of regularized $\theta_t$ to a $f_{\theta_t}$ is subtle if
    nonlinear

# Stochastic Optimization

## Are Gradients Really that Cheap to Calculate?

-   Consider that the objective often involves data (or grid points for
    interpolation)
    -   Denote $x_n$, and observables $y_n$ for $n=1, \ldots N$
-   With VJPs, the computational order of
    $\nabla_{\theta} \mathcal{L}(\theta;\{x_n, y_n\}_{n=1}^N)$ may be
    the same as that of $\mathcal{L}$ itself
-   However, keep in mind that reverse-mode requires storing the
    intermediate values in the “primal” calculation (i.e.,
    $\mathcal{L}(\theta;\{x_n, y_n\}_{n=1}^N)$)
    -   Hence, the memory requirements grow with $N$
    -   This may be a big problem for large datasets or complicated
        calculations, especially with GPUs which have more limited
        memory

## Do We Need the Full Gradient?

-   In practice, it is impossible to calculate the full gradient for
    large datasets

-   In GD, the gradient provided the direction of steepest descent

-   Consider an algorithm with a $g_t$ as an unbiased estimate of the
    gradient

    $$
    \begin{aligned}
    \theta_{t+1} = \theta_t - \eta_t g_t\\
    \mathbb{E}[g_t] = \nabla \mathcal{L}(\theta_t)
    \end{aligned}
    $$

    -   Make the $\eta_t$ smaller to deal with noise if this is
        high-variance
    -   Choose $g_t$ to be far cheaper to calculate than
        $\nabla \mathcal{L}(\theta_t)$

-   Will turn out that this also adds additional regularization, which
    helps with generalization

## Stochastic Optimization

-   To formalize: Up until now our optimizers have been “deterministic”
-   Now we introduce a source of randomness $z \sim q_{\theta}(z)$,
    i.e. it might depend on the estimated parameters $\theta$ later with
    RL/etc.
    -   $z$ could be a source of uncertainty in the environment
    -   $z$ could involve latent variables
    -   $z$ could come from randomness in the optimization process
        (e.g., using subsets of data to form $g_t$)
-   Denote expectations using this distribution as
    $\mathbb{E}_{q_{\theta}(z)}$
-   For now, drop the dependence on $\theta$ for simplicity, though it
    becomes crucial for understanding reinforcement learning/etc.

## Stochastic Objective

-   The full optimization problem is then to minimize this stochastic
    objective

$$
\min_{\theta}\overbrace{\mathbb{E}_{q(z)} \tilde{\mathcal{L}}(\theta, z)}^{\equiv \mathcal{L}(\theta)}
$$

-   Under appropriate regularity conditions, could use GD on this
    objective

$$
\nabla \mathcal{L}(\theta) = \mathbb{E}_{q(z)}\left[\nabla \tilde{\mathcal{L}}(\theta, z)\right]
$$

-   But in practice, it is rare that we can marginalize out the $z$

## Unbiased Draws from the Gradient

-   Assume we can sample $z_t \sim q(z)$ IID
-   Then with enough regularity the gradient using just $z_t$ is
    unbiased

$$
\mathbb{E}_{q(z)}\left[  \nabla \tilde{\mathcal{L}}(\theta_t, z_t) \right] = \nabla \mathcal{L}(\theta_t)
$$

-   That is, on average $\nabla \tilde{\mathcal{L}}(\theta_t, z_t)$ is
    in the right direction for minimizing $\mathcal{L}(\theta_t)$
-   This basic approach of finding unbiased estimators of the gradient
    (and finding ways to lower the variance) is at the heart of most ML
    optimization algorithms

## Stochastic Gradient Descent

-   See [ProbML Book 1](https://probml.github.io/pml-book/book1.html)
    Section 8.4, [ProbML Book
    2](https://probml.github.io/pml-book/book2.html) Section 6.3, and
    [Mark Schmidt’s
    Notes](https://www.cs.ubc.ca/~schmidtm/Courses/340-F22/L23.pdf)
-   Given the previous slide, given IID samples $z_t \sim q$, the
    gradient is unbiased and we have the simplest version of stochastic
    gradient descent (SGD)

$$
\theta_{t+1} = \theta_t - \eta_t \nabla \tilde{\mathcal{L}}(\theta_t, z_t)
$$

-   Which converges to the minima of $\min_{\theta} \mathcal{L}(\theta)$
    under appropriate conditions
-   We can layer on all of the other features we discussed (e.g.,
    momentum, preconditioning, etc) with SGD, but some become especially
    important (e.g. the $\eta_t$ schedule)

## Finite-Sum Objectives

-   Consider a special case of the loss function which is the sum of $N$
    terms. For example with empirical risk minimization used in LLS/etc.

    -   $z_n \equiv (x_n, y_n)$ are typically data, observables, or grid
        points
    -   $\ell(\theta, x_n, y_n)$ is a loss function for a single data
        point (e.g., forecasting using some $f_{\theta}$)

    $$
    \mathcal{L}(\theta) = \frac{1}{N}\sum_{n=1}^N \tilde{\mathcal{L}}(\theta, z_n) \equiv \frac{1}{N}\sum_{n=1}^N \ell(\theta, x_n, y_n)
    $$

    -   For example, LLS is
        $\ell(\theta, x_n, y_n) = ||y_n - \theta \cdot x_n||^2_2$

-   In this case, the randomness of $z_t$ is which data point is chosen

## SGD for Finite-Sum Objectives

-   Hence consider sampling $z_t \equiv (x_t, y_t)$ from our data.
    -   In principle, IID with replacement
-   Then run SGD on one data point at a time

$$
\theta_{t+1} = \theta_t - \eta_t \nabla_{\theta} \ell(\theta_t, x_t, y_t)
$$

-   This may converges to the minima of $\mathcal{L}(\theta)$, and
    potentially the storage requirements for calculations the gradient
    are radically reduced
-   You can guess that the $\eta_t$ parameter is especially sensitive to
    the variance of the gradient estimate

## Decrease Variance with Multiple Draws

-   With a single draw, the variance of the gradient estimate may be
    high

$$
\mathbb{E}\left[\nabla_{\theta} \ell(\theta_t, x_t, y_t)- \nabla \mathcal{L}(\theta_t)\right]^2
$$

-   One tool to decrease the variance is just more monte-carlo draws.
    With finite-sum objectives draw $B \subseteq \{1,\ldots N\}$ indices

$$
\frac{1}{|B|}\sum_{n \in B} \nabla_{\theta} \ell(\theta_t, x_n, y_n)
$$

-   Classic SGD: $|B|=1$; GD: $B = \{1, \ldots N\}$ and in between is
    called “minibatch SGD”. Usually minibatch is implied with “SGD”

## Minibatch SGD

-   Algorithm is to draw $B_t$ indices at each step and execute SGD $$
    \begin{aligned}
    g_t \equiv \frac{1}{|B_t|}\sum_{n \in B_t} \nabla_{\theta} \ell(\theta_t, x_n, y_n)\\
    \theta_{t+1} = \theta_t - \eta_t g_t
    \end{aligned}
    $$

-   Note that we never need to calculate $\mathcal{L}(\theta_t)$
    directly, so can write our code to all operate on batches $B_t$

-   Then layer other tricks on top (e.g., momentum, preconditioning,
    etc.)

    -   In principle you could also use minibatch with second-order or
        quasi-newtonian methods but much rarer

## Choosing Batches

-   Choosing the $B_t$ process may be tricky. You could sample from
    $\{1,\ldots N\}$
    -   with replacement
    -   without replacement
    -   without replacement after shuffling the data, and then ensure
        you have gone through all of the data before repeating
    -   etc.
-   Just remember the goal: variance reduction on gradient estimates
-   You want it to be unbiased in principle (consider partitioning the
    data into batches and operating sequentially?)
-   More art than science in many cases, because it requires many priors

## “Grad Student Descent”

-   This is how virtually all deep learning works. Just swap SGD with
    slightly fancier algorithms using momentum, tinker with parameters,
    etc.
-   In practice, all of these optimizer settings (e.g., how large for
    $|B_t|$, $\eta_t$, convergence criteria, etc.) are fragile and
    require a lot of tuning
    -   Part of a a process called **hyperparameter optimization (HPO)**
        where you try to find the best non-model parameters for your
        goals
    -   Same issue with all numerical methods in economics
        (e.g. convergence criteria of fixed point iteration, initial
        conditions)
-   The concern is not just that it is time-consuming for researchers
    (and ML “Grad Students”), but that it is easy for priors to sneak in
    and bias results

## What was our Goal?

-   We will address this more formally next lecture, but it is worth
    stepping back to think about our goals. Loosely:
    -   If we are solving an empirical risk minimization problem (like
        regressions, etc.) or interpolation, then our goal is to use the
        “data” to find a function $\hat{f}_{\theta}$ that is close to
        the “true” function $f^*$
-   Fitting $\hat{f}_{\theta}$ is easy, but we want it to **generalize**
    within the true distribution
    -   But we don’t know that distribution (hence the “empirical”)
    -   So a typical approach is to emulate this by splitting the data
        we have
    -   But HPO is dangerous because if we are not careful we can
        “contaminate” our process for finding $\hat{f}_{\theta}$ using
        some of the data we intend to check it with. Which might lead to
        overfitting/etc.

# Training Loops

## Splitting the Data

A standard way to do this for Empirical Risk
Minimization/Regressions/etc. is to split it into three parts:

1.  **Training** data used in fitting our approximations
    -   This is just a means to an end in ML and economics
2.  **Validation** data used for HPO and checking convergence criteria
    -   Be cautious to avoid using it for training
3.  **Test** data used to evaluate the generalization performance
    -   Ensure we don’t accidentally use it in training or validation

Not all problems will have this structure, and not all with have
validation data.

## Why Separate Validation and Test?

-   As we will see in deep learning, with massive over-parameterization
    you typically can interpolate all of the training data.
    -   Minimizing training loss is a means to an end, which usually
        ends at zero
-   The validation data might be used to check stopping criteria by
    checking how well the approximation generalizes to data outside of
    training
-   But if we are using it for a stopping criteria or HPO, then is is
    **contaminated**!
    -   Distorts our picture of generalization if we combine it into
        test data

## What about Interpolation Problems?

-   When simply trying to find interpolating functions which solve
    functional equations, the risk of prior contamination is less clear
-   However, you may still want to separate out validation and test grid
    points because any data you use for HPO or convergence criteria
    can’t be used to understand generalization.
-   For example consider:
    1.  Fit until “training” loss is zero
    2.  Keep running stochastic optimizer until “validation” loss is
        zero
-   In that case, it crudely interpolating the validation data, which
    makes it equivalent to training data? Not useful for generalization
    -   May find that the model generalized better if you **stopped
        earlier**

## Level of Abstraction for Optimizers

-   While you can setup a standard optimization objective and optimizer,
    most ML frameworks work at a lower level
-   The key reasons are that:
    -   Minibatching (usually just called “batches”) requires more
        flexibility in implementation to be efficient
    -   Stopping criteria is more complicated with highly
        overparameterized models
    -   Logging and validation logic requires more flexibility
    -   Often you will want to take a snapshot of the current best
        solution and continue later for refinement (or to solve in
        parallel)

## Steps and Epochs

-   There is a great deal of flexibility in how you setup the optimizer
-   But a common approach is to randomly shuffle the data, create a set
    of batches $B_t$ (without replacement), and then iterate through
    them
-   Terminology (when relevant)
    -   Every iteration of SGD for a given batch is a **step**
    -   If you have gone through the entire dataset once, we say that
        you have completed an **epoch**
-   At the end of an epoch is a good time to log, check the validation
    loss, and potentially stop the training

## Software Components used in ML

Some common software components for optimization are

1.  **Autodifferentiation** and libraries of functions provide the
    approximation class
2.  **Data loaders** which will take care of providing batches to the
    optimizers
3.  **Optimizers** are typically iterative, have an internal state, and
    you can update with one sample of the gradient for that batch
4.  **Logging** and visualization tools to track progress because the
    optimization process may be slow and you want to do HPO
5.  **HPO** software using training, validation, and possibly test loss

## Logging and Visualization

-   Several tools exist for logging to babysit optimizers, find good
    hyperparameters, etc. including
    [Tensorboard](https://www.tensorflow.org/tensorboard)
    -   But we will use [Weights and Biases](https://wandb.ai/site)
        (W&B) because it is a market leader, free for academics and
        seems to be the frontrunner
-   Many algorithms and frameworks exist for HPO:
    -   [Weights and Biases](https://wandb.ai/site) (W&B) has a built-in
        HPO framework using random search and bayesian optimization
    -   [Optuna](https://optuna.org/) and [Ray
        Tune](https://docs.ray.io/en/master/tune/index.html) is a
        popular open-source HPO framework
    -   [Ray Tune](https://docs.ray.io/en/master/tune/index.html) is a
        popular open-source HPO framework
-   HPO frameworks will often use the
    [command-line](https://github.com/shadawck/awesome-cli-frameworks#python)
    to run new jobs. [Python
    Fire](https://github.com/google/python-fire)

## Broad Frameworks for Machine Learning

-   You can just hand-code loops/etc. which seems the best approach for
    JAX
    -   Even with Pytorch, it isn’t obvious that a framework is better
        ex-post, though ex-ante it can help you try different
        permutations easily
-   [Pytorch Lightning](https://www.pytorchlightning.ai/) is a popular
    framework which will formalize the training loops even across
    distributed systems and make CLI, HPO, logging, etc. convenient
    -   It remains fairly flexible because it is just wrapping Pytorch
-   [Keras](https://keras.io/) is a similar framework with the ability
    to target multiple backends (e.g., Pytorch, JAX)
    -   The challenge is that it is much less flexible for non-typical
        research
-   [Hydra](https://github.com/facebookresearch/hydra) is a framework
    for more serious engineering code

# Detailed Pytorch Linear Regression Example

## Linear Regression Examples

-   Of course SGD is a terrible way to do a standard linear regression,
    but it will help us understand the mechanics with a well-understood
    problem
-   These examples simulate data for some:
    $y_n = x_n \cdot \theta + \sigma \epsilon$ for
    $\epsilon \sim N(0,1)$
-   They then show various features of the optimization pipeline and
    software to implement the LLS ERM objective

$$
\min_{\theta} \frac{1}{N} \sum_{n=1}^N \left[y_n - x_n \cdot \theta\right]^2
$$

-   Which we note has a finite-sum objective, which lets us use
    minibatch SGD with gradient estimates
-   Install in your environment with `pip install -r requirements.txt`

## Packages

-   Before showing all of the variations, here we will implement an
    inline SGD version with minibatches

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import equinox as eqx
from torch.utils.data import TensorDataset, DataLoader, TensorDataset

## Simulate Data

$$
y \sim N(x \cdot \theta, \sigma^2)
$$

In [3]:
N = 500  # samples
M = 2
sigma = 0.001
theta = torch.randn(M)
X = torch.randn(N, M)
Y = X @ theta + sigma * torch.randn(N)
dataset = TensorDataset(X, Y)
print(dataset[0]) # returns tuples of (x_n, y_n)

(tensor([ 0.4478, -0.7485]), tensor(0.2630))

## Dataloaders Provide Batches

-   Code which serves up random batches of data for gradient estimates
    are called “dataloaders”
-   The following code shuffles the data and provides batches of size
    `batch_size`
-   It returns a [Python
    generator](https://docs.python.org/3/howto/functional.html#generator-expressions-and-list-comprehensions)
    which can be iterated until it hits the end of the data

In [4]:
batch_size = 8
train_loader = DataLoader(dataset,
                batch_size=batch_size,
                shuffle=True)
# e.g. iterate and get first element
print(next(iter(train_loader)))

[tensor([[ 2.0556, -0.9117],
        [-0.4503, -0.7447],
        [ 0.9591,  0.0824],
        [ 1.0008, -0.4637],
        [-2.1695, -1.3079],
        [ 0.5311, -1.9840],
        [-0.5480,  0.1194],
        [-0.0364,  0.5186]]), tensor([-0.3133,  0.6384, -0.4519, -0.1410,  1.6973,  0.9745,  0.1600, -0.2956])]

## Loss Function for Gradient Descent

-   Reminder: need to provide AD-able functions which give a gradient
    estimate, not necessarily the objective itself!
-   In particular, for LLS we simply can find the MSE between the
    prediction and the data for the batch itself

In [5]:
def residuals(model, X, Y):  # batches or full data
    Y_hat = model(X).squeeze()
    return ((Y_hat - Y) ** 2).mean()

## Hypothesis Class

-   The “Hypothesis Class” for our ERM approximation is linear in this
    case.
-   Anything with differentiable parameters is a sub-class of the
    `nn.Module` in Pytorch. Special case of Neural Networks
-   In this case, we can just use a prebuilt linear approximation from
    $\mathbb{R}^M \to \mathbb{R}^1$ without an affine constant term.
-   The underlying parameters will have a random initialization, which
    becomes **crucial** with overparameterized models (but wouldn’t be
    important here)

In [6]:
model = nn.Linear(M, 1, bias=False)  # random initialization
print(model)

Linear(in_features=2, out_features=1, bias=False)

## Optimizer

-   First-order optimizers take steps using gradient estimates
-   In many ML applications you will want control over the process, so
    will manually call the function to collect the gradient estimate
    then call the optimizer to take the next step
-   Here we will just use SGD with a fixed learning rate
-   Note that the optimizer is constructed to look directly at the
    `parameters` for your underlying model(s)

In [7]:
optimizer = optim.SGD(
    model.parameters(), lr=0.001
)
print(optimizer)

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

## Training Loop

-   Finally, we can loop for multiple “epochs” of passes through the
    data with the dataloader, calling the optimizer to update each time
-   The `for ... in train_loader:` will repeat until the end of the data
    and continue to the next epoch (i.e., pass through data)
-   Each batch updates a `step` using the optimizer, which is unaware of
    epochs/batches/etc.

In [8]:
for epoch in range(300):
    for X_batch, Y_batch in train_loader:
        optimizer.zero_grad() 
        loss = residuals(model, X_batch, Y_batch)  # primal
        loss.backward()  # backprop/reverse-mode AD
        # Now the model.parameters have gradients updated, so...
        optimizer.step()  # Update the optimizers internal parameters
print(f"||theta - theta_hat|| = {torch.norm(theta - model.weight.squeeze())}")        

||theta - theta_hat|| = 6.593696161871776e-05

# More Pytorch Linear Regression Examples

## Gradient Descent

-   See
    [examples/linear_regression_pytorch_gd.py](examples/linear_regression_pytorch_gd.py)
-   Simulates data and shows the basic training loop
-   Using the “full batch” to calculate the residuals

## Stochastic Gradient Descent

-   See
    [examples/linear_regression_pytorch_sgd.py](examples/linear_regression_pytorch_sgd.py)
-   This takes the existing code, but adds in code to calculate gradient
    estimates using minibatches
-   Note that this is creating batches by shuffling the data and the
    going through it `batch_size` chunks at a time
-   When it gets to the end of that data, it is the end of the `epoch`

## Adam + Bells and Whistles

-   See
    [examples/linear_regression_pytorch_adam.py](examples/linear_regression_pytorch_adam.py)
-   This extends the previous version and adds in the full
    train/val/test datasplit
    -   The `val_loss` is collected and displayed at the end of each
        epoch
-   It also shows a learning rate scheduler and a few utilities for
    logging and early stopping

## Logging with Weights and Biases

-   See
    [examples/linear_regression_pytorch_logging.py](examples/linear_regression_pytorch_logging.py)

-   This adds in support for Weights and Biases, and also demonstrates
    the use of a custom `nn.Module` for the hypothesis class

    1.  Go to [wandb.ai](https://wandb.ai/) and create an account,
        ideally linked to your github
    2.  Ensure you have installed the packages with
        `pip install -r requirements.txt`
    3.  Run `wandb login` in terminal to connect to your account

-   You will then be able to run these files and see results on
    [wandb.ai](https://wandb.ai/)

## Pytorch Lightning

-   See
    [examples/linear_regression_pytorch_lightning.py](examples/linear_regression_pytorch_lightning.py)
-   See
    [examples/linear_regression_pytorch_lightning_defaults.yaml](examples/linear_regression_pytorch_lightning_defaults.yaml)
    for default HPO and parameters
-   This is using many features in [Pytorch
    Lightning](https://www.pytorchlightning.ai/) to simplify the code
-   The optimizer, stopping rules, logging, learning rate scheduler,
    etc. are all handled by the framework and can be configured in that
    file or on the CLI.
-   Can change values on commandline to override the `yaml` file,

``` bash
python lectures/examples/linear_regression_pytorch_lightning.py --optimizer.lr=0.0001 --model.N=500
python lectures/examples/linear_regression_pytorch_lightning.py --trainer.max_epochs=500
```

## Sweeps with W&B

-   See
    [examples/linear_regression_pytorch_sweep.yaml](lectures/examples/linear_regression_pytorch_sweep.yaml)
    for a [sweep file](https://docs.wandb.ai/guides/sweeps) which
    provides a HPO experiment to run and log
    -   Executes variations on `--model.batch_size=32` and
        `--model.lr=0.001` etc
    -   Tries to choose them to minimize the `val_loss` as a HPO
        objective
    -   First, create the sweep,
        `wandb sweep lectures/examples/linear_regression_pytorch_sweep.yaml`
    -   Then run `wandb agent <sweep_id>` with returned sweep id
    -   Call `wandb agent <sweep_id>` on multiple computers to run in
        parallel
-   Any CLI implementation can be used with these sorts of frameworks

# Detailed JAX Linear Regression Example

## Packages

-   `optax` is a common package for ML optimization methods

In [9]:
import jax
import jax.numpy as jnp
from jax import grad, jit, value_and_grad, vmap
from jax import random
import optax

## Simulate Data

-   Few differences here, except for manual use of the `key`
-   Remember that if you use the same `key` you get the same value.
-   See [JAX
    docs](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)
    for more details

In [10]:
N = 500  # samples
M = 2
sigma = 0.001
key = random.PRNGKey(42)
# Pattern: split before using key, replace name "key"
key, *subkey = random.split(key, num=4)
theta = random.normal(subkey[0], (M,))
X = random.normal(subkey[1], (N, M))
Y = X @ theta + sigma * random.normal(subkey[2], (N,))  # Adding noise

## Dataloaders Provide Batches

-   For more complicated data (e.g. images, text) JAX can use other
    packages, but it doesn’t have a canonical dataloader at this point
-   But in this case we can manually create this, using
    [`yield`](https://docs.python.org/3/howto/functional.html#generators)

In [11]:
def data_loader(key, X, Y, batch_size):
    N = X.shape[0]
    assert N == Y.shape[0]
    indices = jnp.arange(N)
    indices = random.permutation(key, indices)
    # Loop over batches and yield
    for i in range(0, N, batch_size):
        b_indices = indices[i:i + batch_size]
        yield X[b_indices], Y[b_indices]
# e.g. iterate and get first element
dl_test = data_loader(key, X, Y, 4)
print(next(iter(dl_test)))

(Array([[ 0.05907545, -1.7277497 ],
       [-1.3816313 ,  0.33074763],
       [ 0.76224667, -0.07191363],
       [-0.46871892,  0.24884424]], dtype=float32), Array([-1.3571022 ,  0.07165897,  0.0480412 ,  0.13437417], dtype=float32))

## Hypothesis Class

-   The “Hypothesis Class” for our ERM approximation is linear in this
    case
-   JAX is functional and non-mutating, so you must write stateless code
-   We will move towards a more general class with the `equinox`
    package, but for now we will implement the model with the parameters
    directly
-   The underlying parameters will have a random initialization, which
    becomes **crucial** with overparameterized models (but wouldn’t be
    important here)

In [12]:
def predict(theta, X):
    return jnp.matmul(X, theta) #or jnp.dot(X, theta)

# Need to randomize our own theta_0 parameters
key, subkey = random.split(key)
theta_0 = random.normal(subkey, (M,))
print(f"theta_0 = {theta_0}, theta = {theta}")

theta_0 = [ 1.7535115  -0.07298409], theta = [0.1378821  0.79073715]

## Loss Function for Gradient Descent

-   Reminder: need to provide AD-able functions which give a gradient
    estimate, not necessarily the objective itself!
-   In particular, for LLS we simply can find the MSE between the
    prediction and the data for the batch itself
-   For now, we are passing the `params` rather than the `model` itself

In [13]:
def vectorized_residuals(params, X, Y):
    Y_hat = predict(params, X)
    return jnp.mean((Y_hat - Y) ** 2)

## Optimizer

-   The `optimizer.init(theta_0)` provides the initial state for the
    iterations
-   With SGD it is empty, but with momentum/etc. it will have internal
    state

In [14]:
lr = 0.001
batch_size = 16
num_epochs = 201

# optax.adam(lr) is worse here
optimizer = optax.sgd(lr)
opt_state = optimizer.init(theta_0)
print(f"Optimizer state:{opt_state}")
params = theta_0 # initial condition

Optimizer state:(EmptyState(), EmptyState())

## Using Optimizer for a Step

-   Here we write a (compiled) utility function which:
    1.  Calculates the loss and gradient estimates for the batch
    2.  Updates the optimizer state
    3.  Applies the updates to the parameters
    4.  Returns the updated parameters, optimizer state, and loss
-   The reason to set this up as a function is to maintain JAXs “pure”
    style

In [15]:
@jax.jit
def make_step(params, opt_state, X, Y):
  loss_value, grads = jax.value_and_grad(vectorized_residuals)(params, X, Y)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss_value

## Training Loop Version 1

-   Note that unlike Pytorch the gradients are passed as parameters

In [16]:
for epoch in range(num_epochs):
    key, subkey = random.split(key) # changing key for shuffling each epoch
    train_loader = data_loader(subkey, X, Y, batch_size)
    for X_batch, Y_batch in train_loader:
        params, opt_state, train_loss = make_step(params, opt_state, X_batch, Y_batch)  
    if epoch % 100 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")

Epoch 0,||theta - theta_hat|| = 1.714521050453186
Epoch 100,||theta - theta_hat|| = 0.0020757941529154778
Epoch 200,||theta - theta_hat|| = 6.367397145368159e-05
||theta - theta_hat|| = 6.367397145368159e-05

## Auto-Vectorizing

-   In the above case the `vectorized_residuals` was able to use a
    directly vectorized function.
-   However in many cases it will be more convenient to write code for a
    single element of the finite-sum objectives
-   Now we will rewrite our objective to demonstrate how to use `vmap`

In [17]:
def residual(theta, x, y):
    y_hat = predict(theta, x)
    return (y_hat - y) ** 2

@jit
def residuals(theta, X, Y):
    # Use vmap, fixing the 1st argument
    batched_residuals = jax.vmap(residual, in_axes=(None, 0, 0))
    return jnp.mean(batched_residuals(theta, X, Y))
print(residual(theta_0, X[0], Y[0]))
print(residuals(theta_0, X, Y))

0.021030858
3.546122

## New Step and Initialization

-   This simply changes the function used for the `value_and_grad` call
    to use the new `residuals` function and resets our optimizer

In [18]:
@jax.jit
def make_step(params, opt_state, X, Y):     
  loss_value, grads = jax.value_and_grad(residuals)(params, X, Y)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss_value
optimizer = optax.sgd(lr) # better than optax.adam here
opt_state = optimizer.init(theta_0)
params = theta_0

## Training Loop Version 2

-   Otherwise the training loop is the same

In [19]:
for epoch in range(num_epochs):
    key, subkey = random.split(key) # changing key for shuffling each epoch
    train_loader = data_loader(subkey, X, Y, batch_size)
    for X_batch, Y_batch in train_loader:
        params, opt_state, train_loss = make_step(params, opt_state, X_batch, Y_batch)  
    if epoch % 100 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")

Epoch 0,||theta - theta_hat|| = 1.7061582803726196
Epoch 100,||theta - theta_hat|| = 0.002027335576713085
Epoch 200,||theta - theta_hat|| = 6.216356268851086e-05
||theta - theta_hat|| = 6.216356268851086e-05

## JAX Examples

-   See
    [examples/linear_regression_jax_sgd.py](examples/linear_regression_jax_sgd.py)
    -   This implements the inline code above without the vmap
-   See
    [examples/linear_regression_jax_vmap.py](examples/linear_regression_jax_vmap.py)
    -   This implements the `vmap` as above
    -   This also adds in an [learning rate
        schedule](https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules)
-   See
    [examples/linear_regression_jax_equinox.py](examples/linear_regression_jax_equinox.py)
    -   See the following example for more

# JAX Linear Regression with Equinox

## Using an Equinox Model

-   While it seems convenient to work in a functional style, when we
    move towards nested, deep approximations it can become cumbersome
-   An excellent package to both define hypothesis classes and to work
    with their gradients is
    [equinox](https://github.com/patrick-kidger/equinox)
-   Key to the design of equinox is that if we use that the nested
    parameters of an approximation class are PyTrees, then we can find
    the derivative with respect to that entire type. See our previous
    example on differentiating PyTrees

## Equinox Macros

-   There are several macros in equinox which make it easier to work
    with differentiable PyTrees, all of which have the name `filter_`
    -   Loosely: these macros go through the underlying python
        datastructures and filter out values into static values vs. ones
        which could be perturbed
    -   For example, if I have a type with a `jnp.array` for the
        differentiable weights and a `int` which stores the number of
        hidden dimensions, then the `filter_grad` will flag the
        `jnp.array` as differentiable and the `int` will have the
        gradient type as `None`
-   When in doubt, always replace them. e.g. `@jax.jit` can be replied
    with `@equinox.filter_jit`, `jax.vmap` should be replaced with
    `eqx.filter_vmap`, etc.

## Hypothesis Class

-   We are moving towards Neural Networks, which are a very broad class
    of approximations.
-   Here lets just use a linear approximation with no constant term
-   As always, the initial randomization will become increasingly
    important

In [20]:
key, subkey = random.split(key)
model = eqx.nn.Linear(M, 1, use_bias = False, key = subkey)
print(model.weight)

[[-0.01143495  0.6449629 ]]

## Residuals using the Model

-   The model now contains all of the, potentially nested, parameters
    for the approximation class
-   It provides call notation to evaluate the function with those
    parameters

In [21]:
def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals(model, X, Y):
    batched_residuals = vmap(residual, in_axes=(None, 0, 0))
    return jnp.mean(batched_residuals(model, X, Y))
# Alternatively, could have combined and still jit/grad
def residuals_2(model, X, Y):
    Y_hat = vmap(model)(X).squeeze()
    return jnp.mean((Y - Y_hat) ** 2)

## Gradients of Models

-   As discussed, we can find the gradients of richer objects than just
    arrays
-   Optimizer updates use perturbations of the underlying PyTree
-   Updates can be applied because the type of the gradients matches the
    underlying PyTree

In [22]:
grads = jax.grad(residuals)(model, X, Y)
print(grads)
print(grads.weight)

Linear(
  weight=f32[1,2],
  bias=None,
  in_features=2,
  out_features=1,
  use_bias=False
)
[[-0.30091333 -0.31419277]]

## Setup of the Optimizer and `make_step`

-   The `make_step` isn’t very different, except for using a few
    `equinox` utility functions

In [23]:
lr = 0.001
optimizer = optax.sgd(lr)
# Needs to remove the non-differentiable parts of the "model" object
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))

@eqx.filter_jit
def make_step(model, opt_state, X, Y):     
  loss_value, grads = eqx.filter_value_and_grad(residuals)(model, X, Y)
  updates, opt_state = optimizer.update(grads, opt_state, model)
  model = eqx.apply_updates(model, updates)
  return model, opt_state, loss_value

## Training Loop

In [24]:
num_epochs = 300
batch_size = 64
key, subkey = random.split(key) # will keep same key for shuffling each epoch
for epoch in range(num_epochs):
    key, subkey = random.split(key) # changing key for shuffling each epoch
    train_loader = data_loader(subkey, X, Y, batch_size)
    for X_batch, Y_batch in train_loader:
        model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)    
    if epoch % 100 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

Epoch 0,||theta - theta_hat|| = 0.20521600544452667
Epoch 100,||theta - theta_hat|| = 0.038740210235118866
Epoch 200,||theta - theta_hat|| = 0.0073235719464719296
||theta - theta_hat|| = 0.0013871045084670186

## Custom Types

-   To prepare for layers of NN, we see that this class could have been
    written manually, layering as we see fit
-   See
    [examples/linear_regression_jax_equinox.py](examples/linear_regression_jax_equinox.py)
    for more, but we could have manually created the following

In [25]:
class MyLinear(eqx.Module):
    weight: jax.Array
    def __init__(self, in_size, out_size, key):
        self.weight = jax.random.normal(key, (out_size, in_size))
    # Similar to Pytorch's forward
    def __call__(self, x):
        return self.weight @ x

model = MyLinear(M, 1, key = subkey)