# Setup & Overview

In [1]:
# Uncomment if you want to automatically install all dependencies via pip
# !pip install -r requirements.txt

In [2]:
import functools

import matplotlib as mpl
import matplotlib.pyplot as plt
import nbfigtulz as ftl

import numpy as np

import torch
import torch.nn as nn

In [3]:
from data.binpulse import generate_data
from utils import get_ccycle, train, train_loop
import models

In [4]:
SEED = 42
NPRNG = np.random.default_rng(SEED)
torch.manual_seed(SEED);

In [5]:
from utils import get_best_device
DEVICE = get_best_device()

In [6]:
FIG_SIZE_SMALL = (3.0, 2.3)  # small images for paper
FIG_SIZE_LARGE = (4.0, 3.0)  # larger images for presentation
FIG_SIZE = FIG_SIZE_LARGE

In this notebook we repeat our analysis of the $x^3$ example. However, this time we use a more pathological (synthetic) data set with non-constant genuine aleatoric and a large, yet sparsely sampled amplitude in an otherwise flat data distribution. The NN is deliberately kept simple such that it will not be able to describe the spike but will show a large epistemic uncertainty in this region.

Again, change `DER_TYPE` below to `'sDER'` to use the simplified version ${}_\mathcal{S}$DER.

In [7]:
DER_TYPE = 'sDER'  # 'DER'

In [8]:
img_dir = f'img/binpulse_{DER_TYPE.lower()}'
!rm -rf {img_dir} && mkdir -p {img_dir}
ftl.config['img_dir'] = img_dir

print(f'Images are stored in: {img_dir}')

Images are stored in: img/binpulse_sder


In [9]:
BATCH_SIZE = 100
N_EPOCHS = 600
N_SAMPLES = 50
LR = 1e-3     # learning rate
COEFF = 1e-2  # lambda

if DER_TYPE == 'DER':
    DERLayer = models.DERLayer
    loss_der = models.loss_der
else:
    DERLayer = models.SDERLayer
    loss_der = models.loss_sder

# Training

In [10]:
def get_dataloader(n):
    kwargs = {
        'rng': NPRNG,
        'dtype': np.float32,
    }
    train_data = generate_data(n, **kwargs)
    test_data = generate_data(n, **kwargs)

    kwargs = {
        'batch_size': BATCH_SIZE,
        'shuffle': True,
        'drop_last': True
    }
    train_dl = torch.utils.data.DataLoader(train_data, **kwargs)
    test_dl = torch.utils.data.DataLoader(test_data, **kwargs)
    
    return train_dl, test_dl


train_dl, test_dl = get_dataloader(n=1000)

In [11]:
def error_fct(y, y_pred):
    return torch.mean(torch.abs(y[:, 0] - y_pred))


def train_der(n_epochs):
    model = torch.nn.Sequential(models.Model(4), DERLayer())
    loss_fct = functools.partial(loss_der, coeff=COEFF)
    return train(n_epochs=n_epochs,
                 model=model,
                 loss_fct=loss_fct,
                 error_fct=error_fct,
                 lr=LR,
                 train_dl=train_dl,
                 test_dl=test_dl,
                 scan_lim=(0., 1.),
                 device=DEVICE)

Below, we kick-off the training. Reduce `N_SAMPLES` if this takes too long.

In [12]:
model, loss, x_scan = train_loop(train_der, n_samples=N_SAMPLES, n_epochs=N_EPOCHS, quiet=False)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [08:34<00:00, 10.29s/it]


In [13]:
@ftl.with_context
def make_fig(loss, error, *, file_name):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Total loss')
    ax2.set_ylabel('Abs. error')
    
    n = len(loss)
    alpha = max(.1, 2. / n)
    for i in range(n):
        ax.plot(loss[i], color='C0', alpha=alpha)
        ax.plot(error[i], color='C1', alpha=alpha)
    
    ax.legend(
        [mpl.lines.Line2D([0], [0], color=f'C{i}') for i in range(2)],
        ['Total loss', 'Abs. error']
    )
    
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)


ftl.img_grid([
    make_fig(loss['loss_train'], loss['error_train'], file_name='loss_train'),
    make_fig(loss['loss_test'], loss['error_test'], file_name='loss_test'),
])

Above, note how fast $\gamma_i$ converges. As we will see below, the NN not even tries to predict the spike.

In [14]:
x_scan_avg = {
    'x': x_scan['x'],
    'y': np.average(x_scan['y'], axis=0),
}

In [15]:
print(f'    {x_scan["x"].shape=}')
print(f'{x_scan_avg["x"].shape=}')
print(f'    {x_scan["y"].shape=}')
print(f'{x_scan_avg["y"].shape=}')

    x_scan["x"].shape=(300,)
x_scan_avg["x"].shape=(300,)
    x_scan["y"].shape=(50, 600, 300, 4)
x_scan_avg["y"].shape=(600, 300, 4)


In [16]:
@ftl.with_context
def make_fig(dl, x, y):
    xy = np.concatenate([xy.cpu().numpy() for xy in dl], axis=0)
    
    fig, ax = plt.subplots()
    ax.plot(xy[:, 0], xy[:, 1], '.', alpha=.1)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    
    ax.plot(x, y[:, 0])
    ax.legend([
        mpl.lines.Line2D([0], [0], marker='.', linestyle='None', color='C0'),
        mpl.lines.Line2D([0], [0], color='C1'),
    ], ['Data', 'Prediction'])
    
    return ftl.save_fig(fig, 'pred_with_data', resize=FIG_SIZE)
    
    
make_fig(train_dl,x_scan_avg['x'], x_scan_avg['y'][-1])

pred_with_data.png

# Analysis

In [17]:
def idx2epoch(epochs, *, n_max):
    epochs = [n_max if k == -1 else k for k in set(epochs)]
    return sorted([k for k in epochs if 0 < k <= n_max])

In [18]:
@ftl.with_context
def plot_evolution(x, y, f, *, epochs, file_name, x_min=None, x_max=None, **kwargs):
    if x_min is None:
        x_min = x[0]
        
    if x_max is None:
        x_max = x[-1]
    
    sel = (x_min <= x) & (x <= x_max)
    x = x[sel]
    y = y[:, sel]
    
    epochs = idx2epoch(epochs, n_max=y.shape[0])
    ccycle = get_ccycle(len(epochs))
    
    fig, ax = plt.subplots()
    for i, k in enumerate(epochs):
        f(x=x, y=y[k - 1], ax=ax, color=ccycle[i], **kwargs)
        
    ax.legend([
        mpl.lines.Line2D([0], [0], color=ccycle[i]) for i in range(len(epochs))
    ], [
        f'Epoch: {i}' for i in epochs
    ])
        
    return ftl.save_fig(fig, file_name, resize=FIG_SIZE)

In [19]:
def make_fig(*, x, y, ax, color):
    gamma = y[:, 0]
    ax.plot(x, gamma, color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\gamma$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100,  -1],
    file_name='pred_evol')

In [20]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\gamma$')
    
    n = y.shape[0]
    alpha = max(.1, 2. / n)
    
    gamma = y[..., 0]
    for i in range(n):
        ax.plot(x, gamma[i], color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'pred', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [21]:
def make_fig(*, x, y, ax, color):
    beta = y[:, 3]
    ax.plot(x, beta, color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\beta$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100,  -1],
    file_name='beta_evol')

In [22]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\beta$')
    
    n = y.shape[0]
    alpha = max(.1, 2. / n)
    
    beta = y[..., 3]
    for i in range(n):
        ax.plot(x, beta[i], color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'beta', resize=FIG_SIZE)

                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [23]:
def make_fig(*, x, y, ax, color):
    nu = y[:, 1]
    ax.plot(x, nu, color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\nu$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100,  -1],
    file_name='nu_evol')

In [24]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\nu$')
    
    n = y.shape[0]
    alpha = max(.2, 3. / n)
    
    nu = y[..., 1]
    
    for i in range(n):
        ax.plot(x, nu[i], color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'nu', resize=FIG_SIZE)


img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [25]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$1 / \sqrt{\nu}$')
    
    n = y.shape[0]
    alpha = max(.2, 3. / n)
    
    nu = y[..., 1]
    
    for i in range(n):
        ax.plot(x, 1. / np.sqrt(nu[i]), color='C0', alpha=alpha)
        
    return ftl.save_fig(fig, 'epistemic', resize=FIG_SIZE)
                
                
make_fig(x_scan['x'], x_scan['y'][:, -1])

epistemic.png

In [26]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    
    ax.grid()
    ax.set_xlabel(r'$x$')    
    ax.set_ylabel(r'$\left( \sqrt{\nu (\alpha - 1)} \right)^{-1}$')
    ax2.set_ylabel(r'$\left( \sqrt{\nu} \right)^{-1}$')
    
    nu = y[..., 1]
    alpha = y[..., 2]
    beta = y[..., 3]
    
    n = y.shape[0]
    for i in range(n):
        ax.plot(x, 1. / np.sqrt(nu[i] * (alpha[i] - 1.)), color='C0', alpha=max(.2, 3. / n))
        ax2.plot(x, np.sqrt(1. / nu[i]), color='C1', alpha=max(.2, 3. / n))
        
    ax.legend([
        mpl.lines.Line2D([0], [0], color='C0'),
        mpl.lines.Line2D([0], [0], color='C1'),
    ], [r'$\left( \sqrt{\nu (\alpha - 1)} \right)^{-1}$', r'$\left( \sqrt{\nu} \right)^{-1}$'])

    return ftl.save_fig(fig, 'nuamo', resize=FIG_SIZE)


if DER_TYPE == 'DER':
    make_fig(x_scan['x'], x_scan['y'][:, -1])

In [27]:
def make_fig(*, x, y, ax, color): 
    nu = y[:, 1]
    alpha = y[:, 2]
    beta = y[:, 3]
    ax.plot(x, np.sqrt(beta * (1. + nu) / nu / alpha), color=color)
    
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')


img = plot_evolution(
    x_scan_avg['x'],
    x_scan_avg['y'],
    make_fig,
    epochs=[1, 10, 50, 100,  -1],
    file_name='width_evol')

In [28]:
@ftl.with_context
def make_fig(x, y):
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\sqrt{\beta (1 + \nu) / \nu / \alpha}$')
    
    ax.axhline(y=0.01, color='k', linestyle='--', linewidth=1, alpha=.8)
    ax.axhline(y=0.1, color='k', linestyle='--', linewidth=1, alpha=.8)
    
    nu = y[..., 1]
    alpha = y[..., 2]
    beta = y[..., 3]
    
    n = y.shape[0]
    for i in range(n):
        ax.plot(x, np.sqrt(beta[i] * (1. + nu[i]) / nu[i] / alpha[i]), color='C0', alpha=max(.2, 3. / n))
        
    return ftl.save_fig(fig, 'width', resize=FIG_SIZE)
                
                
img2 = make_fig(x_scan['x'], x_scan['y'][:, -1])
ftl.img_grid([img, img2])

In [29]:
@ftl.with_context
def make_fig(x, y, normalize=False):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    axs = [ax, ax2]
    
    ax.set_yscale('log')
    ax2.set_yscale('log')
    
    for i in range(len(y)):
        nu = y[i][:, 1]
        alpha = y[i][:, 2]
        beta = y[i][:, 3]
        
        u = np.sqrt(beta * (nu + 1) / nu / alpha)
        u2 = 1. / np.sqrt(nu)
        
        if normalize:
            u_min, u_max = np.min(u), np.max(u)
            u2_min, u2_max = np.min(u2), np.max(u2)
            u = (u - u_min) / (u_max - u_min)
            u2 = (u2 - u2_min) / (u2_max - u2_min)
            
        axs[0].plot(x, u, color='C0', alpha=.2)
        axs[1].plot(x, u2, color='C1', alpha=.2)
        
    ax.set_xlabel('$x$')
    axs[0].set_ylabel('Aleatoric')
    axs[1].set_ylabel('Epistemic')
                    
    #ax.legend([
    #    mpl.lines.Line2D([0], [0], color='C0'),
    #    mpl.lines.Line2D([0], [0], color='C1'),
    #], ['Aleatoric', 'Epistemic'])
    
    return ftl.save_fig(fig, 'uncertainties_new', resize=FIG_SIZE)


make_fig(x_scan['x'], x_scan['y'][:, -1])

uncertainties_new.png

10k feet summary: if aleatoric uncertainty peaks one cannot tell w/o looking at $\nu$ if this is due to a genuine large data uncertainty or due to large model uncertainty.