In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from pathlib import Path
from scipy.stats import wasserstein_distance as WD
matplotlib.rcParams["text.usetex"] = True
matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["font.size"] = "35"

PATH = Path.cwd() / "fig" / "circle"
PATH.mkdir(exist_ok=True)

In [None]:
x_mnf = np.load("data/mnf_toy_circle_x2.npy")
nll = np.load("data/mnf_toy_circle_l2.npy")
x_mnf_marg = np.load("data/mnf_toy_circle_marg.npy")
x_profiti = np.load("data/profiti_toy_circle_x.npy")
x_profiti_marg = np.load("data/profiti_toy_circle_x_nc.npy")
x_gmix_1 = np.load("data/gmix_toy_circle_1_x.npy")
x_gmix_1_marg = np.load("data/gmix_toy_circle_1_marg.npy")
x_gmix_5 = np.load("data/gmix_toy_circle_5_x.npy")
x_gmix_10 = np.load("data/gmix_toy_circle_10_x.npy")
x_gmix_15 = np.load("data/gmix_toy_circle_15_x.npy")

In [None]:
MI_gau_mnf = 0.5*(WD(x_mnf[:,0], x_mnf_marg[:,0]) + WD(x_mnf[:,1],x_mnf_marg[:,1]))
MI_gau_profiti = 0.5*(WD(x_profiti[:,0], x_profiti_marg[:,0]) + WD(x_profiti[:,1],x_profiti_marg[:,1]))
MI_gau_gmix = 0.5*(WD(x_gmix_1[:,0], x_gmix_1_marg[:,0]) + WD(x_gmix_1[:,1],x_gmix_1_marg[:,1]))
print(MI_gau_mnf, MI_gau_profiti, MI_gau_gmix)

In [None]:
XLIM = np.array([-3, +3])
YLIM = np.array([-3, +3])
PLIM = np.array([0, +1.25])

XTICKS = np.array([-2, -1, 0, 1, 2])
YTICKS = np.array([-2, -1, 0, 1, 2])
PTICKS = np.array([0, 0.25, 0.5, 0.75, 1])

XTICKLABELS = []
YTICKLABELS = []
PTICKLABELS = []

FIGSIZE = (3, 3)
GRID = 1000  # gridsize for kde plot
NUM = None  # How many samples to plot (None=all)

nsams = 2000
x_ = np.random.uniform(-1, 1, nsams)
y = np.sqrt(1 - x_**2)
y[nsams // 2 :] *= -1
x_orig = np.concatenate([x_, y], 0) + np.random.randn(nsams * 2) * 0.05
y_orig = np.concatenate([y, x_], 0) + np.random.randn(nsams * 2) * 0.05

In [None]:
def set_lim_and_ticks(ax):
    ax.set_xlim(XLIM)
    ax.set_ylim(YLIM)
    ax.set_xticks(XTICKS)
    ax.set_yticks(YTICKS)
    ax.set_xticklabels(XTICKLABELS)
    ax.set_yticklabels(XTICKLABELS)


fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_orig, y_orig, s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "gt_circle_joint.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_mnf[:, 0], x_mnf[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "mnf_circle_joint.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_profiti[:, 0], x_profiti[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "profiti_circle_joint.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
ax.scatter(x_gmix_1[:, 0], x_gmix_1[:, 1], s=1, c="orange")
set_lim_and_ticks(ax)
fig.savefig(PATH / "gmix_circle_joint.pdf")

In [None]:
def format_ax(ax):
    ax.set_xlim(XLIM)
    ax.set_ylim(PLIM)
    ax.set_xticks(XTICKS)
    ax.set_yticks(PTICKS)
    ax.set_xticklabels(XTICKLABELS)
    ax.set_yticklabels(PTICKLABELS)
    ax.set_ylabel("")
    ax.legend([], [], frameon=False)


nbins = 100

fig, ax = fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
sns.kdeplot(x_orig, color="green", fill=True, label="y1")
format_ax(ax)
fig.savefig(PATH / "gt_circle_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
sns.kdeplot(y_orig, color="green", fill=True, label="y2")
format_ax(ax)
fig.savefig(PATH / "gt_circle_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_mnf_marg[:NUM, 0], color="red", fill=True, label="predicted marginal")
sns.kdeplot(x_mnf[:NUM, 0], color="blue", fill=True, label="integrated marginal")
format_ax(ax)
fig.savefig(PATH / "mnf_circle_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_mnf_marg[:NUM, 1], color="red", fill=True, label="predicted marginal")
sns.kdeplot(x_mnf[:NUM, 1], color="blue", fill=True, label="integrated marginal")
format_ax(ax)
fig.savefig(PATH / "mnf_circle_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_mnf[:NUM, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_circle_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_mnf[:NUM, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_circle_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_mnf_marg[:NUM, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_circle_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_mnf_marg[:NUM, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "mnf_circle_pred_y2.pdf")

In [None]:
fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(
    x_profiti_marg[:NUM, 0],
    color="red",
    fill=True,
    label="predicted marginal",
    gridsize=GRID,
)
sns.kdeplot(
    x_profiti[:NUM, 0],
    color="blue",
    fill=True,
    label="integrated marginal",
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(
    x_profiti_marg[:NUM, 1],
    color="red",
    fill=True,
    label="predicted marginal",
    gridsize=GRID,
)
sns.kdeplot(
    x_profiti[:NUM, 1],
    color="blue",
    fill=True,
    label="integrated marginal",
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_profiti[:NUM, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_profiti[:NUM, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_profiti_marg[:NUM, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_profiti_marg[:NUM, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "profiti_circle_pred_y2.pdf")

In [None]:
fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(
    x_gmix_1_marg[:NUM, 0],
    color="red",
    fill=True,
    label="predicted marginal",
    gridsize=GRID,
)
sns.kdeplot(
    x_gmix_1_marg[:NUM, 0],
    color="blue",
    fill=True,
    label="integrated marginal",
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_y1_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(
    x_gmix_1_marg[:NUM, 1],
    color="red",
    fill=True,
    label="predicted marginal",
    gridsize=GRID,
)
sns.kdeplot(
    x_gmix_1_marg[:NUM, 1],
    color="blue",
    fill=True,
    label="integrated marginal",
    gridsize=GRID,
)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_y2_both.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:NUM, 0], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_int_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:NUM, 1], color="blue", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_int_y2.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:NUM, 0], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_pred_y1.pdf")

fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);
sns.kdeplot(x_gmix_1_marg[:NUM, 1], color="red", fill=True, gridsize=GRID)
format_ax(ax)
fig.savefig(PATH / "gmix_circle_pred_y2.pdf")

In [None]:
# gmix
fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_orig, y_orig, s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "true_dist_circle.pdf")

fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_mnf[:4000, 0], x_mnf[:4000, 1], s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "mymodel_circle.pdf")

fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_gmix_1[:4000, 0], x_gmix_1[:4000, 1], s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "g_mix_1_circle.pdf")

fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_gmix_5[:4000, 0], x_gmix_5[:4000, 1], s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "g_mix_5_circle.pdf")

fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_gmix_10[:4000, 0], x_gmix_10[:4000, 1], s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "g_mix_10_circle.pdf")

fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)
ax.scatter(x_gmix_15[:4000, 0], x_gmix_15[:4000, 1], s=1, c="orange")
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_xticks([], [])
ax.set_yticks([], [])
fig.savefig(PATH / "g_mix_15_circle.pdf")