# Gaussian generator visualization (Plotly slider)

This notebook samples ~50 Gaussian-mixture problems and uses a Plotly slider to inspect each dataset.

In [1]:
import sys, os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
try:
    from jax import config
    config.update("jax_enable_x64", True)
except Exception:
    pass

sys.path.insert(0, os.path.abspath(os.path.join("..")))

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from uot.problems.generators import GaussianMixtureGenerator
from uot.utils.costs import cost_euclid_squared


## 1D case

Slider for 1D Gaussian-mixture curves.

In [2]:
# 1D generator settings
NUM_DATASETS_1D = 50
N_POINTS_1D = 50
DIM_1D = 1
NUM_COMPONENTS_1D = 3
BORDERS_1D = (0.0, 1.0)

gen_1d = GaussianMixtureGenerator(
    name="Gaussian (1D)",
    dim=DIM_1D,
    num_components=NUM_COMPONENTS_1D,
    n_points=N_POINTS_1D,
    num_datasets=NUM_DATASETS_1D,
    borders=BORDERS_1D,
    cost_fn=cost_euclid_squared,
    use_jax=False,
    measure_mode="grid",
)

problems_1d = list(gen_1d.generate())
if not problems_1d:
    raise RuntimeError("No 1D problems generated.")

mu0_1d, nu0_1d = problems_1d[0].get_marginals()
x_1d, _ = mu0_1d.to_discrete()
x_1d = np.asarray(x_1d).reshape(-1)

mu_grids_1d = []
nu_grids_1d = []
for p in problems_1d:
    mu, nu = p.get_marginals()
    _, mu_w = mu.to_discrete()
    _, nu_w = nu.to_discrete()
    mu_grids_1d.append(np.asarray(mu_w).reshape(-1))
    nu_grids_1d.append(np.asarray(nu_w).reshape(-1))

ymax_1d = max(max(float(g.max()) for g in mu_grids_1d), max(float(g.max()) for g in nu_grids_1d))

fig_1d = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Source (mu)", "Target (nu)"),
)

fig_1d.add_trace(
    go.Scatter(
        x=x_1d,
        y=mu_grids_1d[0],
        mode="lines",
        line=dict(color="#1f77b4"),
    ),
    row=1,
    col=1,
)
fig_1d.add_trace(
    go.Scatter(
        x=x_1d,
        y=nu_grids_1d[0],
        mode="lines",
        line=dict(color="#ff7f0e"),
    ),
    row=1,
    col=2,
)

frames_1d = []
for idx, (mu_grid, nu_grid) in enumerate(zip(mu_grids_1d, nu_grids_1d)):
    frames_1d.append(
        go.Frame(
            data=[
                go.Scatter(
                    x=x_1d,
                    y=mu_grid,
                    mode="lines",
                    line=dict(color="#1f77b4"),
                ),
                go.Scatter(
                    x=x_1d,
                    y=nu_grid,
                    mode="lines",
                    line=dict(color="#ff7f0e"),
                ),
            ],
            name=str(idx),
        )
    )

fig_1d.frames = frames_1d

steps_1d = []
for idx in range(len(frames_1d)):
    steps_1d.append(
        dict(
            method="animate",
            args=[[str(idx)], {"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}}],
            label=str(idx),
        )
    )

fig_1d.update_layout(
    title=f"Gaussian mixtures on a 1D grid ({NUM_DATASETS_1D} datasets)",
    sliders=[
        dict(
            steps=steps_1d,
            currentvalue=dict(prefix="dataset: "),
            pad=dict(t=40),
        )
    ],
    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            buttons=[
                dict(label="Play", method="animate", args=[None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True}]),
                dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
            ],
            x=0.1,
            y=-0.1,
        )
    ],
    yaxis=dict(range=[0, ymax_1d * 1.05]),
    yaxis2=dict(range=[0, ymax_1d * 1.05]),
)

fig_1d.update_xaxes(title_text="x")
fig_1d.update_yaxes(title_text="density")

fig_1d.show()


## 2D case

Slider for 2D Gaussian-mixture grids (heatmaps).

In [5]:
# 2D generator settings
NUM_DATASETS_2D = 50
N_POINTS_2D = 50
DIM_2D = 2
NUM_COMPONENTS_2D = 3
BORDERS_2D = (0.0, 1.0)

gen_2d = GaussianMixtureGenerator(
    name="Gaussian (2D)",
    dim=DIM_2D,
    num_components=NUM_COMPONENTS_2D,
    n_points=N_POINTS_2D,
    num_datasets=NUM_DATASETS_2D,
    borders=BORDERS_2D,
    cost_fn=cost_euclid_squared,
    use_jax=False,
    measure_mode="grid",
)

problems_2d = list(gen_2d.generate())
if not problems_2d:
    raise RuntimeError("No 2D problems generated.")

mu0_2d, nu0_2d = problems_2d[0].get_marginals()
points_2d, _ = mu0_2d.to_discrete(include_zeros=True)
points_2d = np.asarray(points_2d)
x_vals = np.unique(points_2d[:, 0])
y_vals = np.unique(points_2d[:, 1])

mu_grids_2d = []
nu_grids_2d = []
for p in problems_2d:
    mu, nu = p.get_marginals()
    _, mu_w = mu.to_discrete(include_zeros=True)
    _, nu_w = nu.to_discrete(include_zeros=True)
    mu_grids_2d.append(np.asarray(mu_w).reshape(len(x_vals), len(y_vals)))
    nu_grids_2d.append(np.asarray(nu_w).reshape(len(x_vals), len(y_vals)))

zmax_2d = max(max(float(g.max()) for g in mu_grids_2d), max(float(g.max()) for g in nu_grids_2d))

fig_2d = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Source (mu)", "Target (nu)"),
)

fig_2d.add_trace(
    go.Heatmap(
        x=x_vals,
        y=y_vals,
        z=mu_grids_2d[0].T,
        zmin=0,
        zmax=zmax_2d,
        colorscale="Viridis",
        colorbar=dict(title="density"),
    ),
    row=1,
    col=1,
)
fig_2d.add_trace(
    go.Heatmap(
        x=x_vals,
        y=y_vals,
        z=nu_grids_2d[0].T,
        zmin=0,
        zmax=zmax_2d,
        colorscale="Viridis",
        showscale=False,
    ),
    row=1,
    col=2,
)

frames_2d = []
for idx, (mu_grid, nu_grid) in enumerate(zip(mu_grids_2d, nu_grids_2d)):
    frames_2d.append(
        go.Frame(
            data=[
                go.Heatmap(
                    x=x_vals,
                    y=y_vals,
                    z=mu_grid.T,
                    zmin=0,
                    zmax=zmax_2d,
                    colorscale="Viridis",
                ),
                go.Heatmap(
                    x=x_vals,
                    y=y_vals,
                    z=nu_grid.T,
                    zmin=0,
                    zmax=zmax_2d,
                    colorscale="Viridis",
                ),
            ],
            name=str(idx),
        )
    )

fig_2d.frames = frames_2d

steps_2d = []
for idx in range(len(frames_2d)):
    steps_2d.append(
        dict(
            method="animate",
            args=[[str(idx)], {"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}}],
            label=str(idx),
        )
    )

fig_2d.update_layout(
    title=f"Gaussian mixtures on a {N_POINTS_2D}x{N_POINTS_2D} grid ({NUM_DATASETS_2D} datasets)",
    sliders=[
        dict(
            steps=steps_2d,
            currentvalue=dict(prefix="dataset: "),
            pad=dict(t=40),
        )
    ],
    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            buttons=[
                dict(label="Play", method="animate", args=[None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True}]),
                dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
            ],
            x=0.1,
            y=-0.1,
        )
    ],
)

fig_2d.update_xaxes(title_text="x")
fig_2d.update_yaxes(title_text="y", scaleanchor="x", scaleratio=1)

fig_2d.show()
