## [WIP] Epsilon-Beta Visualizations
This notebook aims to visualize $\hat{\lambda}_n^\beta$ for various values of $\beta$ (inverse temperature) and $\epsilon$ (step size). Adrian's Thesis roughly states that:
- $\beta$ can be tuned via graphing $\hat{\lambda}_n^\beta$ for a sweep of $\beta$, and using $\beta$ in a range around the critical points on the graph.
- $\epsilon$ should be the greatest possible value that doesn't cause excessive numerical instability or cause the SGLD chains to fail to converge. An MALA proposal acceptance rate (see `sgld_calibration.ipynb`) between 0.9 - 0.95 is roughly optimal.

## Set-up

In [None]:
%pip install devinterp transformers torchvision

In [7]:
import torch
import torchvision
from transformers import AutoModelForImageClassification

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD
from devinterp.utils import plot_trace

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def transformers_cross_entropy(inputs, outputs):
    return torch.nn.functional.cross_entropy(
        inputs.logits, outputs
    )  # transformers doesn't output a vector

# Load a pretrained MNIST classifier
model = AutoModelForImageClassification.from_pretrained("fxmarty/resnet-tiny-mnist")
data = torchvision.datasets.MNIST(
    root="../data",
    download=True,
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
        ]
    ),
)
loader = torch.utils.data.DataLoader(data, batch_size=256, shuffle=True)

In [104]:
from devinterp.utils import optimal_temperature
import numpy as np

beta_range = np.power(10, np.linspace(-1.5, 2, 12)) * optimal_temperature(loader)
# betas: from beta* /10 to beta* * 100
epsilon_range = np.power(10, np.linspace(-6, -2, 12)) # epsilons: from 1e-6 to 1e-2

In [105]:
from devinterp.slt.mala import MalaAcceptanceRate
from tqdm import tqdm, trange
NUM_CHAINS = 5
NUM_DRAWS = 300
all_sweep_stats = []
with tqdm(total=len(epsilon_range) * len(beta_range)) as pbar:
    for epsilon in epsilon_range:
        for beta in beta_range:
            mala_estimator = MalaAcceptanceRate(
                num_chains=NUM_CHAINS,
                num_draws=NUM_DRAWS,
                temperature=beta,
                learning_rate=epsilon,
                device=DEVICE,
            )

            learning_coeff_stats = estimate_learning_coeff_with_summary(
                model,
                loader=loader,
                criterion=transformers_cross_entropy,
                sampling_method=SGLD,
                optimizer_kwargs=dict(lr=epsilon, localization=100.0, temperature=beta),
                num_chains=NUM_CHAINS,  # How many independent chains to run
                num_draws=NUM_DRAWS,  # How many samples to draw per chain
                num_burnin_steps=0,  # How many samples to discard at the beginning of each chain
                num_steps_bw_draws=1,  # How many steps to take between each sample
                device=DEVICE,
                online=True,
                callbacks = [mala_estimator],
                verbose = False
            )
            mala_estimator_stats = mala_estimator.sample()
            sweep_stats = dict(learning_coeff_stats, **mala_estimator_stats) # Concatenate stat dictionaries
            sweep_stats = dict(sweep_stats, epsilon=epsilon, beta=beta)
            all_sweep_stats.append(sweep_stats)
            pbar.update(1)


You are taking more draws than burn-in steps, your LLC estimates will likely be underestimates. Please check LLC chain convergence.


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)

100%|██████████| 144/144 [1:30:52<00:00, 37.86s/it]


In [107]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(all_sweep_stats)
df["llc/std_over_mean"] = df["llc/trace"].apply(lambda x: x[:, -20:].std() / x[:, -20:].mean())
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
px.scatter_3d(df, x="epsilon", y="beta", z="llc/final", color="llc/std_over_mean", log_y=True, log_x=True, log_z=True, 
              title="Local learning coefficient vs. epsilon and beta",
              # Set max for color
              range_color=[0, 0.15])

In [108]:
fig.write_html("epsilon_beta_sweep.html")

In [None]:
import seaborn as sns
for chain in df.iloc[-4]["llc/trace"]:
    sns.lineplot(data=chain)
print(df.iloc[-4]["llc/std_over_mean"])

In [103]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(all_sweep_stats)
df["llc/std_over_mean"] = df["llc/trace"].apply(lambda x: x[:, -20:].std() / x[:, -20:].mean())
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
px.scatter_3d(df, x="epsilon", y="beta", z="llc/final", color="llc/std_over_mean", log_y=True, log_x=True, log_z=True, 
              title="Local learning coefficient vs. epsilon and beta",
              # Set max for color
              range_color=[0, 0.15])

In [106]:
fig.write_html("epsilon_beta_sweep.html")

In [76]:
import plotly.graph_objects as go
import pandas as pd
df = pd.DataFrame(all_sweep_stats)
# 3d contour plot
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
# Log scale
df["llc/log_final"] = df["llc/final"].apply(lambda x: np.log10(x))
df["log_epsilon"] = df["epsilon"].apply(lambda x: np.log10(x))
df["log_beta"] = df["beta"].apply(lambda x: np.log10(x))

fig = go.Figure(data=[go.Surface(
    x=df["log_epsilon"],
    y=df["log_beta"],
    z=df["llc/log_final"].values.reshape(len(beta_range), len(epsilon_range)).T,
    colorscale='Viridis',
    opacity=0.6,
    contours=dict(z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project=dict(z=True)))
)])
fig.update_layout(scene = dict(
                    xaxis_title='epsilon',
                    yaxis_title='beta',
                    zaxis_title='llc/final',    
                    xaxis = dict(nticks=8, range=[-6,0],),
                    yaxis = dict(nticks=8, range=[0,4],),
                    zaxis = dict(nticks=4, range=[0,4],),
                    ))

fig.show()

In [33]:
list(sweep_stats.values())[0]["llc/trace"][:, -5:].mean()

8.317402