In [1]:
import gym
env = gym.make("CartPole-v1")

In [2]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [3]:
for action in range(env.action_space.n):
  print(action)

0
1


In [4]:
# the poles's starting state:
env.reset()

(array([ 0.02464805,  0.01900045, -0.03181395,  0.02836171], dtype=float32),
 {})

In [5]:
env.step(action=0)
# return: obs/next state, reward, terminated, truncated, info

(array([ 0.02502805, -0.17565116, -0.03124672,  0.31083965], dtype=float32),
 1.0,
 False,
 False,
 {})

In [6]:
# train stage PPO
# https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
import gym
env = gym.make("CartPole-v1")
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1) # hide log
model.learn(total_timesteps=100000)
model.save('cartpole_ppo_model')
env.close()



Using cpu device
-----------------------------
| time/              |      |
|    fps             | 1535 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 933        |
|    iterations           | 2          |
|    time_elapsed         | 4          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.00870229 |
|    clip_fraction        | 0.0987     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.686     |
|    explained_variance   | -0.00387   |
|    learning_rate        | 0.0003     |
|    loss                 | 7.99       |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0169    |
|    value_loss           | 59.8       |
----------------------------------------
-----------------------------------

In [9]:
import gym
from PIL import Image
from stable_baselines3 import PPO
env = gym.make("CartPole-v1", render_mode='rgb_array')
obs, info = env.reset()
screen = env.render() # render image
images= [Image.fromarray(screen)]
i = 1
truncated = False
while (not truncated) and i <= 1000:
  action = env.action_space.sample()
  model = PPO.load("cartpole_ppo_model")
  action, states = model.predict(obs, deterministic =True)
  obs, reward, terminated, truncated, info = env.step(action)
  if (i%100 == 0): print("step", i, "action", action, obs, reward, terminated, info)
  i += 1
  screen=env.render()
  images.append(Image.fromarray(screen))
  env.close()

image_file = '63130500113-cartpole-v1-PPO.gif'
# loop=0: loop forever, duration=1: play each frame for ims
images[0].save(image_file, save_all =True, append_images=images[1:], loop=1, duration=1)

step 100 action 1 [ 0.30153978  0.01557292 -0.01578305 -0.0860551 ] 1.0 False {}
step 200 action 0 [ 0.17144533  0.00961364 -0.00075421  0.04541368] 1.0 False {}
step 300 action 0 [ 0.21665552  0.01204553 -0.00170038 -0.00823242] 1.0 False {}
step 400 action 1 [ 0.21790631  0.0126099   0.00428678 -0.02068246] 1.0 False {}
step 500 action 1 [ 0.21953093  0.01256703  0.00202602 -0.01973692] 1.0 False {}


In [None]:
# train stage DQN
import gymnasium as gym
from stable_baselines3 import DQN
env = gym.make("CartPole-v1", render_mode="human")
model = DQN("MlpPolicy", env, verbose=1) # hide log
model.learn(total_timesteps = 10000, log_interval=4)
model.save("dqn_cartpole")
env.close()

In [12]:
import gymnasium as gym
from PIL import Image
from stable_baselines3 import DQN
env = gym.make("CartPole-v1", render_mode='rgb_array')
env.reset()
screen = env.render() # render image
images= [Image.fromarray(screen)]
i = 1
truncated = False
while (not truncated) and i <= 100:
  action = env.action_space.sample()
  model = DQN.load("dqn_cartpole")
  # obs, info = env.reset()
  action, states = model.predict(obs, deterministic =True)
  obs, reward, terminated, truncated, info = env.step(action)
  if (i%5 == 0): print ("step", i, "action", action, obs, reward, terminated, info)
  screen=env.render()
  images.append(Image.fromarray(screen))
  # if terminated or truncated:
  #   obs, info = env.reset()
  i += 1
  env.close()

image_file = '63130500113-cartpole-v1-DQN.gif'
# loop=0: loop forever, duration=1: play each frame for ims
images[0].save(image_file, save_all =True, append_images=images[1:], loop=1, duration=1)

step 5 action 1 [ 0.03579208  0.9882051  -0.0489381  -1.4514667 ] 1.0 False {}
step 10 action 1 [ 0.17378734  1.9678771  -0.25737777 -3.08326   ] 1.0 True {}
step 15 action 1 [ 0.4093853  2.927927  -0.6394004 -4.9691753] 0.0 True {}
step 20 action 1 [ 0.71511364  3.4125738  -1.1899567  -6.492723  ] 0.0 True {}
step 25 action 1 [ 1.0852885  4.101621  -1.9065473 -8.048179 ] 0.0 True {}
step 30 action 1 [ 1.5205128  4.7577624 -2.7444549 -8.570825 ] 0.0 True {}
step 35 action 1 [ 2.0313601  5.7007384 -3.563955  -7.2815375] 0.0 True {}
step 40 action 1 [ 2.645497   6.7979584 -4.205675  -5.113517 ] 0.0 True {}
step 45 action 1 [ 3.3667715  7.8147163 -4.6385956 -3.2502618] 0.0 True {}
step 50 action 1 [ 4.1862383  8.754143  -4.9042115 -1.8463231] 0.0 True {}
step 55 action 1 [ 5.0980325  9.65999   -5.043811  -0.7697154] 0.0 True {}
step 60 action 1 [ 6.099889   10.555672   -5.0836344   0.14209874] 0.0 True {}
step 65 action 1 [ 7.1912613 11.451677  -5.033786   1.0470749] 0.0 True {}
step 70 a

In [13]:
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
import gym
env = gym.make("CartPole-v1", render_mode="human")
models = {"dqn": DQN.load('dqn_cartpole'), "ppo": PPO.load('cartpole_ppo_model') }
for k, v in models.items():
    mean_rw, sd_rw = evaluate_policy(models[k], env, n_eval_episodes = 10, render=True)
    print(k,' reward: Mean', mean_rw, 'SD', sd_rw )



dqn  reward: Mean 9.3 SD 0.6403124237432849
ppo  reward: Mean 500.0 SD 0.0
