In [11]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from IPython import display

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
from ott.tools.sinkhorn_divergence import SinkhornDivergenceOutput
from ott.solvers.linear import implicit_differentiation as imp_diff
import equinox as eqx

In [116]:


def sink_div(combined, states, y, b, key) -> tuple[float, float]:

    agent_value, agent_policy = combined
    z_dist = eqx.filter_vmap(agent_policy)(states)
    z, log_prob = z_dist.sample_and_log_prob(seed=key)

    geom = pointcloud.PointCloud(z, y, epsilon=0.001)
    
    a = eqx.filter_vmap(agent_value)(states, z).squeeze()
    an = jax.nn.softplus(a - jnp.quantile(a, 0.01)) 
    bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
  
    adv = jax.lax.stop_gradient(a - jnp.quantile(a, 0.1))
    policy_loss = -(log_prob.squeeze() * adv).mean() 
            
    an = an / an.sum()
    bn = bn / bn.sum()
    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        a=an,
        b=bn,
        y=geom.y,
        static_b=True,
        sinkhorn_kwargs={
            "implicit_diff": imp_diff.ImplicitDiff(),
            "use_danskin": True,
            "threshold": 1e-4,
            "max_iterations": 2000
        },
    )
    return ot.divergence * 10 + policy_loss, (-log_prob.squeeze()).min()

In [13]:
import equinox as eqx
import equinox.nn as eqxnn

class MonolithicVF_EQX(eqx.Module):
    net: eqx.Module
    
    def __init__(self, key, state_dim, intents_dim, hidden_dims):
        key, mlp_key = jax.random.split(key, 2)
        self.net = eqxnn.MLP(
            in_size=state_dim + intents_dim, out_size=1, width_size=hidden_dims[-1], depth=len(hidden_dims), key=mlp_key
        )
        
    def __call__(self, observations, intents):
        # TODO: Maybe try FiLM conditioning like in SAC-RND?
        conditioning = jnp.concatenate((observations, intents), axis=-1)
        return self.net(conditioning)

In [73]:
from typing import Any
import distrax
from jax import Array
from jaxtyping import PyTree


class FixedDistrax(eqx.Module):
    cls: type
    args: PyTree[Any]
    kwargs: PyTree[Any]

    def __init__(self, cls, *args, **kwargs):
        self.cls = cls
        self.args = args
        self.kwargs = kwargs

    def sample_and_log_prob(self, *, seed):
        return self.cls(*self.args, **self.kwargs).sample_and_log_prob(seed=seed)

    def sample(self, *, seed):
        return self.cls(*self.args, **self.kwargs).sample(seed=seed)

    def log_prob(self, x):
        return self.cls(*self.args, **self.kwargs).log_prob(x)

    def mean(self):
        return self.cls(*self.args, **self.kwargs).mean()



class ND(eqx.Module):

    loc: jax.Array
    scale_diag: jax.Array
    
    def __init__(self, loc, scale_diag):
        self.loc = loc
        self.scale_diag = scale_diag
        
    def sample_and_log_prob(self, seed):
        s = self.sample(seed)
        return s, self.log_prob(s)

    def sample(self, seed):
        loc = self.loc
        e = jax.random.normal(seed, loc.shape)
        return loc + self.scale_diag * e 

    def log_prob(self, x):
        return jax.scipy.stats.norm.logpdf(x, self.loc, self.scale_diag).sum(-1)



class GaussianIntentPolicy(eqx.Module):
    net: eqx.Module
    
    log_std_min: int = -10.0
    log_std_max: int = 2.0
    temperature: float = 10.0
    
    def __init__(self, key, state_dim, intent_dim, hidden_dims):
        key, key_means, key_log_std = jax.random.split(key, 3)
        
        self.net = eqx.nn.MLP(in_size=state_dim,
                              out_size=2 * intent_dim,
                              width_size=hidden_dims[0],
                              depth=len(hidden_dims),
                              key=key_means)
        
    def __call__(self, state):
        means, log_std = jnp.split(self.net(state), 2)
        log_stds = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        dist = ND(loc=means, scale_diag=jnp.exp(log_stds))
        return dist

In [74]:
key = jax.random.PRNGKey(1)
model = GaussianIntentPolicy(key=key,
                             hidden_dims=[64, 64, 64],
                             state_dim=2,
                             intent_dim=2)

x = 0.25 * jax.random.normal(key, (250, 2)) 
z_dist = eqx.filter_vmap(model)(x)
z, log_prob = z_dist.sample_and_log_prob(seed=key)

print(log_prob.shape)

(250,)


In [16]:
import optax
import dataclasses

class TrainStateEQX(eqx.Module):
    model: eqx.Module
    optim: optax.GradientTransformation
    optim_state: optax.OptState

    @classmethod
    def create(cls, *, model, optim, **kwargs):
        optim_state = optim.init(eqx.filter(model, eqx.is_array))
        return cls(model=model, optim=optim, optim_state=optim_state,
                   **kwargs)
    
    @eqx.filter_jit
    def apply_updates(self, grads):
        updates, new_optim_state = self.optim.update(grads, self.optim_state, self.model)
        new_model = eqx.apply_updates(self.model, updates)
        return dataclasses.replace(
            self,
            model=new_model,
            optim_state=new_optim_state
        )

In [113]:

def gradient_flow(
    x: jnp.ndarray,
    y: jnp.ndarray,
    b,
    cost_fn: callable,
    num_iter: int = 5000,
    dump_every: int = 50
):
    """Compute a gradient flow."""

    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn, has_aux=True))

    key = jax.random.PRNGKey(1)
    V = TrainStateEQX.create(
        model=MonolithicVF_EQX(key, 2, 2, [64, 64, 64]),
        optim=optax.adam(learning_rate=1e-4)
    )

    key, pkey = jax.random.split(key, 2)

    policy = TrainStateEQX.create(
        model=GaussianIntentPolicy(key=pkey,
                             hidden_dims=[64, 64, 64],
                             state_dim=2,
                             intent_dim=2),
        optim=optax.adam(learning_rate=1e-4)
    )

    # opt = optax.adam(1e-3)
    # opt_state = opt.init(x)

    for i in range(0, num_iter + 1):
        # geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
        key, key_6 = jax.random.split(key, 2)
        y = 0.5 * jax.random.normal(key_6, (400, 2)) + jnp.array((1, 0)) 

        (cost, pmin), (value_grads, policy_grads) = cost_fn_vg((V.model, policy.model), x, y, b, key_6)
        # assert ot.converged[0]
        # assert ot.converged[1]
        # assert ot.converged[2]

        V = V.apply_updates(value_grads)
        policy = policy.apply_updates(policy_grads)

        if i % dump_every == 0:
            z = eqx.filter_vmap(policy.model)(x).sample(seed=key_6)
            
            a = eqx.filter_vmap(V.model)(x, z).squeeze()
            an = jax.nn.softplus(a - jnp.quantile(a, 0.01))
            bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
            an = an / an.sum()
            bn = bn / bn.sum()

            geom = pointcloud.PointCloud(z, y, epsilon=0.001)
            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom, a = an, b = bn)).reg_ot_cost
            # print(a.min(), a.mean(), a.max())
            # print(an.min(), an.mean(), an.max())
            print(cost, diff, pmin)
            print()

    return policy.model, V.model
    

In [117]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)


x = 0.25 * jax.random.normal(key1, (100, 2))  # Source
y = 0.5 * jax.random.normal(key2, (400, 2)) + jnp.array((1, 0))  # Target
# y = x

marginal_b = jax.random.normal(key2, shape=(400, ))


policy_model, V_model = gradient_flow(x, y, marginal_b, cost_fn=sink_div)

3.8447678 1.7022855 1.8073897

2.3665216 0.9446982 1.6229136

1.2102695 0.36669117 1.4218076

1.066475 0.24136572 1.0982442

0.64712954 0.06962061 0.6963582

0.50382584 0.04013389 0.3792808

0.44154263 0.07119808 0.11715929

0.40440542 0.058285195 -0.11411727

0.40213943 0.07540182 -0.093509756

0.36885267 0.069088705 -0.26413527

0.34937504 0.07743274 -0.23358262

0.31983906 0.0620846 -0.17591989

0.3751124 0.08339907 -0.24904776

0.34582102 0.048021846 -0.18516405

0.33626342 0.060669906 -0.18174604

0.3252278 0.04566594 -0.2349845

0.33356822 0.05300928 -0.18948251

0.43610442 0.05194739 -0.083176

0.3405601 0.07143687 -0.21454926

0.30168957 0.051284727 -0.08499307

0.37063336 0.048987504 -0.07419868

0.4459473 0.06584742 -0.32289675

0.37267342 0.063605405 -0.18544775

0.32874697 0.060184576 -0.21829103

0.36131996 0.06786085 -0.16509394

0.32153884 0.053866267 -0.19421567

0.33107182 0.054082144 -0.13586485

0.30109733 0.047443252 -0.18800749

0.2831055 0.05526406 -0.17613503

0.

KeyboardInterrupt: 

In [115]:
key = jax.random.PRNGKey(0)
z_dist = eqx.filter_vmap(policy_model)(x)
z, log_prob = z_dist.sample_and_log_prob(seed=key)

print(z_dist.scale_diag)
print(z_dist.loc)
print((-log_prob).min())


[[0.4711008  0.4992154 ]
 [0.4683991  0.48698327]
 [0.4251204  0.43983382]
 [0.45871246 0.47339636]
 [0.4930999  0.52430993]
 [0.47805205 0.4995513 ]
 [0.4481231  0.47100344]
 [0.4740806  0.48846596]
 [0.38585225 0.39462364]
 [0.4217649  0.43729648]
 [0.47485846 0.48892057]
 [0.47473437 0.49206588]
 [0.48301592 0.5143294 ]
 [0.4842028  0.5067877 ]
 [0.48769727 0.5021746 ]
 [0.44983792 0.46433115]
 [0.40839612 0.43259528]
 [0.50427204 0.5315879 ]
 [0.42929387 0.43758118]
 [0.45823327 0.48484293]
 [0.49506098 0.530093  ]
 [0.4719749  0.49180138]
 [0.5082524  0.52769434]
 [0.5517867  0.5805801 ]
 [0.41023695 0.42235565]
 [0.4008704  0.40957537]
 [0.45612365 0.4792257 ]
 [0.45801425 0.4817824 ]
 [0.5184257  0.5452121 ]
 [0.47725558 0.4950465 ]
 [0.3879591  0.41104648]
 [0.47053665 0.48572096]
 [0.50541854 0.5275797 ]
 [0.5170749  0.5362931 ]
 [0.42050987 0.43135056]
 [0.49313495 0.5154984 ]
 [0.45125967 0.4665445 ]
 [0.44599363 0.4579235 ]
 [0.4721431  0.48905343]
 [0.5322167  0.55695313]
