# Setup & Overview

In [1]:
# Uncomment if you want to automatically install all dependencies via pip
# !pip install -r requirements.txt

In [2]:
import functools

import matplotlib as mpl
import matplotlib.pyplot as plt
import nbfigtulz as ftl

import numpy as np

import torch
import torch.nn as nn

In [3]:
from data.x3 import generate_data
from utils import get_ccycle, train, train_loop
import models

In [4]:
SEED = 42
NPRNG = np.random.default_rng(SEED)
torch.manual_seed(SEED);

In [5]:
from utils import get_best_device
DEVICE = get_best_device()

In [6]:
FIG_SIZE_SMALL = (3.0, 2.3)  # small images for paper
FIG_SIZE_LARGE = (4.0, 3.0)  # larger images for presentation
FIG_SIZE = FIG_SIZE_LARGE

In this notebook, we take a deep dive into the $x^3$ example and analyze thoroughly the NN during training with [PyTorch](https://pytorch.org).
We also test a simplified version of DER, which we refer to as ${}_\mathcal{S}$DER. The same NN is used but $\alpha$ is discarded and $\beta$ and $\nu$ are fused into $\sigma = \beta / \nu$. The updated loss function $\mathcal{L}_i$ for ${}_\mathcal{S}$DER reads:
$$\mathcal{L}_i = \log \sigma^2 + (1 + \lambda \nu) \frac{(x - \gamma)^2}{\sigma^2}$$

Change `DER_TYPE` below to `'sDER'` to use the simplified version ${}_\mathcal{S}$DER. In this notebook, this value is used in several branches which ensure, e.g., that images are written to separate locations.

In [7]:
DER_TYPE = 'DER'  # 'sDER'

In [8]:
img_dir = f'img/x3_{DER_TYPE.lower()}'
!rm -rf {img_dir} && mkdir -p {img_dir}
ftl.config['img_dir'] = img_dir

print(f'Images are stored in: {img_dir}')

Images are stored in: img/x3_der


In [9]:
BATCH_SIZE = 100
N_SAMPLES = 50
N_EPOCHS = 500

if DER_TYPE == 'DER':
    LR = 5e-4    # learning rate
    COEFF = .01  # lambda
else:
    LR = .005    # learning rate
    COEFF = 2.   # lambda

In [10]:
if DER_TYPE == 'DER':
    DERLayer = models.DERLayer
    loss_der = models.loss_der
else:
    DERLayer = models.SDERLayer
    loss_der = models.loss_sder

# Training

We now generate training and test data on $x \in [-4, +4]$ ...

In [11]:
def get_dataloader(n):
    kwargs = {
        'rng': NPRNG,
        'dtype': np.float32,
    }
    train_data = generate_data(-4, 4, n, **kwargs)
    test_data = generate_data(-4, 4, n, **kwargs)

    dl_kwargs = {
        'batch_size': BATCH_SIZE,
        'shuffle': True,
        'drop_last': True
    }

    train_dl = torch.utils.data.DataLoader(train_data, **dl_kwargs)
    test_dl = torch.utils.data.DataLoader(test_data, **dl_kwargs)
    
    return train_dl, test_dl


train_dl, test_dl = get_dataloader(n=1000)

... define is simple $\ell_1$ error loss (only used for plotting) and write a convenient helper function to kick-off the training.

In [12]:
def error_fct(y, y_pred):
    return torch.mean(torch.abs(y[:, 0] - y_pred))


def train_der(n_epochs):
    model = torch.nn.Sequential(models.Model(4), DERLayer())
    loss_fct = functools.partial(loss_der, coeff=COEFF)
    return train(n_epochs=n_epochs,
                 model=model,
                 loss_fct=loss_fct,
                 error_fct=error_fct,  # only used for logging
                 lr=LR,
                 train_dl=train_dl,
                 test_dl=test_dl,
                 scan_lim=(-7, 7.),
                 device=DEVICE)

In [13]:
@ftl.with_context
def plot_loss(loss, error, *, file_name):
    fig, ax = plt.subplots()
    
    n = len(loss_train)
    for i in range(n):
        ax.plot(loss[i], color='C0', alpha=.1)
        ax.plot(error[i], color='C1', alpha=.1)
        
    ax.set_xlabel('Epoch')
    ax.set_ylim(1, 20)
    ax.set_yscale('log')
    ax.grid()
    
    ax.legend(
        [mpl.lines.Line2D([0], [0], color=f'C{i}') for i in range(2)],
        ['Total loss', 'Abs. error']
    )
    
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)

We now start the training; time to grab a coffee ☕. If this takes too long (> ☕☕☕), decrease `N_SAMPLES` and restart the training.

In [14]:
model, loss, x_scan = train_loop(train_der, n_samples=N_SAMPLES, n_epochs=N_EPOCHS, quiet=False)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [08:12<00:00,  9.86s/it]


In [15]:
@ftl.with_context
def make_fig(loss, error, *, file_name):
    fig, ax = plt.subplots()
    ax.set_yscale('log')
    
    n = len(loss)
    alpha = max(.1, 2. / n)
    for i in range(n):
        ax.plot(loss[i], color='C0', alpha=alpha)
        ax.plot(error[i], color='C1', alpha=alpha)
        
    ax.set_xlabel('Epoch')
    ax.grid()
    
    ax.legend(
        [mpl.lines.Line2D([0], [0], color=f'C{i}') for i in range(2)],
        ['Total loss', 'Abs. error']
    )
    
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


ftl.img_grid([
    make_fig(loss['loss_train'], loss['error_train'], file_name='loss_train'),
    make_fig(loss['loss_test'], loss['error_test'], file_name='loss_test'),
])

Above, *Abs. error* refers to the $\ell_1$ loss we defined previously.

For convenience we define a helper variable that holds the average of all samples for each epoch:

In [16]:
x_scan_avg = {
    'x': x_scan['x'],
    'y': np.average(x_scan['y'], axis=0),
}

We are hip and use [Python 3.8 f-strings' = support](https://docs.python.org/3/whatsnew/3.8.html#f-strings-support-for-self-documenting-expressions-and-debugging) to print the shapes of `x_scan` and `x_scan_avg` ...

In [17]:
print(f'    {x_scan["x"].shape=}')
print(f'{x_scan_avg["x"].shape=}')
print(f'    {x_scan["y"].shape=}')
print(f'{x_scan_avg["y"].shape=}')

    x_scan["x"].shape=(300,)
x_scan_avg["x"].shape=(300,)
    x_scan["y"].shape=(50, 500, 300, 4)
x_scan_avg["y"].shape=(500, 300, 4)


# Analysis

Let's plot the evolution of the parameters and their final distribution!

In [18]:
def idx2epoch(epochs, *, n_max):
    epochs = [n_max if k == -1 else k for k in set(epochs)]
    return sorted([k for k in epochs if 0 < k <= n_max])


@ftl.with_context
def plot_evolution(x, y, f, *, epochs, file_name, x_min=None, x_max=None, **kwargs):
    if x_min is None:
        x_min = x[0]
        
    if x_max is None:
        x_max = x[-1]
    
    sel = (x_min <= x) & (x <= x_max)
    x = x[sel]
    y = y[:, sel]
    
    epochs = idx2epoch(epochs, n_max=y.shape[0])
    ccycle = get_ccycle(len(epochs))
    
    fig, ax = plt.subplots()
    for i, k in enumerate(epochs):
        f(x=x, y=y[k - 1], ax=ax, color=ccycle[i], **kwargs)
        
    ax.legend([
        mpl.lines.Line2D([0], [0], color=ccycle[i]) for i in range(len(epochs))
    ], [
        f'Epoch: {i}' for i in epochs
    ])
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)

In [19]:
def make_fig(*, x, y, ax, color):
    error = y[:, 0] - x**3
    ax.plot(x, error, color=color)
    
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\gamma - x^3$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100, -1],
    file_name='error_evol',
    x_min=-4,
    x_max=4.,
)

In [20]:
@ftl.with_context
def make_fig(x, y):
    sel = (x >= -4.) & (x <= 4.)
    x = x[sel]
    y = y[:, sel]
    
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\gamma - x^3$')
    
    n = y.shape[0]
    alpha = max(.1, 2. / n)
    
    error = y[..., 0] - x**3
    for i in range(n):
        ax.plot(x, error[i], color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'error', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [21]:
def make_fig(*, x, y, ax, color):
    nu = y[:, 1]
    ax.plot(x, np.sqrt(nu), color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\nu}$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100, 200, -1],
    file_name='nu_evol'
)

In [22]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\nu}$')
    
    n = y.shape[0]
    alpha = max(.1, 2. / n)
    
    nu = y[..., 1]
    for i in range(n):
        ax.plot(x, np.sqrt(nu[i]), color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'nu', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [23]:
def make_fig(*, x, y, ax, color):
    alpha = y[:, 2]
    ax.plot(x, np.sqrt(alpha - 1.), color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\alpha - 1}$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100, 200, -1],
    file_name='alpha_evol'
)

In [24]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'\alpha$')
    
    alpha = y[..., 2]
    
    n = y.shape[0]
    for i in range(n):
        ax.plot(x, alpha[i], color='C0', alpha=max(.1, 2. / n))
        
    return ftl.save_fig(fig, 'alpha', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [25]:
def make_fig(*, x, y, ax, color):
    beta = y[:, 3]
    ax.plot(x, np.sqrt(beta), color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta}$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100, 200, -1],
    file_name='beta_evol'
)

In [26]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta}$')
    
    n = y.shape[0]
    alpha = max(.1, 2. / n)
    
    beta = y[..., 3]
    for i in range(n):
        ax.plot(x, np.sqrt(beta[i]), color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'beta', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [27]:
def make_fig(*, x, y, ax, color):
    nu = y[:, 1]
    alpha = y[:, 2]
    beta = y[:, 3]
    ax.plot(x, np.sqrt(beta * (1. + nu) / nu / alpha), color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100, 200, -1],
    file_name='width_evol'
)

In [28]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')
    
    nu = y[..., 1]
    alpha = y[..., 2]
    beta = y[..., 3]
    
    n = y.shape[0]
    for i in range(n):
        ax.plot(x, np.sqrt(beta[i] * (1. + nu[i]) / nu[i] / alpha[i]), color='C0', alpha=max(.2, 3. / n))
        
    return ftl.save_fig(fig, 'width', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [29]:
@ftl.with_context
def make_fig(x, y):
    sel = (x >= -4.) & (x <= 4.)
    x = x[sel]
    y = y[:, sel]
    
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')
    
    nu = y[..., 1]
    alpha = y[..., 2]
    beta = y[..., 3]
    
    n = y.shape[0]
    for i in range(n):
        ax.plot(x, np.sqrt(beta[i] * (1. + nu[i]) / nu[i] / alpha[i]), color='C0', alpha=max(.2, 3. / n))
        
    return ftl.save_fig(fig, 'width_narrow', resize=FIG_SIZE)
                
                
make_fig(x_scan['x'], x_scan['y'][:, -1])

width_narrow.png

In [30]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    axs = [ax, ax2]
    
    ax2.set_yscale('log')
    
    for i in range(len(y)):
        nu = y[i][:, 1]
        alpha = y[i][:, 2]
        beta = y[i][:, 3]
        
        u = np.sqrt(beta / (alpha - 1))
        u2 = u / np.sqrt(nu)
            
        axs[0].plot(x, u, color='C0', alpha=.2)
        axs[1].plot(x, u / np.sqrt(nu), color='C1', alpha=.2)
        
    ax.set_xlabel('$x$')
    axs[0].set_ylabel('Aleatoric')
    axs[1].set_ylabel('Epistemic')
                    
    ax.legend([
        mpl.lines.Line2D([0], [0], color='C0'),
        mpl.lines.Line2D([0], [0], color='C1'),
    ], ['Aleatoric', 'Epistemic'])
    
    return ftl.save_fig(fig, 'uncertainties_sota', resize=FIG_SIZE)


if DER_TYPE == 'DER':
    img = make_fig(x_scan['x'], x_scan['y'][:, -1])

img

uncertainties_sota.png

Obviously, the width of Student's $t$-distribution is much better suited for estimating the epistemic uncertainty. Hence, we update the definitions of the aleatoric uncertainty $u_\text{al}$ and the epistemic uncertainty $u_\text{ep}$:
$$\begin{align*}
u_\text{al} &\leftarrow w_\text{St} \equiv \sqrt{\frac{\beta (1 + \nu)}{\alpha \nu}} \\
u_\text{ep} &\leftarrow \sqrt{\frac{1}{\nu}}
\end{align*}$$
where the latter is rooted in our previous findings that $\nu$ measures the point-wise convergence speed.

In [31]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    axs = [ax, ax2]
    
    ax.set_ylim(0, 40)
    ax2.set_yscale('log')
    
    for i in range(len(y)):
        nu = y[i][:, 1]
        alpha = y[i][:, 2]
        beta = y[i][:, 3]
        
        u = np.sqrt(beta * (nu + 1) / nu / alpha)
        u2 = 1. / np.sqrt(nu)
            
        axs[0].plot(x, u, color='C0', alpha=.2)
        axs[1].plot(x, u2, color='C1', alpha=.2)
        
    ax.set_xlabel('$x$')
    axs[0].set_ylabel('Aleatoric')
    axs[1].set_ylabel('Epistemic')
                    
    ax.legend([
        mpl.lines.Line2D([0], [0], color='C0'),
        mpl.lines.Line2D([0], [0], color='C1'),
    ], ['Aleatoric', 'Epistemic'])
    
    return ftl.save_fig(fig, 'uncertainties_new', resize=FIG_SIZE)


make_fig(x_scan['x'], x_scan['y'][:, -1])

uncertainties_new.png

Before, we only looked at 1D projections of the parameters. In fact, we can learn even more about the behaviour of DER during traing by looking at 2D projections:
Below, we see a massive overshooting in $\nu$ in the early stages of training that predominantly drives the final epistemic uncertainty.
Note that the overshooting is **not** due to momentum of Adam since $\mathcal{L}$ decreases monotonously, but due to $\partial_\nu \mathcal{L} \propto |y-x^3|$.

In [32]:
@ftl.with_context
def make_fig(xy, x, *, file_name):
    fig, ax = plt.subplots()
    ax.grid()
    
    for k in x:
        i = np.abs(xy['x'] - k).argmin()
        m = xy['y'][:, i]
        
        error = m[:, 0] - k ** 3
        nu = m[:, 1]
        alpha = m[:, 2]
        beta = m[:, 3]
        w2 = beta / alpha * (nu + 1) / nu
        
        ax.plot(error, np.sqrt(w2), label=f'$x={k:+}$')
        ax.legend()
        ax.set_xlabel(r'$\gamma - x^3$')
        ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


ftl.img_grid([
    make_fig(x_scan_avg, x=[4, 3, 2], file_name='error_width_evol_+2+3+4'),
    make_fig(x_scan_avg, x=[2, 1, 0], file_name='error_width_evol_+2+1+0'),
    make_fig(x_scan_avg, x=[-1, 0, 1], file_name='error_width_evol_-1+0+1'),
])

In [33]:
@ftl.with_context
def make_fig(xy, x, *, file_name):
    fig, ax = plt.subplots()
    ax.grid()
    
    min_nu = None
    max_nu = None
    
    for j, k in enumerate(x):
        i = np.abs(xy['x'] - k).argmin()
        m = xy['y'][:, i]
        
        nu = m[:, 1]
        alpha = m[:, 2]
        beta = m[:, 3]
        
        if min_nu is None or min_nu > np.min(nu):
            min_nu = np.min(nu)
            
        if max_nu is None or max_nu < np.max(nu):
            max_nu = np.max(nu)
            
        color = f'C{j}'
        
        if DER_TYPE == 'DER':
            ax.plot(np.sqrt(nu / (1. + nu)), np.sqrt(beta / alpha), color=color, label=f'$x={k:+}$', alpha=.9)
            ax.plot(np.sqrt(nu / (1. + nu))[::10], np.sqrt(beta / alpha)[::10], '.', color=color, alpha=.6)
        else:
            ax.plot(np.sqrt(nu), np.sqrt(beta), color=color, label=f'$x={k:+}$', alpha=.9)
            ax.plot(np.sqrt(nu)[::10], np.sqrt(beta)[::10], '.', color=color, alpha=.6)
    
    if DER_TYPE == 'DER':
        x = np.linspace(np.sqrt(min_nu / (1. + min_nu)), np.sqrt(max_nu / (1. + max_nu)), 10)
    else:
        x = np.linspace(np.sqrt(min_nu), np.sqrt(max_nu), 10)
    
    y = 3 * x
    ax.plot(x, y, 'k--', label='$w_\mathrm{St}=3$', alpha=.8, linewidth=1)
    ax.legend()
    
    if DER_TYPE == 'DER':
        ax.set_xlabel(r'$\sqrt{\nu / (1 + \nu)}$')
        ax.set_ylabel(r'$\sqrt{\beta / \alpha}$')
    else:
        ax.set_xlabel(r'$\sqrt{\nu}$')
        ax.set_ylabel(r'$\sqrt{\beta}$')
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


ftl.img_grid([
    make_fig(x_scan_avg, x=[4, 3, 2], file_name='ba_nu_evol_+2+3+4'),
    make_fig(x_scan_avg, x=[2, 1, 0], file_name='ba_nu_evol_+2+1+0'),
    make_fig(x_scan_avg, x=[-1, 0, 1], file_name='ba_nu_evol_-1+0+1'),
])

In [34]:
@ftl.with_context
def make_fig(xy, x, *, file_name):
    fig, ax = plt.subplots()
    ax.grid()
    
    i = np.abs(xy['x'] - x).argmin()
    m = xy['y'][:, :, i]

    nu = m[..., 1]
    alpha = m[..., 2]
    beta = m[..., 3]

    for i in range(m.shape[0]):
        if DER_TYPE == 'DER':
            ax.plot(np.sqrt(nu[i] / (1. + nu[i])), np.sqrt(beta[i] / alpha[i]), alpha=.3)
        else:
            ax.plot(np.sqrt(nu[i]), np.sqrt(beta[i]), alpha=.3)
    
    if DER_TYPE == 'DER':
        ax.set_xlabel(r'$\sqrt{\nu / (1 + \nu)}$')
        ax.set_ylabel(r'$\sqrt{\beta / \alpha}$')
    else:
        ax.set_xlabel(r'$\sqrt{\nu}$')
        ax.set_ylabel(r'$\sqrt{\beta}$')
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


make_fig(x_scan, x=2.5, file_name='ba_nu_evol_+2.5')

ba_nu_evol_+2.5.png

In [35]:
@ftl.with_context
def make_fig(xy, x, *, file_name):
    fig, ax = plt.subplots()
    ax.set_yscale('log')
    ax.grid()
    
    for j, k in enumerate(x):
        i = np.abs(xy['x'] - k).argmin()
        m = xy['y'][:, i]
        
        error = m[:, 0] - k ** 3
        nu = m[:, 1]
        
        color = f'C{j}'
        ax.plot(error, nu, color=color, label=f'$x={k:+}$', alpha=.9)
        ax.plot(error[::10], nu[::10], '.', color=color, alpha=.6)
    
    ax.legend()
    ax.set_xlabel(r'$\gamma - x^3$')
    ax.set_ylabel(r'$\nu$')
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


ftl.img_grid([
    make_fig(x_scan_avg, x=[4, 3, 2], file_name='error_nu_evol_+2+3+4'),
    make_fig(x_scan_avg, x=[2, 1, 0], file_name='error_nu_evol_+2+1+0'),
    make_fig(x_scan_avg, x=[-1, 0, 1], file_name='error_nu_evol_-1+0+1'),
])