# Learning the Kantorovich Dual using Input Convex Neural Networks

In [1]:
cd ../..

/Users/bunnech/Documents/PhD/Projects/ott


In [2]:
import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from ott.geometry import pointcloud
from ott.core.neuraldual import NeuralDualSolver, NeuralDual



## Helper Functions

In [3]:
def plot_ot_map(state, source, target):
    """Plot data and learned optimal transport map."""

    def draw_arrows(a, b):
        plt.arrow(a[0], a[1], b[0] - a[0], b[1] - a[1],
                  color=[0.5, 0.5, 1], alpha=0.3)

    plt.close('all')

    grad_state_s = jax.vmap(lambda x: jax.grad(state.apply_fn, argnums=1)(
            {'params': state.params}, x))(source)

    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.scatter(target[:, 0], target[:, 1], color='#A7BED3',
               alpha=0.5, label=r'$x$')
    ax.scatter(source[:, 0], source[:, 1], color='#1A254B',
               alpha=0.5, label=r'$y$')

    ax.scatter(grad_state_s[:, 0], grad_state_s[:, 1], color='#F2545B',
               alpha=0.5, label=r'$\nabla g(y)$')

    plt.legend()

    for i in range(source.shape[0]):
        draw_arrows(source[i, :], grad_state_s[i, :])


In [4]:
@jax.jit
def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)

    sdiv = sinkhorn_divergence(pointcloud.PointCloud, x, y, power=power,
                               epsilon=epsilon, a=a, b=b, static_b=False,
                               sinkhorn_kwargs={'threshold': 1e-2,
                                                'tau_a': 0.95, 'tau_b': 0.95})
    return sdiv.divergence

## Setup Training and Validation Datasets

In [5]:
class ToyDataset(IterableDataset):
    def __init__(self, name):
        self.name = name

    def __iter__(self):
        return self.create_sample_generators()

    def create_sample_generators(self, scale=5.0, variance=0.5):
        # given name of dataset, select centers
        if self.name == "simple":
            centers = np.array([0, 0])

        elif self.name == "circle":
            centers = np.array(
                [
                    (1, 0),
                    (-1, 0),
                    (0, 1),
                    (0, -1),
                    (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                ]
            )

        elif self.name == "square_five":
            centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])

        elif self.name == "square_four":
            centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])

        else:
            raise NotImplementedError()

        # create generator which randomly picks center and adds noise
        centers = scale * centers
        while True:
            center = centers[np.random.choice(len(centers))]
            point = center + variance**2 * np.random.randn(2)

            yield point


def load_toy_data(name_source: str,
                  name_target: str,
                  batch_size: int = 256,
                  valid_batch_size: int = 1024):
    dataloaders = (
      iter(DataLoader(ToyDataset(name_source), batch_size=batch_size)),
      iter(DataLoader(ToyDataset(name_target), batch_size=batch_size)),
      iter(DataLoader(ToyDataset(name_source), batch_size=valid_batch_size)),
      iter(DataLoader(ToyDataset(name_target), batch_size=valid_batch_size)),
    )
    input_dim = 2
    return dataloaders, input_dim

## Solve Neural Dual

In [6]:
(dataloader_source, dataloader_target, _, _), input_dim = load_toy_data('simple', 'circle')

In [None]:
neural_dual_solver = NeuralDualSolver(
    input_dim=input_dim, num_train_iters=100000)
neural_dual = neural_dual_solver(
    dataloader_source, dataloader_target, dataloader_source, dataloader_target)

 47%|█████████████████▊                    | 46957/100000 [45:52<52:27, 16.85it/s]

## Evaluate Neural Dual

In [None]:
data_source = next(dataloader_source).numpy()
data_target = next(dataloader_target).numpy()

In [None]:
plot_ot_map(neural_dual.g, data_source, data_target)

In [None]:
pred_target = neural_dual.transport(data_source, 'g')
print(f'Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_target, data_target)}')

In [None]:
pred_source = neural_dual.transport(data_target, 'f')
print(f'Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_source, data_source)}')

In [None]:
dual_dist = neural_dual.distance(data_source, data_target)
print(f'Neural dual distance between source and target data: {dual_dist}')