In [1]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from IPython import display

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
from ott.tools.sinkhorn_divergence import SinkhornDivergenceOutput
from ott.solvers.linear import implicit_differentiation as imp_diff
import equinox as eqx

In [58]:


def sink_div(combined, states, y, b, key) -> tuple[float, float]:
    # y - intentions of expert
    agent_value, agent_policy = combined
    z_dist = eqx.filter_vmap(agent_policy)(states)
    z, log_prob = z_dist.sample_and_log_prob(seed=key) # intentions of agent

    geom = pointcloud.PointCloud(z, y, epsilon=0.001)
    
    a = eqx.filter_vmap(agent_value)(states, z).squeeze() # weights for intents of agent
    an = jax.nn.softplus(a - jnp.quantile(a, 0.01)) 
    bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
        

    an = an / an.sum()
    bn = bn / bn.sum()
    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        a=an,
        b=bn,
        y=geom.y,
        static_b=True,
        sinkhorn_kwargs={
            "implicit_diff": imp_diff.ImplicitDiff(),
            "use_danskin": True,
            "threshold": 1e-4,
            "max_iterations": 2000
        },
    )
    return ot.divergence, (-log_prob.squeeze()).min()

In [39]:
import equinox as eqx
import equinox.nn as eqxnn

class MonolithicVF_EQX(eqx.Module):
    net: eqx.Module
    
    def __init__(self, key, state_dim, intents_dim, hidden_dims):
        key, mlp_key = jax.random.split(key, 2)
        self.net = eqxnn.MLP(
            in_size=state_dim + intents_dim, out_size=1, width_size=hidden_dims[-1], depth=len(hidden_dims), key=mlp_key
        )
        
    def __call__(self, observations, intents):
        # TODO: Maybe try FiLM conditioning like in SAC-RND?
        conditioning = jnp.concatenate([observations, intents], axis=-1)
        return self.net(conditioning)

In [41]:
from typing import Any
import distrax
from jax import Array
from jaxtyping import PyTree


class FixedDistrax(eqx.Module):
    cls: type
    args: PyTree[Any]
    kwargs: PyTree[Any]

    def __init__(self, cls, *args, **kwargs):
        self.cls = cls
        self.args = args
        self.kwargs = kwargs

    def sample_and_log_prob(self, *, seed):
        return self.cls(*self.args, **self.kwargs).sample_and_log_prob(seed=seed)

    def sample(self, *, seed):
        return self.cls(*self.args, **self.kwargs).sample(seed=seed)

    def log_prob(self, x):
        return self.cls(*self.args, **self.kwargs).log_prob(x)

    def mean(self):
        return self.cls(*self.args, **self.kwargs).mean()



class ND(eqx.Module):

    loc: jax.Array
    scale_diag: jax.Array
    
    def __init__(self, loc, scale_diag):
        self.loc = loc
        self.scale_diag = scale_diag
        
    def sample_and_log_prob(self, seed):
        s = self.sample(seed)
        return s, self.log_prob(s)

    def sample(self, seed):
        loc = self.loc
        e = jax.random.normal(seed, loc.shape)
        return loc + self.scale_diag * e 

    def log_prob(self, x):
        return jax.scipy.stats.norm.logpdf(x, self.loc, self.scale_diag).sum(-1)



class GaussianIntentPolicy(eqx.Module):
    net: eqx.Module
    
    log_std_min: int = -5.0
    log_std_max: int = 2.0
    temperature: float = 10.0
    
    def __init__(self, key, state_dim, intent_dim, hidden_dims):
        key, key_means, key_log_std = jax.random.split(key, 3)
        
        self.net = eqx.nn.MLP(in_size=state_dim,
                              out_size=2 * intent_dim,
                              width_size=hidden_dims[0],
                              depth=len(hidden_dims),
                              key=key_means)
        
    def __call__(self, state):
        means, log_std = jnp.split(self.net(state), 2)
        log_stds = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        # dist = FixedDistrax(distrax.MultivariateNormalDiag, loc=means,
        #                     scale_diag=jnp.exp(log_stds)) #ND(loc=means, scale_diag=jnp.exp(log_stds))
        dist = ND(loc=means, scale_diag=jnp.exp(log_stds))
        return dist

In [5]:
key = jax.random.PRNGKey(1)
model = GaussianIntentPolicy(key=key,
                             hidden_dims=[64, 64, 64],
                             state_dim=2,
                             intent_dim=2)

x = 0.25 * jax.random.normal(key, (250, 2)) 
z_dist = eqx.filter_vmap(model)(x)
z, log_prob = z_dist.sample_and_log_prob(seed=key)

print(log_prob.shape)

2023-11-10 23:24:40.515521: W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-bio-dna-20106e7d-5766-609d21c28e80d, line 10; fatal   : Unsupported .version 7.8; current version is '7.7'
ptxas fatal   : Ptx assembly aborted due to errors

Relying on driver to perform ptx compilation. 
Setting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda  or modifying $PATH can be used to set the location of ptxas
This message will only be logged once.


(250,)


In [6]:
import optax
import dataclasses

class TrainStateEQX(eqx.Module):
    model: eqx.Module
    optim: optax.GradientTransformation
    optim_state: optax.OptState

    @classmethod
    def create(cls, *, model, optim, **kwargs):
        optim_state = optim.init(eqx.filter(model, eqx.is_array))
        return cls(model=model, optim=optim, optim_state=optim_state,
                   **kwargs)
    
    @eqx.filter_jit
    def apply_updates(self, grads):
        updates, new_optim_state = self.optim.update(grads, self.optim_state, self.model)
        new_model = eqx.apply_updates(self.model, updates)
        return dataclasses.replace(
            self,
            model=new_model,
            optim_state=new_optim_state
        )

In [71]:

def gradient_flow(
    x: jnp.ndarray,
    y: jnp.ndarray,
    b,
    cost_fn: callable,
    num_iter: int = 6000,
    dump_every: int = 50
):
    """Compute a gradient flow."""

    def v_loss(agent_policy, agent_value, states, key) -> float:
        z_dist = eqx.filter_vmap(agent_policy)(states)
        z, _ = z_dist.sample_and_log_prob(seed=key)
        v = eqx.filter_vmap(agent_value)(states, z).squeeze()
        return -v.mean() * 0.1

    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn, has_aux=True))
    v_loss_vg = eqx.filter_jit(eqx.filter_value_and_grad(v_loss, has_aux=False))

    key = jax.random.PRNGKey(42)
    V = TrainStateEQX.create(
        model=MonolithicVF_EQX(key, 2, 2, [128, 128, 128]),
        optim=optax.adam(learning_rate=3e-4)
    )

    key, pkey = jax.random.split(key, 2)

    policy = TrainStateEQX.create(
        model=GaussianIntentPolicy(key=pkey,
                             hidden_dims=[128, 128, 128],
                             state_dim=2,
                             intent_dim=2),
        optim=optax.adam(learning_rate=3e-4)
    )
    for i in range(0, num_iter + 1):
        key, key_6 = jax.random.split(key, 2)

        (cost, pmin), (value_grads, policy_grads) = cost_fn_vg((V.model, policy.model), x, y, b, key_6)
        v_loss, policy_grads_2 = v_loss_vg(policy.model, V.model, x, key_6)
       
        V = V.apply_updates(value_grads)
        policy_grads = jax.tree_map(lambda g1, g2: g1 + g2, policy_grads, policy_grads_2)
        policy = policy.apply_updates(policy_grads)

        if i % dump_every == 0:
            z = eqx.filter_vmap(policy.model)(x).sample(seed=key_6)
            a = eqx.filter_vmap(V.model)(x, z).squeeze()
            
            an = jax.nn.softplus(a - jnp.quantile(a, 0.01))
            bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
            an = an / an.sum()
            bn = bn / bn.sum()

            geom = pointcloud.PointCloud(z, y, epsilon=0.001)
            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom, a = an, b = bn)).reg_ot_cost
            # print(a.min(), a.mean(), a.max())
            # print(an.min(), an.mean(), an.max())
            print(cost, diff, pmin)
            print()

    return policy.model, V.model
    

In [72]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)


x = 0.25 * jax.random.normal(key1, (100, 2))  # Source
y = 0.5 * jax.random.normal(key2, (400, 2)) + jnp.array((1, 0))  # Target

marginal_b = jax.random.normal(key2, shape=(400, ))
policy_model, V_model = gradient_flow(x, y, marginal_b, cost_fn=sink_div)

2.074768 1.8227637 1.7290224

0.107341364 0.038548037 0.5317926

0.10932518 0.050748993 0.26898557

0.10725307 0.048360553 0.21133506

0.11491291 0.050901167 0.252268

0.10030399 0.038433574 0.11526367

0.10839401 0.04057438 0.2886464

0.10331808 0.033942413 0.34556502

0.10312855 0.042218458 0.28834325

0.10705197 0.040992375 0.29649103

0.09709285 0.033289075 0.36462563

0.109361514 0.046177126 0.27830768

0.105315395 0.039995104 0.3126964

0.110388905 0.045361415 0.29012984

0.11780122 0.05492336 0.34669518

0.101461306 0.037622146 0.3320344

0.11104167 0.046675704 0.35370427

0.10872314 0.044631526 0.18860152

0.10147652 0.03933075 0.20048974

0.096749775 0.03376815 0.32717204

0.097046405 0.036267973 0.33117652

0.112862185 0.04903081 0.28869057

0.09960293 0.039875746 0.32367522

0.09779629 0.038981725 0.29801893

0.10668324 0.044474345 0.30882272

0.10430345 0.04366511 0.3388177

0.11779995 0.052845545 0.24904889

0.110859096 0.04805383 0.28831878

0.10110547 0.04256174 0.151417

In [70]:
key = jax.random.PRNGKey(0)
z_dist = eqx.filter_vmap(policy_model)(x)
z, log_prob = z_dist.sample_and_log_prob(seed=key)

print(z_dist.scale_diag)
print(z_dist.loc)
print((-log_prob).min())


[[0.00673795 0.01371963]
 [0.00673795 0.0124318 ]
 [0.00673795 0.01049418]
 [0.00673795 0.01051167]
 [0.00673795 0.01393469]
 [0.00673795 0.01329899]
 [0.00673795 0.01291853]
 [0.00673795 0.00869805]
 [0.00673795 0.00673795]
 [0.00673795 0.01056215]
 [0.00673795 0.0083251 ]
 [0.00673795 0.01127638]
 [0.00673795 0.01374628]
 [0.00673795 0.01356966]
 [0.00673795 0.00760112]
 [0.00673795 0.01076203]
 [0.00673795 0.01078376]
 [0.00673795 0.01412819]
 [0.00673795 0.00723075]
 [0.00673795 0.0133966 ]
 [0.00673795 0.01373136]
 [0.00673795 0.01288247]
 [0.00673795 0.00968719]
 [0.00673795 0.00950502]
 [0.00673795 0.00902373]
 [0.00673795 0.00707102]
 [0.00673795 0.01315871]
 [0.00673795 0.01330595]
 [0.00673795 0.01290078]
 [0.00673795 0.01143418]
 [0.00673795 0.00915954]
 [0.00673795 0.00976726]
 [0.00673795 0.01201004]
 [0.00673795 0.00875782]
 [0.00673795 0.00891351]
 [0.00673795 0.01329588]
 [0.00673795 0.01108508]
 [0.00673795 0.00944947]
 [0.00673795 0.01112261]
 [0.00673795 0.01057859]
