In [6]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from jaxmarl import make
from functools import partial

# Hiperparámetros
NUM_ENVS = 4
NUM_STEPS = 1000
NUM_AGENTS = 3
LR = 1e-2

In [7]:
# Crear entornos
keys = jax.random.split(jax.random.PRNGKey(0), NUM_ENVS)
envs = [make("switch_riddle", num_agents=NUM_AGENTS) for _ in range(NUM_ENVS)]
states = [env.reset(k)[1] for env, k in zip(envs, keys)]
obs = [env.reset(k)[0] for env, k in zip(envs, keys)]

In [13]:
# === NETWORK DEFINITION ===
def net_fn(obs):
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.nets.MLP([32, 32]),
    ])
    hidden = mlp(obs)
    logits = hk.Linear(5)(hidden)  # Assume 5 discrete actions
    value = hk.Linear(1)(hidden)
    return logits, value

def make_policy():
    return hk.transform(net_fn)

In [14]:
# === INIT NETWORKS, OPTIMIZERS ===
policy_fns = {}
params = {}
opt_state = {}
optimizers = {}
for i in range(NUM_AGENTS):
    agent = f"agent_{i}"
    policy = make_policy()
    policy_fns[agent] = policy
    dummy_obs = jnp.zeros((1, *envs[0].observation_space(agent).shape))
    params[agent] = policy.init(jax.random.PRNGKey(42 + i), dummy_obs)
    optimizers[agent] = optax.adam(LR)
    opt_state[agent] = optimizers[agent].init(params[agent])

In [15]:
# === ACTION FUNCTION ===
def select_action(params, obs, key, policy_fn):
    logits, value = policy_fn.apply(params, key, obs)
    action = jax.random.categorical(key, logits)
    log_prob = jax.nn.log_softmax(logits)[0, action]
    return action, value[0, 0], log_prob


In [21]:
# === LOSS FUNCTION ===
def loss_fn(params, key, obs, action, advantage, old_log_prob, returns, policy_fn):
    logits, value = policy_fn.apply(params, key, obs)
    log_probs = jax.nn.log_softmax(logits)
    new_log_prob = log_probs[0, action]
    ratio = jnp.exp(new_log_prob - old_log_prob)
    clipped_ratio = jnp.clip(ratio, 0.8, 1.2)
    actor_loss = -jnp.minimum(ratio * advantage, clipped_ratio * advantage)
    critic_loss = (returns - value[0, 0]) ** 2
    loss = actor_loss + 0.5 * critic_loss
    return loss.squeeze()

In [None]:
# === TRAINING LOOP ===
for step in range(NUM_STEPS):
    for env_idx, env in enumerate(envs):
        state = states[env_idx]
        obs_env = obs[env_idx]
        key = jax.random.PRNGKey(step * 100 + env_idx)

        actions = {}
        values = {}
        log_probs = {}

        for i, agent in enumerate(env.agents):
            obs_agent = jnp.array(obs_env[agent])[None, ...]
            key, subkey = jax.random.split(key)
            action, value, log_prob = select_action(params[agent], obs_agent, subkey, policy_fns[agent])
            actions[agent] = action
            values[agent] = value
            log_probs[agent] = log_prob

        obs_next, state_next, reward, done, info = env.step(key, state, actions)

        # PPO step (1-step, advantage = reward - value)
        for agent in env.agents:
            obs_agent = jnp.array(obs_env[agent])[None, ...]
            rew = reward[agent]
            advantage = rew - values[agent]
            returns = rew

            grads = jax.grad(loss_fn)(
                params[agent], subkey, obs_agent, actions[agent], advantage,
                log_probs[agent], returns, policy_fns[agent]
            )

            updates, opt_state[agent] = optimizers[agent].update(grads, opt_state[agent])
            params[agent] = optax.apply_updates(params[agent], updates)

        # Update env state
        states[env_idx] = state_next
        obs[env_idx] = obs_next

    if step % 100 == 0:
        for agent in reward:
            print(f'\nStep {step}: ')
            print(f"reward[{agent}] = {reward[agent]:.2f}")

{'agent_0': Array(0., dtype=float32, weak_type=True), 'agent_1': Array(0., dtype=float32, weak_type=True), 'agent_2': Array(0., dtype=float32, weak_type=True)}
Step 0: reward[agent_0] = 0.00
Step 0: reward[agent_1] = 0.00
Step 0: reward[agent_2] = 0.00


KeyboardInterrupt: 