In [None]:
!pip install gymnax --quiet

In [None]:
import jax
import gymnax
from flax import nnx
import jax.numpy as jnp
import wandb
import optax
import collections
from tqdm import tqdm

key = jax.random.key(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

env, env_params = gymnax.make("CartPole-v1")

In [None]:
env.action_space(env_params).n
env.observation_space(env_params).shape[0]

4

In [None]:
env.step()

In [None]:
class Policy(nnx.Module):
  def __init__(self, observation_space, action_space, rngs:nnx.Rngs):
    super().__init__()
    self.layer1 = nnx.Linear(observation_space.shape[0], 128, rngs = rngs)
    self.layer2 = nnx.Linear(128, 128, rngs=rngs)

    self.layer3 = nnx.Linear(128, action_space.n, rngs=rngs)

  def __call__(self, x):
    x = jax.nn.relu(self.layer1(x))
    x = jax.nn.relu(self.layer2(x))
    return self.layer3(x)

  def select_action(self, x, key):
    logits = self(x)
    return jax.random.categorical(key, logits)

def loss(model, obs, actions, returns):
    log_logits = jax.nn.log_softmax(model(obs))

    log_prob_taken = jnp.take_along_axis(log_logits, actions[:, None], axis=1).squeeze()

    return -jnp.mean(log_prob_taken * returns)

def compute_returns(rewards, gamma):
    R = 0
    returns = []
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return jnp.array(returns)

model = Policy(env.observation_space(env_params), env.action_space(env_params), rngs=nnx.Rngs(0))
wandb.init(project="JAX-GYMNAX", config={
    "env": "Cartpole",
    "lr": 1e-2,
    "gamma": 0.99,
    "episodes": 500,
})

In [None]:
def train(env, env_params, model, episodes: int = 50, learning_rate=1e-3, gamma=0.99):
    optimizer = nnx.Optimizer(
        model,
        optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(learning_rate=learning_rate),
        )
    )
    grad_func = nnx.value_and_grad(loss)
    key = jax.random.PRNGKey(0)
    total_rewards = collections.deque(maxlen=100)

    with tqdm(range(episodes)) as pbar:
        for i in pbar:
            batch_all_obs = []
            batch_all_actions = []
            batch_all_returns = []
            # Using Batching for Gradient Stability
            for _ in range(10):
                episode_obs = []
                episode_actions = []
                episode_rewards = []

                done = False
                key, reset_key = jax.random.split(key)
                obs, state = env.reset(reset_key, env_params)

                while not done:
                    key, action_key, step_key = jax.random.split(key, 3)
                    action = model.select_action(obs, action_key)
                    next_obs, state, reward, done, _ = env.step(step_key, state, action, env_params)

                    episode_obs.append(obs)
                    episode_actions.append(action)
                    episode_rewards.append(reward)
                    obs = next_obs

                total_rewards.append(sum(episode_rewards))
                returns = compute_returns(episode_rewards, gamma)

                batch_all_obs.extend(episode_obs)
                batch_all_actions.extend(episode_actions)
                batch_all_returns.extend(returns)

            final_obs = jnp.stack(batch_all_obs)
            final_actions = jnp.array(batch_all_actions)
            final_returns = jnp.array(batch_all_returns)

            final_returns = (final_returns - jnp.mean(final_returns)) / (jnp.std(final_returns) + 1e-8)

            value, grad = grad_func(model, final_obs, final_actions, final_returns)
            optimizer.update(grad)

            avg_reward = sum(total_rewards) / len(total_rewards)
            wandb.log({
                "episodic_reward": avg_reward,
                "global_step": i,
                "loss": value.item()
            })
            pbar.set_description(f"Episode: {i}, Loss: {value.item():.4f}, Reward: {avg_reward:.2f}")

In [None]:
train(env, env_params, model, episodes = 100)

Episode: 99, Loss: -0.0097, Reward: 471.87: 100%|██████████| 100/100 [39:06<00:00, 23.47s/it]


In [None]:
from flax import serialization
import orbax.checkpoint as orbax

state = nnx.state(model)
checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save(f'/content/model_state', state)

In [None]:
import shutil
import os

source_dir = '/content/model_state'
destination_dir = '/content/drive/MyDrive/RL_MODELS/'

os.makedirs(destination_dir, exist_ok=True)

shutil.copytree(source_dir, os.path.join(destination_dir, os.path.basename(source_dir)), dirs_exist_ok=True)

print(f"Model state copied from {source_dir} to {destination_dir}")