In [None]:
!pip install brax flax optax

In [None]:
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import tensorflow_probability.substrates.jax as tfp
from brax.envs import inverted_pendulum
from brax.io import html
from IPython.display import HTML

tfd = tfp.distributions

from typing import Sequence

In [None]:
print(f"JAXのデフォルトバックエンド: {jax.default_backend()}")
print("JAXが認識しているデバイス:")
print(jax.devices())

In [None]:
# 環境のインスタンス化とJIT化
env = inverted_pendulum.InvertedPendulum(backend="positional")
# env = inverted_pendulum.InvertedPendulum()
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

# JAXの乱数キー
key = jax.random.PRNGKey(0)

## Actor-Critic (PPO)


In [None]:
class ActorCritic(nn.Module):
    action_size: int
    hidden_sizes: Sequence[int]

    @nn.compact
    def __call__(self, x):
        shared = x
        for size in self.hidden_sizes:
            shared = nn.Dense(features=size)(shared)
            shared = nn.relu(shared)

        # Actor
        loc = nn.Dense(features=self.action_size)(shared)
        loc = nn.tanh(loc)

        log_std = self.param("log_std", nn.initializers.zeros, (self.action_size,))
        scale = jnp.exp(log_std)

        # Critic
        value = nn.Dense(features=1)(shared)

        return loc, scale, jnp.squeeze(value, axis=-1)

In [None]:
action_size = env.action_size
ac_net = ActorCritic(action_size=action_size, hidden_sizes=[64, 64])
key, ac_key = jax.random.split(key)
dummy_obs = jnp.zeros((1, env.observation_size))
params = ac_net.init(ac_key, dummy_obs)["params"]

learning_rate = 3e-4
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0), optax.adam(learning_rate=learning_rate)
)
opt_state = optimizer.init(params)

In [None]:
@jax.jit
def train_step(params, opt_state, key):
    def rollout_step(carry, _):
        state, key = carry
        key, policy_key = jax.random.split(key)
        loc, scale, value = ac_net.apply({"params": params}, state.obs)
        dist = tfd.Normal(loc=loc, scale=scale)
        action = dist.sample(seed=policy_key)
        log_prob = dist.log_prob(action).sum()
        next_state = jit_env_step(state, action)
        transition = (
            state.obs,
            action,
            log_prob,
            value,
            next_state.reward,
            1.0 - next_state.done,
        )

        return (next_state, key), transition

    def calculate_gae(transitions, last_val):
        gamma, lambda_ = 0.99, 0.95

        def scan_fn(gae_and_next_val, transition):
            gae, next_val = gae_and_next_val
            _, _, _, value, reward, done = transition
            delta = reward + gamma * next_val * done - value
            gae = delta + gamma * lambda_ * gae * done
            return (gae, value), gae

        _, advantages = jax.lax.scan(
            scan_fn, (0.0, last_val), transitions, reverse=True
        )
        returns = advantages + transitions[3]

        return advantages, returns

    def loss_fn(params, obs, action, log_prob_old, advantage, return_val):
        loc, scale, value_pred = ac_net.apply({"params": params}, obs)
        dist = tfd.Normal(loc=loc, scale=scale)
        log_prob_new = dist.log_prob(action).sum()
        ratio = jnp.exp(log_prob_new - log_prob_old)
        policy_loss = -jnp.minimum(
            ratio * advantage, jnp.clip(ratio, 1.0 - 0.2, 1.0 + 0.2) * advantage
        )
        value_loss = optax.l2_loss(value_pred, return_val)
        return policy_loss.mean() + 0.5 * value_loss.mean()

    # Rollout
    key, reset_key = jax.random.split(key)
    initial_state = jit_env_reset(reset_key)
    (final_state, _), transitions = jax.lax.scan(
        rollout_step, (initial_state, key), None, length=200
    )

    # GAE Calculation
    _, _, last_val = ac_net.apply({"params": params}, final_state.obs)

    advantages, returns = calculate_gae(transitions, last_val)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # Loss Calculation and Update
    obs_batch, action_batch, log_prob_batch, _, _, _ = transitions
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(
        params, obs_batch, action_batch, log_prob_batch, advantages, returns
    )
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss

In [None]:
print("🚀 Starting training run...")
policy_params = params
total_epochs = 500
for epoch in range(1, total_epochs + 1):
    key, train_key = jax.random.split(key)
    policy_params, opt_state, loss = train_step(policy_params, opt_state, train_key)

    if epoch % 50 == 0:
        print(f"Epoch: {epoch}/{total_epochs}, Loss: {loss:.4f}")

print("✅ Training complete!")

In [None]:
# --- 6. Evaluation ---
print("✅ Evaluating trained model...")

# Ensure the jit_policy function is defined
jit_policy = jax.jit(
    lambda params, state: ac_net.apply({"params": params}, state.obs)[0]
)

key, eval_key = jax.random.split(key)
eval_state = jit_env_reset(eval_key)
rollout = []

for _ in range(1000):  # Evaluate for 1000 steps
    rollout.append(eval_state)
    action = jit_policy(policy_params, eval_state)
    eval_state = jit_env_step(eval_state, action)
    if eval_state.done:  # Stop if the episode ends
        break

# Correctly access the physics state via `pipeline_state`
physics_states = [s.pipeline_state for s in rollout]
display(HTML(html.render(env.sys, physics_states)))