# Sampler Diagnostics Demo (Normal Crossing)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/diagnostics.ipynb)

This notebook demonstrates the usage of various diagnostic tools for the sampling and RLCT estimation process. As an example, we'll use normal crossings for each diagnostic. This is a polynomial model characterized by $f(x) = w_1^a w_2^b x$ for some $(a, b)$, where $w_1$ and $w_2$ are weights to be learned. The data is generated with gaussian noise around the origin, so the model achieves its lowest loss when $w_1=0$ or $w_2 =0$.

We'll also be using the SGLD optimizer.

In [1]:
%pip install matplotlib seaborn
%cd .. 
%pip install -e .
%cd examples

Note: you may need to restart the kernel to use updated packages.
/home/paperspace/devinterp
Obtaining file:///home/paperspace/devinterp
[31mERROR: Exception:
Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pip/_internal/cli/base_command.py", line 180, in exc_logging_wrapper
    status = run_func(*args)
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pip/_internal/cli/req_command.py", line 245, in wrapper
    return func(self, options, args)
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pip/_internal/commands/install.py", line 377, in run
    requirement_set = resolver.resolve(
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pip/_internal/resolution/resolvelib/resolver.py", line 76, in resolve
    collected = self.factory.collect_root_requirements(root_reqs)
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pip/_internal/resolution/resolvelib/fact

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from devinterp.optim.sgld import SGLD
from devinterp.utils import plot_trace
from devinterp.slt import sample
from devinterp.slt import OnlineLLCEstimator

ModuleNotFoundError: No module named 'torch'

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# plotting
sns.set_style("whitegrid")
CMAP = sns.color_palette("muted", as_cmap=True)
PRIMARY, SECONDARY, TERTIARY = CMAP[:3]
plt.rcParams["figure.figsize"]=12,9  # note: this cell may need to be re-run after creating a plot to take effect

# constants
SIGMA = 0.25
NUM_TRAIN_SAMPLES = 1000
BATCH_SIZE = NUM_TRAIN_SAMPLES
CRITERION = F.mse_loss

NUM_CHAINS = 20
NUM_DRAWS = 2000

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/4258283777.py", line 1, in <module>
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NameError: name 'torch' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/ho

In [None]:
# Set up RLCT estimation
class PolyModel(nn.Module):
    def __init__(self, powers):
        super(PolyModel, self).__init__()
        self.weights = nn.Parameter(
            torch.tensor([1.0, 0.3], dtype=torch.float32, requires_grad=True, device=DEVICE)
        )
        self.powers = powers

    def forward(self, x):
        multiplied = torch.prod(self.weights**self.powers)
        x = x * multiplied
        return x


def generate_dataset_for_seed(seed=0):
    x = torch.normal(0, 2, size=(NUM_TRAIN_SAMPLES,))
    y = torch.normal(0, SIGMA, size=(NUM_TRAIN_SAMPLES,))
    train_data = TensorDataset(x, y)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    return train_loader, train_data


def run_callbacks(
    train_loader,
    train_data,
    weights=[0.0, 0.0],
    powers=torch.tensor([1, 2]).to(DEVICE),
    callbacks=None,
    device=DEVICE
):
    model = PolyModel(powers).to(DEVICE)
    model.weights = nn.Parameter(torch.tensor(weights, dtype=torch.float32, requires_grad=True, device=DEVICE))

    optim_kwargs = {
        "lr": 0.0005,
        "elasticity": 1.0,
        "temperature": "adaptive",
        "num_samples": len(train_data),
        "save_noise": True,
    }
    
    if callbacks is None:
        llc_estimator = OnlineLLCEstimator(NUM_CHAINS, NUM_DRAWS, len(train_loader.dataset), device=DEVICE)
        callbacks = [llc_estimator]
        
    sample(
        model=model,
        loader=train_loader,
        criterion=CRITERION,
        optimizer_kwargs=optim_kwargs,
        sampling_method=SGLD,
        num_chains=NUM_CHAINS,
        num_draws=NUM_DRAWS,
        callbacks=callbacks,
        device=device
    )

    results = {}

    for callback in callbacks:
        if hasattr(callback, "sample"):
            results.update(callback.sample())

    return results


def get_rlct(
    train_loader,
    train_data,
    weights=[0.0, 0.0],
    lr=0.0005,
    powers=torch.tensor([1, 2]).to(DEVICE),
    device=DEVICE
):
    llc_estimator = OnlineLLCEstimator(NUM_CHAINS, NUM_DRAWS, len(train_loader.dataset), device=DEVICE)
    callbacks = [llc_estimator]
    return run_callbacks(
        train_loader=train_loader,
        train_data=train_data,
        weights=weights,
        powers=torch.tensor([1, 2]).to(DEVICE),
        callbacks=callbacks,
        device=DEVICE
    )["llc/means"][-1].item()

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/1609270389.py", line 2, in <module>
    class PolyModel(nn.Module):
NameError: name 'nn' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devinterp/.venv/l

## Testing RLCT estimation

Let's start with estimating some RLCTs at known values for a simple normal crossing with $(a, b) = (1, 2)$.

In [None]:
train_loader, train_data = generate_dataset_for_seed(seed=0)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/463679980.py", line 1, in <module>
    train_loader, train_data = generate_dataset_for_seed(seed=0)
NameError: name 'generate_dataset_for_seed' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structur

In [None]:
sample_points = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.1, 0.1]]
known_rlcts = [0.25, 0.25, 0.5, 0.25]
estimated_rlcts = [
    get_rlct(train_loader, train_data, weights=sample_point)
    for sample_point in sample_points
]

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3843568791.py", line 3, in <module>
    estimated_rlcts = [
  File "/tmp/ipykernel_3494/3843568791.py", line 4, in <listcomp>
    get_rlct(train_loader, train_data, weights=sample_point)
NameError: name 'get_rlct' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/py

In [None]:
estimated_rlcts

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2040963822.py", line 1, in <module>
    estimated_rlcts
NameError: name 'estimated_rlcts' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devinterp/.venv/

## SamplerCallbacks

RLCT estimation uses a `SamplerCallback` class. A list of `SamplerCallback`s can be passed to `sample`, which causes each callback to be called at each draw made by the sampler for each chain. RLCT estimation passes a single `SamplerCallback`, `LLCEstimator` or `OnlineLLCEstimator` to `sample`.

Another such `SamplerCallback` is the `OnlineWBICEstimator`.

In [None]:
from devinterp.slt.wbic import OnlineWBICEstimator

wbic_estimator = OnlineWBICEstimator(
    num_chains=NUM_CHAINS, 
    num_draws=NUM_DRAWS, 
    n=NUM_TRAIN_SAMPLES,
    device=DEVICE
)

train_loader, train_data = generate_dataset_for_seed(seed=0)

results = run_callbacks(train_loader, train_data, weights=[0.0, 0.0], callbacks=[wbic_estimator])

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/4119021706.py", line 1, in <module>
    from devinterp.slt.wbic import OnlineWBICEstimator
ModuleNotFoundError: No module named 'devinterp'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File

In [None]:
# final WBIC estimation
results['wbic/means'][-1]

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3289545065.py", line 2, in <module>
    results['wbic/means'][-1]
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devinterp/.ven

The trace of a sampled statistic is an array of shape `[NUM_CHAINS, NUM_DRAWS]`, where the `[i, j]` entry is the value computed by the `SamplerCallback` at the *j*th draw of the *i*th chain.

It can sometimes be useful to plot the trace as a sanity check and to see where things might be going wrong. In the following graph, the estimated WBIC values level off and appear to converge, which is what we'd expect to see.

In [None]:
wbic_trace = results['wbic/trace']
plot_trace(wbic_trace, 'WBIC')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2190568803.py", line 1, in <module>
    wbic_trace = results['wbic/trace']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devin

## Other diagnostics

There are various diagnostic tools implemented as `SamplerCallback`s. Some are regular `SamplerCallback`s that can be passed to `sample` on their own, and others are "derivative" callbacks that run diagnostics on a particular `SamplerCallback` instance.

Regular implementations:
- `OnlineWBICEstimator`: estimates WBIC (seen above)
- `WeightNorm`: track the L2 norm of model weights during sampling
- `GradientNorm`: track the L2 norm of gradients during sampling
- `NoiseNorm`: track the L2 norm of SGLD noise term during sampling
  - (n.b. in order to compute the noise norm, you must pass `save_noise=True` to the SGLD optimizer)
- `GradientDistribution`: view a histogram/heatmap of gradient values at each SGLD step for specific named model parameters -- useful e.g. for checking that gradients haven't exploded or collapsed

Derivative implementations:
- `OnlineTraceStatistics`: compute the mean/std of the trace of another `SamplerCallback` across draws and across chains
- `OnlineLossStatistics`: computes various loss statistics for `OnlineLLCEstimator`

Additionally, since derivative callbacks depend on a base callback, they must be positioned later in the list of callbacks so that they're called after the base callback. A helper function `validate_callbacks` checks whether a list of callbacks satisfies this condition.

### WeightNorm, GradientNorm, NoiseNorm examples

In [None]:
from devinterp.slt.norms import GradientNorm, NoiseNorm, WeightNorm

gradient_norm = GradientNorm(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, device=DEVICE)
noise_norm = NoiseNorm(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, device=DEVICE)
weight_norm = WeightNorm(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, device=DEVICE)

norm_callbacks = [gradient_norm, noise_norm, weight_norm]

train_loader, train_data = generate_dataset_for_seed(seed=0)

results = run_callbacks(train_loader, train_data, weights=[0.0, 0.0], callbacks=norm_callbacks)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/1826400693.py", line 1, in <module>
    from devinterp.slt.norms import GradientNorm, NoiseNorm, WeightNorm
ModuleNotFoundError: No module named 'devinterp'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured

The norms are stored as traces that we can plot below. It may be useful to, for example, compare the gradient norm and noise norm to ensure one isn't completely dominating the other. The weight trace can also be helpful to check whether the sampler appears to be spreading out in the weight space.

In [None]:
grad_trace = results['gradient_norm/trace']
noise_trace = results['noise_norm/trace']
weight_trace = results['weight_norm/trace']

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3553741462.py", line 1, in <module>
    grad_trace = results['gradient_norm/trace']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/papersp

In [None]:
plot_trace(grad_trace, 'gradient norm')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/1471127766.py", line 1, in <module>
    plot_trace(grad_trace, 'gradient norm')
NameError: name 'plot_trace' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspa

In [None]:
plot_trace(noise_trace, 'noise norm')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2654678513.py", line 1, in <module>
    plot_trace(noise_trace, 'noise norm')
NameError: name 'plot_trace' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace

In [None]:
# note: the weight norm starts near 0 in this graph because the weights are initialized to [0, 0]
plot_trace(weight_trace, 'weight_norm')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/601126102.py", line 2, in <module>
    plot_trace(weight_trace, 'weight_norm')
NameError: name 'plot_trace' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspac

### GradientDistribution

`GradientDistribution` shows the histogram of gradient values at each SGLD time step, with a darker color indicating a bin with a higher count. Below, we can see that some gradient values are relatively large, but most cluster around 0 (the darker blue line). In this case, gradients don't seem to be exploding or collapsing during sampling.

In [None]:
from devinterp.slt.gradient import GradientDistribution

grad_dist = GradientDistribution(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, min_bins=40)
callbacks = [grad_dist]
results = run_callbacks(train_loader, train_data, weights=[0.0, 0.0], callbacks=callbacks)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/759089658.py", line 1, in <module>
    from devinterp.slt.gradient import GradientDistribution
ModuleNotFoundError: No module named 'devinterp'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  

In [None]:
grad_dist.plot('weights')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2475940196.py", line 1, in <module>
    grad_dist.plot('weights')
NameError: name 'grad_dist' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devinterp/.v

### OnlineLossStatistics, OnlineTraceStatistics

In [None]:
from devinterp.slt.callback import validate_callbacks
from devinterp.slt.loss import OnlineLossStatistics
from devinterp.slt.norms import GradientNorm, NoiseNorm, WeightNorm
from devinterp.slt.trace import OnlineTraceStatistics

# llc estimator for OnlineLossStatistics
llc_estimator = OnlineLLCEstimator(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, n=NUM_TRAIN_SAMPLES)
loss_statistics = OnlineLossStatistics(base_callback=llc_estimator)

# weight norm for OnlineTraceStatistics
weight_norm = weight_norm = WeightNorm(num_chains=NUM_CHAINS, num_draws=NUM_DRAWS, device=DEVICE)
trace_statistics = OnlineTraceStatistics(base_callback=weight_norm, attribute='weight_norms')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/1511855963.py", line 1, in <module>
    from devinterp.slt.callback import validate_callbacks
ModuleNotFoundError: No module named 'devinterp'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  F

In [None]:
# validate_callbacks throws an error if you try to pass a derivative callback before its base callback
callbacks = [loss_statistics, llc_estimator]
validate_callbacks(callbacks)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/138070363.py", line 2, in <module>
    callbacks = [loss_statistics, llc_estimator]
NameError: name 'loss_statistics' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home

In [None]:
# it passes True if the callbacks meet the ordering condition
callbacks = [llc_estimator, loss_statistics, weight_norm, trace_statistics]
validate_callbacks(callbacks)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/1498434145.py", line 2, in <module>
    callbacks = [llc_estimator, loss_statistics, weight_norm, trace_statistics]
NameError: name 'llc_estimator' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in stru

In [None]:
results = run_callbacks(train_loader, train_data, weights=[0.0, 0.0], callbacks=callbacks)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/4207772104.py", line 1, in <module>
    results = run_callbacks(train_loader, train_data, weights=[0.0, 0.0], callbacks=callbacks)
NameError: name 'run_callbacks' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", lin

**OnlineTraceStatistics**

`OnlineTraceStatistics` is a general purpose 'derivative' `SamplerCallback` that takes any base statistic that's computed as a trace and computes the mean and standard deviation of that statistic in two ways: across *draws* and across *chains*.

The black dotted lines and gray overlays in the previous charts are a visualization of what it looks like to compute the mean and std across *draws*.

The mean and std statistics computed across a *chain* are the cumulative mean and std of that single chain at a given draw step. For example, the final mean and std computed for a chain would be the mean and std of the base statistic computed at all draw steps of that chain.

In [None]:
# Mean and std for each chain of the computed weight norms
means = results['weight_norms/chain/mean']
stds = results['weight_norms/chain/std']

print('The shapes match the shape of the trace.')
print(means.shape)
print(stds.shape)

print('\nThe mean of all draw steps across a given chain i is the ith index below.')
final_means = means[:, -1]
print(final_means)

print('\nThe std of all draw steps across chain 5:')
std = stds[5, -1]
print(std)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2088047608.py", line 2, in <module>
    means = results['weight_norms/chain/mean']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspa

**OnlineLossStatistics**

Since we're using minibatches instead of being in the limit of infinite training data, there's noise introduced by the random selection of minibatch. We can estimate/visualize the noise by looking at a histogram of the initial losses of each chain in our sampler, since these are losses from different minibatches on a fixed weight. Note that we get this by calling a method of `OnlineLossStatistics` rather than referring to `results`.

Note that these values may be so close together that it causes a rendering issue in `plt.hist`, but we can see from the values that the minibatch noise is very small.

In [None]:
init_losses = loss_statistics.loss_hist_by_draw(draw=0, bins=10)
init_losses

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3656987481.py", line 1, in <module>
    init_losses = loss_statistics.loss_hist_by_draw(draw=0, bins=10)
NameError: name 'loss_statistics' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_tr

Unexpected exception formatting exception. Falling back to standard exception


`OnlineLossStatistics` also provides a few ways to check the "health" of your loss chains. For example, your chains should ideally not be achieving loss values lower than your initial loss. You can see the loss values directly by plotting the loss trace, or you can check the cumulative percent of negative steps relative to the initial loss through a few statistics computed by `OnlineLossStatistics`, one of which is `"loss/percent_neg_steps"`.

In [None]:
loss_trace = results['loss/trace']
plot_trace(loss_trace, 'loss')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3426001199.py", line 1, in <module>
    loss_trace = results['loss/trace']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/home/paperspace/devin

In [None]:
cum_perc_neg_steps = results['loss/percent_neg_steps']
plot_trace(cum_perc_neg_steps, 'cumulative % negative')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/2354953871.py", line 1, in <module>
    cum_perc_neg_steps = results['loss/percent_neg_steps']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
  File "/h

Recall that there is still some minibatch noise, so we can also check how negative losses are relative to the initial loss while also thresholding by the estimated minibatch noise. Most negative draws end up within the threshold, and any chains that still end up consistently negative can be pruned from estimations.

In [None]:
thresholded_neg_steps = results['loss/percent_thresholded_neg_steps']
plot_trace(thresholded_neg_steps, 'cumulative % of thresholded negative')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
  File "/tmp/ipykernel_3494/3981448502.py", line 1, in <module>
    thresholded_neg_steps = results['loss/percent_thresholded_neg_steps']
NameError: name 'results' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/pygments/styles/__init__.py", line 45, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
  File "/home/paperspace/devinterp/.venv/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_trace