In [1]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import ot
import ot.plot
from ipywidgets import interactive, HBox, VBox, widgets, interact

In [2]:
def plot_ot_dist(fig, samples1, samples2, plt_name='ot.png'):
    fig.data = []
    fig.layout = {}

    num_samples = samples1.shape[0]
    a, b = np.ones((num_samples,)).astype(np.float32)/num_samples, np.ones((num_samples,)).astype(np.float32)/num_samples
    M = ot.dist(samples1, samples2)

    ot_dist = ot.emd2(a, b, M)
    G0 = ot.emd(a, b, M)

    shape_list = []
    for i in range(G0.shape[0]):
        mapping = G0[i, :].nonzero()[0][0]
        if len(G0[i, :].nonzero()[0]) == 0:
            continue
        start = (samples1[i, 0], samples1[i, 1])
        end = (samples2[mapping, 0], samples2[mapping, 1])

        shape_list.append(
            go.layout.Shape(
                type="line",
                x0=start[0],
                y0=start[1],
                x1=end[0],
                y1=end[1],
                line=dict(
                    color="rgba(0, 225, 0, 0.4)",
                    width=1
                )
            )
        )

    fig.update_layout(shapes=shape_list)

    fig.add_trace(go.Scatter(
        x=samples1[:, 0],
        y=samples1[:, 1],
        mode="markers",
        marker=go.scatter.Marker(
            color='rgb(255, 0, 0)',
            opacity=0.7, 
            size=7
        ), 
        name='Source'
    ))

    fig.add_trace(go.Scatter(
        x=samples2[:, 0],
        y=samples2[:, 1],
        mode="markers",
        marker=go.scatter.Marker(
            color='rgb(0, 0, 255)',
            opacity=0.7,
            size=7
        ), 
        name='Target'
    ))
    # fig.show()

    # ot.plot.plot2D_samples_mat(samples1, samples2, G0, c=[1, 1, 0.6])
    # fig.write_image(plt_name)

    return ot_dist

In [3]:
def generate_samples(alpha_src, alpha_tgt):
    n_modes = 3
    rad = 10
    angle_shift = 0.1
    n_samples = 300

    mean_src = []
    mean_tgt = []

    for mode in range(n_modes):
        mean_vec = [np.cos(np.pi * 2 * (float(mode) / n_modes)) * rad,
                    np.sin(np.pi * 2 * (float(mode) / n_modes)) * rad]
        mean_src.append(mean_vec)

        mean_vec = [np.cos(np.pi * 2 * (float(mode) / n_modes) + np.pi * angle_shift) * rad,
                    np.sin(np.pi * 2 * (float(mode) / n_modes) + np.pi * angle_shift) * rad]
        mean_tgt.append(mean_vec)

    samples_src = []
    samples_tgt = []
    cov = np.eye(2)

    mode_vec_src = [1] + [alpha_src ** i for i in range(1, n_modes)]
    mode_vec_tgt = [1] + [alpha_tgt ** i for i in range(1, n_modes)]
    mode_vec_tgt = np.roll(mode_vec_tgt, 1)
    mode_vec_src = np.array(mode_vec_src)
    mode_vec_tgt = np.array(mode_vec_tgt)
    mode_vec_src = mode_vec_src / np.sum(mode_vec_src)
    mode_vec_tgt = mode_vec_tgt / np.sum(mode_vec_tgt)
    

    for mode in range(n_modes):
        n_src = int(mode_vec_src[mode] * n_samples)
        n_tgt = int(mode_vec_tgt[mode] * n_samples)

        samples1 = np.random.multivariate_normal(mean_src[mode], cov, n_src)
        samples2 = np.random.multivariate_normal(mean_tgt[mode], cov, n_tgt)
        samples_src.append(samples1)
        samples_tgt.append(samples2)

    samples_src = np.vstack(samples_src)
    samples_tgt = np.vstack(samples_tgt)

    return samples_src, samples_tgt

In [4]:
# Base plot
alpha_src = 1
alpha_tgt = 1

# create FigureWidget from fig
fig = go.FigureWidget(data=[], layout={})

# Creating a slider

slider = widgets.FloatSlider(
    value=1,
    min=0,
    max=1.0,
    step=0.1,
    description='Imbalance factor',
    orientation='horizontal',
    readout=True,
    continuous_update=False,
    readout_format='.2f',
)

slider.layout.width = '800px'


# Function that will modify the OT plot
def update_plot(y):
    alpha_src = y
    alpha_tgt = y
    samples_src, samples_tgt = generate_samples(alpha_src, alpha_tgt)
    try:
        plot_ot_dist(fig, samples_src, samples_tgt)
    except:
        update_plot(y+0.0001)


# display the FigureWidget and slider with center justification
vb = VBox((fig, interactive(update_plot, y=slider)))
vb.layout.align_items = 'center'
vb

VBox(children=(FigureWidget({
    'data': [], 'layout': {'template': '...'}
}), interactive(children=(FloatSliâ€¦