In [14]:
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter

from ignite.engine import Engine, Events

import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display


In [15]:
seed_val = 543
gamma = 0.99
log_interval = 100
max_episodes = 1000000
render = True
verbose_steps = True  # print policy prediction and reward every step
writer = SummaryWriter("runs/cartpole")


In [16]:
env = gym.make("CartPole-v0", render_mode="rgb_array")

display = Display(visible=0, size=(1400, 900))
display.start()

def wrap_env(env):
  env = RecordVideo(env, './video', disable_logger=True)
  return env

env = wrap_env(env)



  logger.deprecation(
  logger.warn(


In [17]:
# The Policy class defines a neural network used to decide which action to take in the CartPole environment given an observation (state) of the environment.
# Specifically:
# - The network receives the environment state (which has 4 features for CartPole) as input.
# - It passes the state through a fully connected layer -> dropout -> ReLU activation -> another fully connected layer.
# - The output is converted to action probabilities using a softmax, representing the probability of choosing each possible action (left or right).
# - The network also keeps track of the log-probabilities of actions taken (`saved_log_probs`) and received rewards (`rewards`) for use in the REINFORCE algorithm's policy gradient update procedure.
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

        # Stores log-probabilities of actions taken for each episode/timestep
        self.saved_log_probs = []
        # Stores rewards received for each timestep
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        # Output is the probability distribution over actions (used for sampling actions)
        return F.softmax(action_scores, dim=1)


In [18]:
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()
timesteps = range(10000)

In [19]:
def select_action(policy, observation):
    state = torch.from_numpy(observation).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item(), probs.detach()

def finish_episode(policy, optimizer, gamma):
    if not policy.saved_log_probs or not policy.rewards:
        return None
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        returns.appendleft(R)
    
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)

    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)

    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    loss_value = policy_loss.item()
    policy_loss.backward()
    optimizer.step()

    del policy.rewards[:]
    del policy.saved_log_probs[:]
    return loss_value

def run_single_timestep(engine, timestep):
    observation = engine.state.observation
    action, probs = select_action(policy, observation)
    engine.state.observation, reward, done, _, _ = env.step(action)
    if render:
        env.render()

    policy.rewards.append(reward)
    engine.state.ep_reward += reward

    step_in_episode = len(policy.rewards) - 1
    if not hasattr(engine.state, "global_step"):
        engine.state.global_step = 0
    engine.state.global_step += 1
    gs = engine.state.global_step

    if verbose_steps:
        p = probs[0].cpu().numpy()
        print(f"Episode {engine.state.epoch} Step {step_in_episode} probs=[{p[0]:.3f}, {p[1]:.3f}] action={action} reward={reward}")
    writer.add_scalar("policy/prob_left", probs[0, 0].item(), gs)
    writer.add_scalar("policy/prob_right", probs[0, 1].item(), gs)
    writer.add_scalar("policy/reward_step", reward, gs)
    writer.add_scalar("policy/cumulative_reward", engine.state.ep_reward, gs)

    if done:
        engine.terminate_epoch()
        engine.state.timestep = timestep

trainer = Engine(run_single_timestep)


In [20]:
EPISODE_STARTED = Events.EPOCH_STARTED
EPISODE_COMPLETED = Events.EPOCH_COMPLETED

trainer.state.running_reward = 10

@trainer.on(EPISODE_STARTED)
def reset_environment_state():
    if not hasattr(trainer.state, "global_step"):
        trainer.state.global_step = 0
    torch.manual_seed(seed_val + trainer.state.epoch)
    trainer.state.observation, _ = env.reset(seed=seed_val + trainer.state.epoch)
    trainer.state.ep_reward = 0

@trainer.on(EPISODE_COMPLETED)
def update_model():
    trainer.state.running_reward = 0.05 * trainer.state.ep_reward + (1 - 0.05) * trainer.state.running_reward
    loss_value = finish_episode(policy, optimizer, gamma)
    if loss_value is not None:
        print(f"Episode {trainer.state.epoch} policy loss: {loss_value:.4f}")
        writer.add_scalar("policy_loss", loss_value, trainer.state.epoch)
        writer.add_scalar("episode_reward", trainer.state.ep_reward, trainer.state.epoch)

@trainer.on(EPISODE_COMPLETED(every=log_interval))
def log_episode():
    i_episode = trainer.state.epoch
    print(
        f"Episode {i_episode}\tLast reward: {trainer.state.ep_reward:.2f}"
        f"\tAverage length: {trainer.state.running_reward:.2f}"
    )

@trainer.on(EPISODE_COMPLETED)
def should_finish_training():
    running_reward = trainer.state.running_reward
    if running_reward > env.spec.reward_threshold:
        print(
            f"Solved! Running reward is now {running_reward} and "
            f"the last episode runs to {trainer.state.timestep} time steps!"
        )
        trainer.should_terminate = True



In [21]:
trainer.run(timesteps, max_epochs=max_episodes)


Episode 1 Step 0 probs=[0.632, 0.368] action=0 reward=1.0
Episode 1 Step 1 probs=[0.625, 0.375] action=0 reward=1.0
Episode 1 Step 2 probs=[0.657, 0.343] action=1 reward=1.0
Episode 1 Step 3 probs=[0.564, 0.436] action=1 reward=1.0
Episode 1 Step 4 probs=[0.483, 0.517] action=0 reward=1.0
Episode 1 Step 5 probs=[0.607, 0.393] action=0 reward=1.0
Episode 1 Step 6 probs=[0.742, 0.258] action=0 reward=1.0
Episode 1 Step 7 probs=[0.683, 0.317] action=0 reward=1.0
Episode 1 Step 8 probs=[0.663, 0.337] action=1 reward=1.0
Episode 1 Step 9 probs=[0.633, 0.367] action=0 reward=1.0
Episode 1 Step 10 probs=[0.705, 0.295] action=1 reward=1.0
Episode 1 Step 11 probs=[0.690, 0.310] action=1 reward=1.0
Episode 1 Step 12 probs=[0.566, 0.434] action=1 reward=1.0
Episode 1 Step 13 probs=[0.623, 0.377] action=1 reward=1.0
Episode 1 Step 14 probs=[0.487, 0.513] action=1 reward=1.0
Episode 1 Step 15 probs=[0.544, 0.456] action=1 reward=1.0
Episode 1 Step 16 probs=[0.615, 0.385] action=0 reward=1.0
Episode

State:
	iteration: 15077
	epoch: 131
	epoch_length: 10000
	max_epochs: 1000000
	output: <class 'NoneType'>
	batch: 380
	metrics: <class 'dict'>
	dataloader: <class 'range'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	running_reward: 199.67944762477586
	global_step: 15077
	observation: <class 'numpy.ndarray'>
	ep_reward: 381.0
	timestep: 380

In [25]:
import os

mp4list = glob.glob('video/*.mp4')

if len(mp4list) > 0:
    # Get the latest mp4 file by creation/modification time
    mp4 = max(mp4list, key=os.path.getmtime)
    print(mp4)
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
              </video>'''.format(encoded.decode('ascii'))))
else: 
    print("Could not find video")

video/rl-video-episode-125.mp4
