In [None]:
import torch
from torch import nn
import gymnasium
import plotly.express as px
from collections import deque
from tqdm import tqdm
import plotly.graph_objects as go
from IPython.display import display
from utils import Model, ReplayMemory

In [2]:
dtype = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gymnasium.make("CartPole-v1")

input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
max_reward = 500


print(f"Using device: {device}")
print(f"Input dimension: {input_dim}, Output dimension: {output_dim}")

Using device: cuda
Input dimension: 4, Output dimension: 2


In [3]:
def dict_to_device(dict):
    if device == 'cpu' or dict is None:
        return dict
    return {k: v.to(device) for k, v in dict.items()}

# returns X, y, actions
def terms_batch(terms):
    return terms['state'], terms['reward'], terms['action']

# returns X, y, actions
def non_terms_batch(non_terms, target_model, gamma):
    max_q = target_model.get_max_q(non_terms['next_state'])
    y = non_terms['reward'] + gamma * max_q
    return non_terms['state'], y, non_terms['action']

def process_batch(
    batch, 
    gamma, 
    target_model,
    policy_model, 
    criterion, 
    optimizer
):
    policy_model.train()

    non_terms = dict_to_device(batch['non_terms'])
    terms = dict_to_device(batch['terms'])

    if terms is None:
        x, y, actions = non_terms_batch(non_terms, target_model, gamma)

    elif non_terms is None:
        x, y, actions = terms_batch(terms)
    
    else:
        x_t, y_t, actions_t = terms_batch(terms)
        x_nt, y_nt, actions_nt = non_terms_batch(non_terms, target_model, gamma)
        
        y = torch.cat([y_t, y_nt])
        actions = torch.cat([actions_t, actions_nt])
        x = torch.cat([x_t, x_nt])
    
    
    preds = policy_model.get_action_qs(x, actions)
    loss = criterion(preds, y)

    optimizer.zero_grad()
    loss.backward()
    # I clip the gradient to prevent radical changes in the model
    nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
    optimizer.step()

    return loss.item()


| Num | Observation           | Min                          | Max                          |
|-----|------------------------|-------------------------------|-------------------------------|
| 0   | Cart Position          | -4.8                          | 4.8                           |
| 1   | Cart Velocity          | -Inf                          | Inf                           |
| 2   | Pole Angle             | ~ -0.418 rad (-24°)           | ~ 0.418 rad (24°)             |
| 3   | Pole Angular Velocity  | -Inf                          | Inf                           |


The cart x-position (index 0) can be take values between (-4.8, 4.8), but the episode terminates if the cart leaves the (-2.4, 2.4) range.

The pole angle can be observed between (-.418, .418) radians (or ±24°), but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°)



### Rewards


Since the goal is to keep the pole upright for as long as possible, by default, a reward of +1 is given for every step taken, including the termination step. The default reward threshold is 500 for v1 and 200 for v0 due to the time limit on the environment.

If sutton_barto_reward=True, then a reward of 0 is awarded for every non-terminating step and -1 for the terminating step. As a result, the reward threshold is 0 for v0 and v1.



In [5]:
batch_size = 128
learning_rate = 0.001
n_episodes = 300
eps_greedy = 0.1
eps_decay = 0.95
gamma = 0.99
update_target_steps = 100
quality_check_freq = 25

hidden_dim = 32

target_model = Model(input_dim, hidden_dim, output_dim).to(device)
policy_model = Model(input_dim, hidden_dim, output_dim).to(device)
replay_memory = ReplayMemory(int(1e4))

optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, 
    start_factor=1.0, 
    end_factor=0.1, 
    total_iters=n_episodes
)
criterion = nn.MSELoss()

In [None]:
def torch_state(state):
    return torch.tensor(
        state,
        dtype=dtype
    ).unsqueeze(0)

def should_random(eps_greedy):
    return torch.empty(1).uniform_(0.0, 1.0) < eps_greedy
def random_action():
    return torch.randint(low=0, high=2, size=(1,))  # [0, 1]

fig = go.FigureWidget()
fig.add_scatter(x=[], y=[], mode='lines+markers', name='Episode Lengths')
fig.update_layout(title='Episode Lengths', xaxis_title='Episode', yaxis_title='Length')
display(fig)

episode_lens = []
smooth = []
sm = 0
update_it = 0


for episode in tqdm(range(1, n_episodes + 1)):
    finished = False
    
    state, info = env.reset()
    state = torch_state(state)

    episode_len = 0

    while not finished:
        episode_len += 1

        if should_random(eps_greedy):
            action = random_action()
        else:
            policy_model.eval()
            state_c = state.to(device)
            action = policy_model.predict(state_c).to('cpu')
        
        next_state, reward, terminated, truncated, info = env.step(action.item())
        
        next_state = torch_state(next_state)
        reward = torch.tensor([reward], dtype=dtype).unsqueeze(0)
        action = action.unsqueeze(0)

        finished = finished or terminated or truncated
        if finished:
            next_state = None

        replay_memory.add({
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state
        })

        state = next_state

        if len(replay_memory) < batch_size:
            continue
        
        batch = replay_memory.sample(batch_size)
        loss = process_batch(
            batch,
            gamma,
            target_model,
            policy_model,
            criterion,
            optimizer
        )

        update_it += 1
        if update_it == update_target_steps:
            target_model.load_state_dict(policy_model.state_dict())
            update_it = 0

    episode_lens.append(episode_len)
    sm += episode_len
    
    if len(replay_memory) >= batch_size:
        eps_greedy *= eps_decay
        scheduler.step()
    
    
    if episode % quality_check_freq == 0:
        sm /= quality_check_freq
        smooth.append(sm)
        fig.data[0].x = list(range(
            quality_check_freq, 
            quality_check_freq * (len(smooth) + 1), 
            quality_check_freq
        ))
        fig.data[0].y = smooth

        if sm == max_reward:
            print(f"Max reward reached: {max_reward} at episode {episode}")
            break

        sm = 0


FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Episode Lengths',
              'type': 'scatter',
              'uid': '20eee492-a8b7-434e-8cfb-b949f5d07384',
              'x': [],
              'y': []}],
    'layout': {'template': '...',
               'title': {'text': 'Episode Lengths'},
               'xaxis': {'title': {'text': 'Episode'}},
               'yaxis': {'title': {'text': 'Length'}}}
})

 75%|███████▍  | 224/300 [02:14<00:45,  1.66it/s]

Max reward reached: 500 at episode 226





In [8]:
torch.save(target_model, 'best_agent.pt')