In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import seaborn as sns
from src.experiments.common import Experiment
import torch
from hydra.utils import instantiate
from src.inference.mcmc import MCMCInference
from src.inference.mcmc.samplers import HMC
import matplotlib.pyplot as plt
from src.models.linear import LinearRegressor
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from src.bayesian.prior_sets import get_normal

In [None]:
def eval_poly(x, coeffs):
    return coeffs[0] + sum(c*x**i for i, c in enumerate(coeffs[1:], start=1))

In [None]:
run = Experiment("simulated").latest_run()

In [None]:
dataset = instantiate(run.runs[0].config.data.dataset)
X, Y = dataset[:]

In [None]:
xx = torch.linspace(-3, 3)
plt.plot(xx, eval_poly(xx, dataset.coeffs))
plt.scatter(X[:, 1], Y[:, 0]) 

In [None]:
L_0 = torch.eye(4)
mu_0 = torch.zeros(4)
ols = (X.T @ X).inverse() @ X.T @ Y.squeeze()
L_n =  X.T @ X + L_0
mu_n = L_n.inverse() @ (X.T @ X @ ols + L_0 @ mu_0)
posterior = torch.distributions.MultivariateNormal(mu_n, precision_matrix=L_n)

In [None]:
def get_marginal(dist: torch.distributions.Normal, i, j=None):
    
    if j is None:
        mean = dist.mean[i]
        var = dist.covariance_matrix[i, i]
        return torch.distributions.Normal(mean, var.sqrt())
    else:
        mean = dist.mean[[i, j]]
        v = dist.covariance_matrix
        cov = torch.tensor([
            [v[i, i], v[i, j]],
            [v[j, i], v[j, j]],
        ])
        return torch.distributions.MultivariateNormal(mean, cov)

In [None]:
from dataclasses import asdict

def to_matrix(samples):
    return torch.stack(list(samples.values())).numpy()

sample_data = pd.concat(
    pd.DataFrame(to_matrix(torch.load(r.path / "saved_samples.pt")))
    .rename_axis(index="sample")
    .assign(
        sampler=r.config["inference"]["sampler"]["_target_"],
        batch_size=r.config["data"]["batch_size"],
    )
    .set_index(["sampler", "batch_size"], append=True)
    .reorder_levels(["sampler", "batch_size", "sample"])
    for r in run.runs if "sampler" in r.config["inference"]
)


In [None]:
sns.cubehelix_palette(start=2.6)

In [None]:
iter(map(lambda x: sns.cubehelix_palette(rot=x), [0, -0.4, 2.8]))

plot_cfg = {
    ('src.inference.mcmc.samplers.HMC', 15): {"color_palette": "Blues_r"},
    ('src.inference.mcmc.samplers.HMC', 5): {"color_palette": "Greens_r"},
    ('src.inference.mcmc.samplers.SGHMC', 5): {"color_palette": "Oranges_r"}
}


In [None]:
def plot_univariate(x, **kwargs):
    i = x.name
    marg = get_marginal(posterior, i)
    xlims = min(x), max(x)
    xx = torch.linspace(*xlims)
    yy = marg.log_prob(xx).exp()
    plt.plot(xx, yy, c="black")

def plot_bivariate(x, y, **kwargs):
    i = x.name
    j = y.name
    marg = get_marginal(posterior, i, j)
    xlims = min(x), max(x)
    ylims = min(y), max(y)
    xx = torch.linspace(*xlims, 400)
    yy = torch.linspace(*ylims, 400)
    XY = torch.stack(torch.meshgrid(xx, yy), dim=-1)
    ZZ = marg.log_prob(XY).exp()
    plt.contour(XY[..., 0], XY[..., 1], ZZ, colors="black")


In [None]:
plots = {}
for key, data in sample_data.groupby(level=[0, 1]):
    with sns.color_palette(plot_cfg[key]["color_palette"]):
        plots[key] = sns.pairplot(
            data,
            kind="hist",
            diag_kws={"stat": "density"}
        )
        plots[key].map_diag(plot_univariate)
        plots[key].map_offdiag(plot_bivariate)


In [None]:
i = 1
j = 3
joint_plots = {}
for key, data in sample_data[[i, j]].groupby(level=[0, 1]):
    with(sns.color_palette(plot_cfg[key]["color_palette"])):
        x = data[i]
        y = data[j]
        joint_plots[key] = sns.jointplot(x=data[i], y=data[j], kind="hist")
        plt.sca(joint_plots[key].ax_joint)
        plot_bivariate(x, y)


In [None]:

import pandas as pd
import seaborn as sns

n_samples = 10_000
param_samples = posterior.sample((n_samples,))
XX = torch.stack([torch.ones_like(xx), xx, xx**2, xx**3]).T
predictions = XX @ param_samples.unsqueeze(-1)

(
    pd.DataFrame(predictions.squeeze().numpy(), columns=XX[:, 1].numpy())
    .melt(
        var_name="x",
        value_name="y",
    )
    .groupby("x")
    .quantile([0.05, 0.5, 0.95])
    .unstack()
    .droplevel(0, axis="columns")
    .reset_index()
    .melt(
        id_vars="x",
        var_name="quantile",
    )
    .pipe((sns.relplot, "data"), x="x", y="value", hue="quantile", kind="line")
)
plt.scatter(X[:,1], y)
plt.plot(xx, eval_poly(xx, coeffs))
plt.plot(xx, eval_poly(xx, sgd_inf.model.linear.weight.detach()))

In [None]:
mu_0 = 0
sigma_0 = 1.

In [None]:
LinearModel

In [None]:
dm.train_data.dataset.tensors[1]

In [None]:
from src.bayesian.priors import ScaleMixturePrior, NormalPrior

In [None]:
prior = ScaleMixturePrior(0, 0, 2, 0)
xx = torch.linspace(-20, 20, 300)
plt.plot(xx, prior.log_prob(xx).exp())

In [None]:
# runs_path = Path("../experiment_results/iris/2021-11-20/23-35-21").resolve()
runs = Experiment("iris").latest_run()

In [None]:
sghmc = next(
    run
    for run in runs.runs if
    run.config["inference"]["_target_"] == "src.inference.mcmc.MCMCInference"
    and run.config["inference"]["sampler"]["_target_"]
    == "src.inference.mcmc.samplers.SGHMC"
)


In [None]:
hmc_batched = next(
    run
    for run in runs.runs
    if run.config["data"]["batch_size"] != -1
    and run.config["inference"]["_target_"] == "src.inference.mcmc.MCMCInference"
    and run.config["inference"]["sampler"]["_target_"]
    == "src.inference.mcmc.samplers.HMC"
)


In [None]:
hmc_full = next(
    run
    for run in runs.runs
    if run.config["data"]["batch_size"] == -1
    and run.config["inference"]["_target_"] == "src.inference.mcmc.MCMCInference"
    and run.config["inference"]["sampler"]["_target_"]
    == "src.inference.mcmc.samplers.HMC"
)


In [None]:
import seaborn as sns

def get_samples(run):
    samples = torch.load((run.path / "saved_samples.pt"))

    def flatten_sample(sample):
        a_1, a_2 = sample["linear.weight"][:, 0].numpy()
        b_1, b_2 = sample["linear.weight"][:, 1].numpy()
        c_1, c_2 = sample["linear.bias"].numpy()

        return {
            "a_1": a_1,
            "a_2": a_2,
            "b_1": b_1,
            "b_2": b_2,
            "c_1": c_1,
            "c_2": c_2,
        }

    return pd.DataFrame.from_records([flatten_sample(sample) for sample in samples.values()])

hmc_full_data = get_samples(hmc_full)


In [None]:
mcmc_samples.loc["sghmc"].plot()

In [None]:
mcmc_samples = pd.concat(
    [get_samples(hmc_full), get_samples(hmc_batched), get_samples(sghmc)],
    keys=["full", "batched", "sghmc"],
    names=["algorithm"]
)


In [None]:
(
    mcmc_samples.reset_index(level="algorithm")
    .reset_index(drop=True)
    .pipe((sns.pairplot, "data"), hue="algorithm", kind="kde")
)


In [None]:
sghmc.plot()

In [None]:
sns.pairplot(hmc_full_data, kind="kde")

In [None]:
hmc_full_data.columns = pd.MultiIndex.from_tuples([x.split("_") for x in hmc_full_data.columns])


In [None]:
sns.load_dataset("penguins")

In [None]:
hmc_full_data

In [None]:
sns.pairplot(hmc_full_data.melt(col_level=1), hue="variable")

In [None]:
fg = sns.relplot(data=iris_data, x="Component 1", y="Component 2", hue="Species", col="Split")
for ax in fg.axes.flatten():
    plt.sca(ax)
    for between in [(1, 2), (0, 1)]:
        slope, intercept = get_decision_boundary(sgd_inference.model, between)
        plt.axline((0, intercept), slope=slope)

Use toy model, but with categorial likelihood

In [None]:
vi_inference = VariationalInference(
    model=PolynomialClassifier(4, 3, bias=False),
    lr=1e-3,
    n_particles=10,
)
trainer = Trainer()
trainer.fit(
    model=vi_inference,
    train_dataloaders=DataLoader(train_data, batch_size=8, shuffle=True),
    val_dataloaders=DataLoader(test_data, batch_size=8),
)

In [None]:
mcmc_inference = MCMCInference(
    model=PolynomialClassifier(4, 3, bias=False),
    burn_in=50
)
trainer = Trainer()
trainer.fit(
    model=mcmc_inference,
    train_dataloaders=DataLoader(train_data, batch_size=8, shuffle=True),
    val_dataloaders=DataLoader(test_data, batch_size=8),
)


In [None]:
samples = pd.DataFrame(torch.stack(list(mcmc_inference.sample_container.samples.values())).numpy())

In [None]:
samples.plot()

In [None]:
model = IrisModel(4, 3)

In [None]:
sampler = Hamiltonian(step_size=0.04)
inference = MonteCarloInference(sampler=sampler)
inference.fit(model, train_data, burn_in=2000, n_samples=500)

In [None]:
sample_df = pd.DataFrame({f"$c_{i}$" : c for i, c in enumerate(inference.samples_.T)})
plot_data = sample_df.reset_index().melt(id_vars="index")
fig = plt.figure()
grid_spec = fig.add_gridspec(2, 1, height_ratios=(2, 7))

ax_line = fig.add_subplot(grid_spec[1, 0])
ax_marg = fig.add_subplot(grid_spec[0, 0], sharex=ax_line)

sns.lineplot(x = "value", y="index", hue="variable", ax=ax_line, data=plot_data, sort=False, legend=False)
sns.kdeplot(x = "value", hue="variable", data=plot_data, legend=False)

plt.show()

In [None]:
y_pred_samples = inference.predictive(x_test)
y_pred = torch.tensor([v.bincount(minlength=3).argmax() for v in y_pred_samples.argmax(-1).T])

In [None]:
pca = PCA(2).fit(x_train)

In [None]:
certainty = (y_pred == y_pred_samples.argmax(-1)).sum(0) / 500

In [None]:
u, v = pca.transform(x_test).T
# sns.scatterplot(x=u, y=v, hue=y_pred)
sns.scatterplot(x=u, y=v, hue=certainty)
plt.figure()
sns.scatterplot(x=u, y=v, hue=y_pred)