In [1]:
!pip install brax jax jaxlib flax optax

Collecting brax
  Downloading brax-0.10.4-py3-none-any.whl (998 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m998.3/998.3 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
Collecting dm-env (from brax)
  Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Collecting flask-cors (from brax)
  Downloading Flask_Cors-4.0.1-py2.py3-none-any.whl (14 kB)
Collecting jaxopt (from brax)
  Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.3/172.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Collecting ml-collections (from brax)
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mujoco (from brax)
  Downloading mujoco-3.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [10]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import linen as nn
from collections import deque
from IPython.display import HTML
import brax
from brax.io import html
from brax.envs import create

# GPU 사용 확인
print("JAX Devices:", jax.devices())

# 병렬 환경 설정
def create_env(env_name, batch_size):
    env = create(env_name)
    reset_fn = jax.vmap(env.reset)
    step_fn = jax.vmap(env.step)
    return env, reset_fn, step_fn

# Actor-Critic 네트워크 정의
class Actor(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dim)(x)
        x = nn.tanh(x)
        return x

class Critic(nn.Module):
    @nn.compact
    def __call__(self, x, a):
        x = jnp.concatenate([x, a], axis=-1)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Replay Buffer 정의
class ReplayBuffer:
    def __init__(self, buffer_size, state_dim, action_dim):
        self.buffer = deque(maxlen=buffer_size)
        self.state_dim = state_dim
        self.action_dim = action_dim

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in indices])
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

# DDPG 알고리즘 정의
class DDPGAgent:
    def __init__(self, state_dim, action_dim, actor_lr, critic_lr, gamma, tau, buffer_size, batch_size):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size

        self.actor = Actor(action_dim)
        self.critic = Critic()
        self.target_actor = Actor(action_dim)
        self.target_critic = Critic()

        self.actor_params = self.actor.init(jax.random.PRNGKey(0), jnp.ones((state_dim,)))
        self.critic_params = self.critic.init(jax.random.PRNGKey(1), jnp.ones((state_dim,)), jnp.ones((action_dim,)))
        self.target_actor_params = self.actor_params
        self.target_critic_params = self.critic_params

        self.actor_optimizer = optax.adam(actor_lr)
        self.critic_optimizer = optax.adam(critic_lr)
        self.actor_opt_state = self.actor_optimizer.init(self.actor_params)
        self.critic_opt_state = self.critic_optimizer.init(self.critic_params)

        self.replay_buffer = ReplayBuffer(buffer_size, state_dim, action_dim)

    def select_action(self, state, noise_scale):
        action = self.actor.apply(self.actor_params, state)
        action = action + noise_scale * np.random.randn(self.action_dim)
        return np.clip(action, -1, 1)

    def update(self):
        if len(self.replay_buffer.buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        # Update Critic
        next_actions = self.target_actor.apply(self.target_actor_params, next_states)
        target_q_values = self.target_critic.apply(self.target_critic_params, next_states, next_actions)
        y = rewards + self.gamma * (1 - dones) * target_q_values.squeeze()
        y = y[:, None]

        def critic_loss_fn(critic_params):
            q_values = self.critic.apply(critic_params, states, actions)
            loss = jnp.mean((q_values - y) ** 2)
            return loss

        grad_fn = jax.value_and_grad(critic_loss_fn)
        loss, grads = grad_fn(self.critic_params)
        updates, self.critic_opt_state = self.critic_optimizer.update(grads, self.critic_opt_state)
        self.critic_params = optax.apply_updates(self.critic_params, updates)

        # Update Actor
        def actor_loss_fn(actor_params):
            actions = self.actor.apply(actor_params, states)
            q_values = self.critic.apply(self.critic_params, states, actions)
            loss = -jnp.mean(q_values)
            return loss

        grad_fn = jax.value_and_grad(actor_loss_fn)
        loss, grads = grad_fn(self.actor_params)
        updates, self.actor_opt_state = self.actor_optimizer.update(grads, self.actor_opt_state)
        self.actor_params = optax.apply_updates(self.actor_params, updates)

        # Update Target Networks
        self.target_actor_params = jax.tree_multimap(lambda x, y: x * (1 - self.tau) + y * self.tau,
                                                     self.target_actor_params, self.actor_params)
        self.target_critic_params = jax.tree_multimap(lambda x, y: x * (1 - self.tau) + y * self.tau,
                                                      self.target_critic_params, self.critic_params)

# 에이전트 학습 및 평가
batch_size = 32  # 병렬로 처리할 환경의 수
env, reset_fn, step_fn = create_env('ant', batch_size=batch_size)

state_dim = env.observation_size
action_dim = env.action_size

agent = DDPGAgent(state_dim, action_dim, actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.005, buffer_size=100000, batch_size=64)

num_episodes = 500
log_interval = 10  # 로그를 출력할 에피소드 간격

for episode in range(num_episodes):
    keys = jax.random.split(jax.random.PRNGKey(episode), batch_size)
    states = reset_fn(keys)
    total_rewards = np.zeros(batch_size)
    dones = np.zeros(batch_size, dtype=bool)

    while not np.all(dones):
        # State 객체에서 상태 배열을 추출
        states = states.obs
        actions = jax.vmap(agent.select_action, in_axes=(0, None))(states, 0.1)
        next_states, rewards, dones, _ = step_fn(states, actions)
        # 다음 상태에서도 상태 배열을 추출
        next_states = next_states.obs
        for i in range(batch_size):
            agent.replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], dones[i])
        agent.update()
        states = next_states
        total_rewards += rewards * (~dones)

    if episode % log_interval == 0:
        print(f"Episode: {episode}, Average Total Reward: {np.mean(total_rewards)}")

# 학습 후 에이전트의 동작 시각화
def visualize_agent(agent, env):
    state = env.reset(rng=jax.random.PRNGKey(seed=0))
    done = False
    frames = []

    while not done:
        state = state.obs  # 상태 배열을 추출
        action = agent.select_action(state, noise_scale=0)
        next_state, reward, done, _ = env.step(action)
        next_state = next_state.obs  # 상태 배열을 추출
        frames.append(env.render())
        state = next_state

    return html.render(frames)

# 환경의 병렬 처리 비활성화
env = create('ant')
HTML(visualize_agent(agent, env))

JAX Devices: [cuda(id=0)]


AttributeError: BatchTracer has no attribute info