In [1]:
import jax
import jax.numpy as jnp
import os
os.environ['CUDA_VISIBLE_DEVICES']='2'

import matplotlib.pyplot as plt
from IPython import display

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
from ott.tools.sinkhorn_divergence import SinkhornDivergenceOutput
from ott.solvers.linear import implicit_differentiation as imp_diff
import equinox as eqx

In [48]:
def sink_div(combined, states, y, b, key) -> tuple[float, float]:
    agent_value, agent_policy = combined
    z_dist = eqx.filter_vmap(agent_policy)(states)
    z, log_prob = z_dist.sample_and_log_prob(seed=key) # intentions of agent
    geom = pointcloud.PointCloud(z, y)
    
    a = eqx.filter_vmap(agent_value)(states, z).squeeze() # weights for intents of agent
    an = jax.nn.softplus(a - jnp.quantile(a, 0.01)) 
    bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
        

    an = an / an.sum()
    bn = bn / bn.sum()
    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        a=an,
        b=bn,
        y=geom.y,
        static_b=True,
        sinkhorn_kwargs={
            "implicit_diff": imp_diff.ImplicitDiff(),
            "use_danskin": True,
            "max_iterations": 2000
        },
    )
    return ot.divergence, (-log_prob.squeeze()).min()

In [49]:
import equinox as eqx
import equinox.nn as eqxnn

class MonolithicVF_EQX(eqx.Module):
    net: eqx.Module
    
    def __init__(self, key, state_dim, intents_dim, hidden_dims):
        key, mlp_key = jax.random.split(key, 2)
        self.net = eqxnn.MLP(
            in_size=state_dim + intents_dim, out_size=1, width_size=hidden_dims[-1], depth=len(hidden_dims), key=mlp_key
        )
        
    def __call__(self, observations, intents):
        # TODO: Maybe try FiLM conditioning like in SAC-RND?
        conditioning = jnp.concatenate([observations, intents], axis=-1)
        return self.net(conditioning)

In [50]:
from jaxrl_m.common import TrainStateEQX
from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy
import optax

key = jax.random.PRNGKey(42)
actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,
                             hidden_dims=[128, 128, 128],
                             state_dim=29,
                             intent_dim=2), optim=optax.adam(learning_rate=3e-4))

x = jax.random.normal(key, (250, 29)) 
z_dist = eqx.filter_vmap(actor_intents_learner.model)(x)
z, log_prob = z_dist.sample_and_log_prob(seed=key)

print(log_prob.shape)

(250,)


In [51]:
def gradient_flow(
    x: jnp.ndarray,
    y: jnp.ndarray,
    b,
    cost_fn: callable,
    num_iter: int = 6000,
    dump_every: int = 50
):
    def v_loss(agent_policy, agent_value, states, key) -> float:
        z_dist = eqx.filter_vmap(agent_policy)(states)
        z, _ = z_dist.sample_and_log_prob(seed=key)
        v = eqx.filter_vmap(agent_value)(states, z).squeeze()
        return -v.mean() * 0.1

    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn, has_aux=True))
    v_loss_vg = eqx.filter_jit(eqx.filter_value_and_grad(v_loss, has_aux=False))

    key = jax.random.PRNGKey(42)
    V = TrainStateEQX.create(
        model=MonolithicVF_EQX(key, 29, 256, [128, 128, 128]),
        optim=optax.adam(learning_rate=3e-4)
    )

    key, pkey = jax.random.split(key, 2)

    actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,
                             hidden_dims=[128, 128, 128],
                             state_dim=29,
                             intent_dim=256), optim=optax.adam(learning_rate=3e-4))
    
    for i in range(0, num_iter + 1):
        key, key_6 = jax.random.split(key, 2)

        (cost, pmin), (value_grads, policy_grads) = cost_fn_vg((V.model, actor_intents_learner.model), x, y, b, key_6)
        v_loss, policy_grads_2 = v_loss_vg(actor_intents_learner.model, V.model, x, key_6)
        V = V.apply_updates(value_grads)
        policy_grads = jax.tree_map(lambda g1, g2: g1 + g2, policy_grads, policy_grads_2)
        actor_intents_learner = actor_intents_learner.apply_updates(policy_grads)

        if i % dump_every == 0:
            z = eqx.filter_vmap(actor_intents_learner.model)(x).sample(seed=key_6)
            a = eqx.filter_vmap(V.model)(x, z).squeeze()
            
            an = jax.nn.softplus(a - jnp.quantile(a, 0.01))
            bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))
            an = an / an.sum()
            bn = bn / bn.sum()

            geom = pointcloud.PointCloud(z, y, epsilon=0.01)
            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom, a = an, b = bn)).reg_ot_cost
            print(cost, diff, pmin)
            print()

    return policy.model, V.model
    

In [73]:
def sink_div(agent_policy, states, y, b, key) -> tuple[float, float]:
    z_dist = eqx.filter_vmap(agent_policy)(states)
    z, log_prob = z_dist.sample_and_log_prob(seed=key) # intentions of agent
    geom = pointcloud.PointCloud(z, y)
    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        y=geom.y,
        epsilon=0.1,
        static_b=True,
        sinkhorn_kwargs={
            "implicit_diff": imp_diff.ImplicitDiff(),
            "use_danskin": True,
            "max_iterations": 2000
        },
    )
    return ot.divergence
    
def gradient_flow(
    x: jnp.ndarray,
    y: jnp.ndarray,
    b,
    cost_fn: callable,
    num_iter: int = 6000,
    dump_every: int = 50
):

    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn))
    key = jax.random.PRNGKey(42)
    key, pkey = jax.random.split(key, 2)

    actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,
                             hidden_dims=[128, 128, 128],
                             state_dim=29,
                             intent_dim=256), optim=optax.adam(learning_rate=3e-4))
    
    for i in range(0, num_iter + 1):
        key, key_6 = jax.random.split(key, 2)

        cost, policy_grads = cost_fn_vg(actor_intents_learner.model, x, y, b, key_6)
        actor_intents_learner = actor_intents_learner.apply_updates(policy_grads)

        if i % dump_every == 0:
            z = eqx.filter_vmap(actor_intents_learner.model)(x).sample(seed=key_6)
            geom = pointcloud.PointCloud(z, y, epsilon=0.01)
            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom)).reg_ot_cost
            print(cost, diff)
            print()

    return policy.model
    

In [74]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)


x = jax.random.normal(key1, (100, 29))
y = jax.random.normal(key2, (400, 256))

policy_model, V_model = gradient_flow(x, y, marginal_b, cost_fn=sink_div)

435.81137 434.6088

425.7082 424.34558



KeyboardInterrupt: 