In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from pathlib import Path
from scipy.stats import wasserstein_distance as WD
matplotlib.rcParams["text.usetex"] = True
matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["font.size"] = "35"

PATH = Path.cwd() / "fig" / "toy"
PATH.mkdir(exist_ok=True)

In [None]:
x_mnf = np.load("data/mnf_toy_gaussian2_x2.npy")
nll = np.load("data/mnf_toy_gaussian2_l2.npy")
x_mnf_marg = np.load("data/mnf_toy_gaussian2_marg.npy")
x_profiti = np.load("data/profiti_toy_gaussian2_x.npy")
x_profiti_marg = np.load("data/profiti_toy_gaussian2_x_nc.npy")
x_gmix_1 = np.load("data/gmix_toy_gaussian2_1_x.npy")
x_gmix_1_marg = np.load("data/gmix_toy_gaussian2_1_marg.npy")

In [None]:
MI_gau_mnf = 0.5*(WD(x_mnf[:,0], x_mnf_marg[:,0]) + WD(x_mnf[:,1],x_mnf_marg[:,1]))
MI_gau_profiti = 0.5*(WD(x_profiti[:,0], x_profiti_marg[:,0]) + WD(x_profiti[:,1],x_profiti_marg[:,1]))
MI_gau_gmix = 0.5*(WD(x_gmix_1[:,0], x_gmix_1_marg[:,0]) + WD(x_gmix_1[:,1],x_gmix_1_marg[:,1]))
print(MI_gau_mnf, MI_gau_profiti, MI_gau_gmix)

In [None]:
XLIM = np.array([-10, +10])
YLIM = np.array([-10, +10])
PLIM = np.array([0, +1.25])

XTICKS = np.array([-20 / 3, -10 / 3, 0, 10 / 3, 20 / 3])
YTICKS = np.array([-20 / 3, -10 / 3, 0, 10 / 3, 20 / 3])
PTICKS = np.array([0, 0.25, 0.5, 0.75, 1])

XTICKLABELS = []
YTICKLABELS = []
PTICKLABELS = []

FIGSIZE = (3, 3)
GRID = 1000  # gridsize for kde plot
NUM = None  # How many samples to plot (None=all)

n = 6000
x_orig = np.random.normal(0, 1, [n, 2, 1])
cov = np.array([[1, 0], [1, 1]])

y_orig = np.matmul(cov, x_orig)
y_orig = np.sign(y_orig) * (y_orig**2)
x_orig = y_orig[:, 0].squeeze()
y_orig = y_orig[:, 1].squeeze()

In [None]:
# RESCALE (so that same axes as circle plot)
SCALE = 3 / 10
XLIM = XLIM * SCALE
YLIM = YLIM * SCALE
XTICKS = XTICKS * SCALE
YTICKS = YTICKS * SCALE

x_orig = x_orig * SCALE
y_orig = y_orig * SCALE
x_mnf = x_mnf * SCALE
x_mnf_marg = x_mnf_marg * SCALE
x_profiti = x_profiti * SCALE
x_profiti_marg = x_profiti_marg * SCALE

In [None]:
def set_lim_and_ticks(ax):
    ax.set_xlim(XLIM)
    ax.set_ylim(YLIM)
    ax.set_xticks(XTICKS)
    ax.set_yticks(YTICKS)
    ax.set_xticklabels(XTICKLABELS)
    ax.set_yticklabels(XTICKLABELS)


fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_orig, y_orig, s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "gt_toy_joint.pdf")
fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_mnf[:, 0], x_mnf[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "mnf_toy_joint.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_profiti[:, 0], x_profiti[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "profiti_toy_joint.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_gmix_1[:, 0], x_gmix_1[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "gmix_toy_joint.pdf")

In [None]:
def format_ax(ax):
    ax.set_xlim(XLIM)
    ax.set_ylim(PLIM)
    ax.set_xticks(XTICKS)
    ax.set_yticks(PTICKS)
    ax.set_xticklabels(XTICKLABELS)
    ax.set_yticklabels(PTICKLABELS)
    ax.set_ylabel("")
    ax.legend([], [], frameon=False)


nbins = 100

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
sns.kdeplot(x_orig, color="green", label="y1", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gt_toy_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
sns.kdeplot(y_orig, color="green", label="y2", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gt_toy_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(
    x_mnf_marg[:, 0], color="red", label="predicted marginal", fill=True, gridsize=GRID
)
sns.kdeplot(
    x_mnf[:, 0], color="blue", label="integrated marginal", fill=True, gridsize=GRID
)
format_ax(ax)
ax.legend([], [], frameon=False)

fig.savefig(PATH / "mnf_toy_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(
    x_mnf_marg[:, 1], color="red", label="predicted marginal", fill=True, gridsize=GRID
)
sns.kdeplot(
    x_mnf[:, 1], color="blue", label="integrated marginal", fill=True, gridsize=GRID
)
format_ax(ax)
ax.legend([], [], frameon=False)

fig.savefig(PATH / "mnf_toy_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_mnf[:, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_toy_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_mnf[:, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
ax.legend([], [], frameon=False)

fig.savefig(PATH / "mnf_toy_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_mnf_marg[:, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_toy_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_mnf_marg[:, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_toy_pred_y2.pdf")

In [None]:
fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(
    x_profiti_marg[:2000, 0],
    color="red",
    label="predicted marginal",
    fill=True,
    gridsize=GRID,
)
sns.kdeplot(
    x_profiti[:2000, 0],
    color="blue",
    label="integrated marginal",
    fill=True,
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(
    x_profiti_marg[:2000, 1],
    color="red",
    label="predicted marginal",
    fill=True,
    gridsize=GRID,
)
sns.kdeplot(
    x_profiti[:2000, 1],
    color="blue",
    label="integrated marginal",
    fill=True,
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_profiti[:2000, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_profiti[:2000, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_profiti_marg[:2000, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_profiti_marg[:2000, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_toy_pred_y2.pdf")

In [None]:
def format_ax(ax):
    ax.set_xlim(XLIM)
    ax.set_ylim(PLIM)
    ax.set_xticks(XTICKS)
    ax.set_yticks(PTICKS)
    ax.set_xticklabels(XTICKLABELS)
    ax.set_yticklabels(PTICKLABELS)
    ax.set_ylabel("")
    ax.legend([], [], frameon=False)


nbins = 100

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(
    x_gmix_1_marg[:2000, 0],
    color="red",
    label="predicted marginal",
    fill=True,
    gridsize=GRID,
)
sns.kdeplot(
    x_gmix_1_marg[:2000, 0],
    color="blue",
    label="integrated marginal",
    fill=True,
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(
    x_gmix_1_marg[:2000, 1],
    color="red",
    label="predicted marginal",
    fill=True,
    gridsize=GRID,
)
sns.kdeplot(
    x_gmix_1_marg[:2000, 1],
    color="blue",
    label="integrated marginal",
    fill=True,
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:2000, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:2000, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:2000, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:2000, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_toy_pred_y2.pdf")

In [None]:
fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_profiti_marg[:, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)