[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timaeus-research/devinterp/blob/stan/jet/notebooks/DLN_SGLD_trajectory_notebook.ipynb)

# Normal Crossing SGLD trajectory notebook

This notebook uses SGLD to sample from the posterior around a point, for a polynomial model characterized by $y_{pred} = w_1^a * w_2^b * x$ for some $(a, b)$, where $w_1$ and $w_2$ are weights to be learned and $x$ is our (1-dimensional) input.

Both the input $x$ and target data $y$ are generated using (independent) gaussian noise, and the loss is Mean Squared Error from $y$ to $y_{pred}$, so the model achieves its lowest loss when $w_1=0$ or $w_2=0$.


In [1]:
!pip install git+https://github.com/timaeus-research/devinterp.git@stan/jet

%pip install seaborn

Collecting git+https://github.com/timaeus-research/devinterp.git@stan/jet
  Cloning https://github.com/timaeus-research/devinterp.git (to revision stan/jet) to /tmp/pip-req-build-2s8s6305
  Running command git clone --filter=blob:none --quiet https://github.com/timaeus-research/devinterp.git /tmp/pip-req-build-2s8s6305
  Running command git checkout -b stan/jet --track origin/stan/jet
  Switched to a new branch 'stan/jet'
  Branch 'stan/jet' set up to track remote branch 'stan/jet' from 'origin'.
  Resolved https://github.com/timaeus-research/devinterp.git to commit 7c09f4dead8e939f31a5b7afa41f6837f796498b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting PyMoments (from devinterp==1.2.0)
  Downloading PyMoments-1.0.1-py3-none-any.whl.metadata (6.0 kB)
Downloading PyMoments-1.0.1-py3-none-any.whl (95 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# @title
import locale


def getpreferredencoding(do_setlocale=True):
    return "UTF-8"


locale.getpreferredencoding = getpreferredencoding

import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
from matplotlib.colors import LinearSegmentedColormap

from devinterp.optim.sgld import SGLD
from devinterp.optim.sgnht import SGNHT
from devinterp.slt.llc import OnlineLLCEstimator
from devinterp.utils import default_nbeta, evaluate_mse
from devinterp.slt.sampler import sample

from devinterp.jet_tools.plot import *
from devinterp.jet_tools.models import *
from devinterp.jet_tools.diffs import *
from devinterp.jet_tools.utils import *

# **Auxiliary functions**

In [3]:
# plotting
sns.set_style("whitegrid")

# plotting
CMAP = sns.color_palette("muted", as_cmap=True)
PRIMARY, SECONDARY, TERTIARY = CMAP[:3]


def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip("#")
    return tuple(int(hex_color[i : i + 2], 16) / 255.0 for i in (0, 2, 4))


lighter_factor = 0.9  # Between 0 and 1, higher values make it closer to white
lighter_SECONDARY = tuple(
    [x + (1 - x) * lighter_factor for x in hex_to_rgb(SECONDARY)[:3]] + [1.0]
)

colors = [SECONDARY, lighter_SECONDARY]
n_bins = 20  # Number of bins
contour_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=n_bins)

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
SEED = 0
SIGMA = 1
NUM_TRAIN_SAMPLES = 1000
BATCH_SIZE = NUM_TRAIN_SAMPLES
EVALUATE = evaluate_mse

train_loader, train_data, x, y = generate_dataset_for_seed(
    seed=SEED, sigma=SIGMA, batch_size=BATCH_SIZE, num_samples=NUM_TRAIN_SAMPLES
)

## **Marginal and joint distributions of jets**

In [6]:
test_array = np.array([[[0.0], [1.0], [3.0], [4.0]]])

should_be = np.array([1.0, -1.0])
print(ith_place_nth_diff(test_array, 1, n=2).reshape(-1) - should_be.reshape(-1))

[0. 0.]


# **SGLD trajectories sampling**


In [7]:
method = "xavier"
dln_model = DLN(
    DLNConfig(
        input_dim=2,
        hidden_dim=2,
        output_dim=2,
        n_layers=2,
        initialization_method=method,  # "xavier" "zeros" "random" "kaiming"
    )
)

dln_dataset = create_dataset(dln_model, seed=SEED)
dln_dataloader = torch.utils.data.DataLoader(dln_dataset, batch_size=64)

Created continuous dataset (shuffle parameter is ignored)


In [8]:
INITIAL_PARAMETERS = [0.0, 0.0]
A_B = [
    1,
    3,
]  # MSE loss function and zero-mean gaussian, so this is effectively w_1**2 * w_2**4 as loss, NB

nbeta = default_nbeta(len(train_data))  # = n/log(n)
model = PolyModel(torch.tensor(A_B).to(DEVICE), DEVICE).to(
    DEVICE
)  # lol idk which one it should be
model.weights = nn.Parameter(
    torch.tensor(INITIAL_PARAMETERS, dtype=torch.float32, requires_grad=True).to(DEVICE)
)

In [9]:
# Number and length of chains
num_chains_sgld = 1
num_draws_sgld = 10_000

# SGLD parameters
epsilon_sgld = 0.005

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
sgld_weights = WeightCallback(
    num_chains=num_chains_sgld, num_draws=num_draws_sgld, model=dln_model
)

online_llc_estimator = OnlineLLCEstimator(
    num_chains=num_chains_sgld,
    num_draws=num_draws_sgld,
    nbeta=nbeta,
    init_loss=0.0,
    device=DEVICE,
)
trace_sgld = sample(
    dln_model,
    dln_dataloader,
    evaluate=EVALUATE,
    optimizer_kwargs=dict(lr=epsilon_sgld, nbeta=nbeta, bounding_box_size=1.0),
    sampling_method=SGLD,
    num_chains=num_chains_sgld,  # independent sampling runs
    num_draws=num_draws_sgld,  # length of a sampling run
    verbose=True,
    device=DEVICE,
    callbacks=[online_llc_estimator, sgld_weights],
    seed=SEED,
)

Chain 0:   1%|          | 123/10000 [00:05<03:05, 53.35it/s]

In [None]:
wt_sgld = np.array(sgld_weights.get_results()["ws/trace"])

# numpy running mean of llc_chain (does not work yet with multiple chains)
llc_chain = online_llc_estimator.get_results()["llc/trace"]
# llc_chain = np.cumsum(llc_chain) / np.arange(1, len(llc_chain[0]) + 1)
print(np.nonzero(np.isnan(wt_sgld)))
print(epsilon_sgld * num_draws_sgld)

# take the first element of the first element of every weights tensor (which has shape (3, 10000, 2, 3, 3)), so we get
# wt_sgld = wt_sgld[:,:,:,0,0]
# print(wt_sgld.shape)

# **SGLD plots**

## **Plotting SGLD trajectories and jet coordinates**

In [None]:
plot_multi_trajectories(wt_sgld, [1, 10], [0, 1, 2], "SLGD")

## **Plot cumulant statistics of marginal distribution of jet coordinates SGLD**


In [None]:
flattened = wt_sgld.reshape((num_chains_sgld, num_draws_sgld, -1))
print(flattened.shape)
plot_second_order_first_place_stats(flattened, 1, method=method)
plot_third_order_stats_per_dim(flattened, 1, method=method)

In [None]:
plot_second_order_two_place_stats(wt_sgld, 1)

In [None]:
plot_third_order_stats(wt_sgld, 1)

## **Vector fields SGLD**

In [None]:
plot_vector_field_jets(wt_sgld, [10, 100, 500], [1], 15, "SGLD")

# **SGHNT trajectories sampling**

In [None]:
num_chains_sgnht = 1
num_draws_sgnht = 10_000

epsilon_sgnht = 0.001

In [None]:
sgnht_weights = WeightCallback(
    num_chains=num_chains_sgnht, num_draws=num_draws_sgnht, model=model
)
online_llc_estimator = OnlineLLCEstimator(
    num_chains=num_chains_sgnht,
    num_draws=num_draws_sgld,
    nbeta=nbeta,
    init_loss=0.0,
    device=DEVICE,
)

trace_sgnht = sample(
    model,
    train_loader,
    evaluate=EVALUATE,
    optimizer_kwargs=dict(
        lr=epsilon_sgnht, nbeta=nbeta, diffusion_factor=0.005, bounding_box_size=1.0
    ),
    sampling_method=SGNHT,
    num_chains=num_chains_sgnht,  # independent sampling runs
    num_draws=num_draws_sgnht,  # length of a sampling run
    verbose=True,
    device=DEVICE,
    callbacks=[online_llc_estimator, sgnht_weights],
    seed=SEED,
)

In [None]:
wt_sgnht = np.array(sgnht_weights.get_results()["ws/trace"])

print(np.nonzero(np.isnan(wt_sgnht)))
print(epsilon_sgld * num_draws_sgnht)

# **SGNHT plots**

## **Plotting SGNHT trajectories and jet coordinates**




In [None]:
plot_multi_trajectories(wt_sgnht, [1, 10], [0, 1, 2], "SGNHT")

## **Plot cumulant statistics of marginal distribution of jet coordinates SGNHT**

In [None]:
plot_second_order_two_place_stats(wt_sgnht, 1)

In [None]:
plot_third_order_stats(wt_sgnht, 1)

## **Vector fields SGNHT**

In [None]:
plot_vector_field_jets(wt_sgnht, [10, 100, 500], [1], 15, "SGNHT")