# ICNN Dual Solver 

In this tutorial, we explore how to learn the solution of the Kantorovich dual based on parameterizing the two dual potentials $f$ and $g$ with two input convex neural networks ({class}`~ott.solvers.nn.icnn.ICNN`) {cite}`amos:17`, a method developed by {cite}`makkuva:20`. For more insights on the approach itself, we refer the user to the original publication.

Given dataloaders containing samples of the *source* and the *target* distribution, `OTT`'s {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` finds the pair of optimal potentials $f$ and $g$ to solve the corresponding dual of the optimal transport problem. Once a solution has been found, these neural {class}`~ott.problems.linear.potentials.DualPotentials` can be used to transport unseen source data samples to its target distribution (or vice-versa) or compute the corresponding distance between new source and target distribution.

In [None]:
import sys

if "google.colab" in sys.modules:
    !pip install -q git+https://github.com/ott-jax/ott@main

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
from torch.utils.data import DataLoader, IterableDataset

from functools import partial
from dataclasses import dataclass
import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.solvers.nn import icnn, neuraldual
from ott.tools import sinkhorn_divergencefrom ott.solvers.nn import neuraldual, icnn, mlp

from IPython.display import clear_output, display

## Helper Functions

Let us define some helper functions which we use for the subsequent analysis.

In [None]:
def plot_ot_map(neural_dual, source, target, inverse=False):
    """Plot data and learned optimal transport map."""

    def draw_arrows(a, b):
        plt.arrow(
            a[0], a[1], b[0] - a[0], b[1] - a[1], color=[0.5, 0.5, 1], alpha=0.3
        )

    grad_state_s = neural_dual.transport(source, forward=not inverse)

    fig = plt.figure(facecolor="white")
    ax = fig.add_subplot(111)

    if not inverse:
        ax.scatter(
            target[:, 0],
            target[:, 1],
            color="#A7BED3",
            alpha=0.5,
            label=r"$target$",
        )
        ax.scatter(
            source[:, 0],
            source[:, 1],
            color="#1A254B",
            alpha=0.5,
            label=r"$source$",
        )
        ax.scatter(
            grad_state_s[:, 0],
            grad_state_s[:, 1],
            color="#F2545B",
            alpha=0.5,
            label=r"$\nabla f(source)$",
        )
    else:
        ax.scatter(
            target[:, 0],
            target[:, 1],
            color="#A7BED3",
            alpha=0.5,
            label=r"$source$",
        )
        ax.scatter(
            source[:, 0],
            source[:, 1],
            color="#1A254B",
            alpha=0.5,
            label=r"$target$",
        )
        ax.scatter(
            grad_state_s[:, 0],
            grad_state_s[:, 1],
            color="#F2545B",
            alpha=0.5,
            label=r"$\nabla g(target)$",
        )

    fig.legend(ncol=3, loc="upper center")

    for i in range(source.shape[0]):
        draw_arrows(source[i, :], grad_state_s[i, :])
    return fig, ax


def plot_potential(neural_dual, source, target):
    fig, ax = plt.subplots(figsize=(6, 6), facecolor="white")
    x1 = np.linspace(-6, 6)
    x2 = np.linspace(-6, 6)
    X1, X2 = np.meshgrid(x1, x2)
    X1flat = np.ravel(X1)
    X2flat = np.ravel(X2)
    X12flat = np.stack((X1flat, X2flat)).T
    Zflat = neural_dual._f(X12flat)
    Z = np.array(Zflat.reshape(X1.shape))

    CS = ax.contourf(X1, X2, Z, cmap="Blues")
    fig.colorbar(CS, ax=ax)
    fig.tight_layout()
    ax.set_title(r"$f$")
    return fig, ax

In [None]:
def get_optimizer(optimizer, lr, b1, b2, eps):
    """Returns a flax optimizer object based on `config`."""

    if optimizer == "Adam":
        optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2, eps=eps)
    elif optimizer == "SGD":
        optimizer = optax.sgd(learning_rate=lr, momentum=None, nesterov=False)
    else:
        raise NotImplementedError(f"Optimizer {optimizer} not supported yet!")

    return optimizer


def get_optimizer_decay(init_lr, alpha, num_train_iter, b1, b2):
    lr_schedule = optax.cosine_decay_schedule(
        init_value=init_lr, decay_steps=num_train_iter, alpha=alpha
    )
    opt = optax.adamw(learning_rate=lr_schedule, b1=b1, b2=b2)
    return opt

In [None]:
@jax.jit
def sinkhorn_loss(x, y, epsilon=0.1):
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)

    sdiv = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b
    )
    return sdiv.divergence

## Setup Training and Validation Datasets

We apply the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` to compute the transport between toy datasets. In this tutorial, the user can choose between the datasets `simple` (data clustered in one center), `circle` (two-dimensional Gaussians arranged on a circle), `square_five` (two-dimensional Gaussians on a square with one Gaussian in the center), and `square_four` (two-dimensional Gaussians in the corners of a rectangle).

In [None]:
@dataclass
class ToyDataset:
    name: str
    batch_size: int
    init_seed: int

    def __iter__(self):
        return self.create_sample_generators()

    def create_sample_generators(self, scale=5.0, variance=0.5):
        # given name of dataset, select centers
        if self.name == "simple":
            centers = np.array([0, 0])

        elif self.name == "circle":
            centers = np.array(
                [
                    (1, 0),
                    (-1, 0),
                    (0, 1),
                    (0, -1),
                    (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                ]
            )

        elif self.name == "square_five":
            centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])

        elif self.name == "square_four":
            centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])

        else:
            raise NotImplementedError()

        # create generator which randomly picks center and adds noise
        centers = scale * centers
        key = jax.random.PRNGKey(self.init_seed)
        while True:
            k1, k2, key = jax.random.split(key, 3)
            sample_centers = jax.random.choice(k1, centers, [self.batch_size])
            samples = sample_centers + variance**2 * jax.random.normal(
                k2, [self.batch_size, 2]
            )
            yield samples


def load_toy_data(
    name_source: str,
    name_target: str,
    batch_size: int = 10000,
    valid_batch_size: int = 1000,
):
    dataloaders = (
        iter(ToyDataset(name_source, batch_size=batch_size, init_seed=0)),
        iter(ToyDataset(name_target, batch_size=batch_size, init_seed=1)),
        iter(ToyDataset(name_source, batch_size=valid_batch_size, init_seed=2)),
        iter(ToyDataset(name_target, batch_size=valid_batch_size, init_seed=3)),
    )
    input_dim = 2
    return dataloaders, input_dim

## Solve Neural Dual

In order to solve the neural dual, we need to define our dataloaders. The only requirement is that the corresponding source and target train and validation datasets are *iterators*.

In [None]:
(
    dataloader_source,
    dataloader_target,
    dataloader_source_eval,
    dataloader_target_eval,
), input_dim = load_toy_data("square_five", "square_four")

Next, we define the architectures parameterizing the dual potentials $f$ and $g$. These need to be parameterized by {class}`~ott.solvers.nn.icnn.ICNN`s. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can run the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` using approximations to this positivity constraint (via weight clipping and a weight penalization). For this, set `positive_weights` to `True` in both the ICNN architecture and {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` configuration. For more details on how to customize the {class}`~ott.solvers.nn.icnn.ICNN` architectures, we refer you to the documentation.

In [None]:
# initialize models
lrelu = partial(jax.nn.leaky_relu, negative_slope=0.2)
neural_f = icnn.ICNN(
    dim_data=2, dim_hidden=[128, 128], act_fn=lrelu, init_std=1.0
)
neural_g = mlp.PotentialGradientMLP(dim_hidden=[512, 512])

# initialize optimizers
optimizer_f = get_optimizer_decay(
    init_lr=5e-4, alpha=1e-4, num_train_iter=50000, b1=0.5, b2=0.5
)
optimizer_g = optimizer_f

We then initialize the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` by passing the two {class}`~ott.solvers.nn.icnn.ICNN` models parameterizing $f$ and $g$, as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` is initialized, we can obtain the neural {class}`~ott.problems.linear.potentials.DualPotentials` by passing the corresponding dataloaders to it. As here our training and validation datasets do not differ, we pass (`dataloader_source`, `dataloader_target`) for both training and validation steps. For more details on how to configure the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver`, we refer you to the documentation.

<font color='#F2545B'>Execution of the following cell might take up to 15 minutes per 5000 iterations (depending on your system and the number of training iterations.</font>

In [None]:
num_train_iters = 15000

# Sample an unseen batch.
eval_data_source = next(dataloader_source_eval)
eval_data_target = next(dataloader_target_eval)


def training_callback(step, neural_dual):
    # Callback function as the training progresses to visualize the couplings.
    if step % 1000 == 0:
        clear_output()
        print(f"Training iteration: {step}/{num_train_iters}")

        fig, ax = plot_ot_map(
            neural_dual, eval_data_source, eval_data_target, inverse=False
        )
        display(fig)
        plt.close(fig)

        fig, ax = plot_ot_map(
            neural_dual, eval_data_target, eval_data_source, inverse=True
        )
        display(fig)
        plt.close(fig)

        fig, ax = plot_potential(
            neural_dual, eval_data_target, eval_data_source
        )
        display(fig)
        plt.close(fig)


neural_dual_solver = neuraldual.NeuralDualSolver(
    input_dim,
    neural_f,
    neural_g,
    optimizer_f,
    optimizer_g,
    num_train_iters=num_train_iters,
    num_inner_iters=1,
    amortization_loss="regression",
    parallel_updates=True,
)
neural_dual = neural_dual_solver(
    dataloader_source,
    dataloader_target,
    dataloader_source,
    dataloader_target,
    callback=training_callback,
)

## Evaluate Neural Dual

After training has completed successfully, we can evaluate the neural {class}`~ott.problems.linear.potentials.DualPotentials` on unseen incoming data. We first sample a new batch from the source and target distribution.

Now, we can plot the corresponding transport from source to target using the gradient of the learning potential $g$, i.e., $\nabla g(\text{source})$, or from target to source via the gradient of the learning potential $f$, i.e., $\nabla f(\text{target})$.

In [None]:
plot_ot_map(neural_dual, eval_data_source, eval_data_target, inverse=False)

In [None]:
plot_ot_map(neural_dual, eval_data_target, eval_data_source, inverse=True)

We further test, how close the predicted samples are to the sampled data.

First for potential $g$, transporting source to target samples. Ideally the resulting {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance is close to $0$.

In [None]:
pred_target = neural_dual.transport(eval_data_source)
print(
    f"Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_target, eval_data_target)}"
)

Then for potential $f$, transporting target to source samples. Again, the resulting {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance needs to be close to $0$.

In [None]:
pred_source = neural_dual.transport(eval_data_target, forward=False)
print(
    f"Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_source, eval_data_source)}"
)

Besides computing the transport and mapping source to target samples or vice versa, we can also compute the overall distance between new source and target samples.

In [None]:
neural_dual_dist = neural_dual.distance(eval_data_source, eval_data_target)
print(
    f"Neural dual distance between source and target data: {neural_dual_dist}"
)

Which compares to the primal {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance in the following.

In [None]:
sinkhorn_dist = sinkhorn_loss(eval_data_source, eval_data_target)
print(f"Sinkhorn distance between source and target data: {sinkhorn_dist}")