In [None]:
import torch
from torch import nn
import gymnasium
import plotly.express as px
from collections import deque
from tqdm import tqdm
import plotly.graph_objects as go
from IPython.display import display



dtype = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
env = gymnasium.make("CartPole-v1")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
print(f"Input dimension: {input_dim}, Output dimension: {output_dim}")

In [None]:
class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.mod = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )
        
    def forward(self, x):
        return self.mod(x)
    
    @torch.no_grad()
    def predict(self, x):
        return self.mod(x).argmax(dim=1)



In [None]:

def process_batch(
        batch, 
        gamma, 
        target_model,
        policy_model, 
        criterion, 
        optimizer
    ):
    policy_model.train()

    non_terms = batch['non_terms']
    terms = batch['terms']

    if terms is not None:
        terms = {k: v.to(device) for k, v in terms.items()}
    if non_terms is not None:
        non_terms = {k: v.to(device) for k, v in non_terms.items()}

    if terms is None:
        with torch.no_grad():
            max_q = target_model(non_terms['next_state']).max(dim=1, keepdim=True).values
        y_true = non_terms['reward'] + gamma * max_q
        actions = non_terms['action']
        X = non_terms['state']
    elif non_terms is None:
        y_true = terms['reward']
        actions = terms['action']
        X = terms['state']
    else:
        with torch.no_grad():
            max_q = target_model(non_terms['next_state']).max(dim=1, keepdim=True).values
        y_true_terms = terms['reward']
        y_true_nonterms = non_terms['reward'] + gamma * max_q
        
        y_true = torch.cat([y_true_terms, y_true_nonterms])
        actions = torch.cat([terms['action'], non_terms['action']])
        X = torch.cat([terms['state'], non_terms['state']])
    
    preds = policy_model.forward(X).gather(index=actions, dim=1)

    if preds.shape[1] == 32 or y_true.shape[1] == 32:
        pass
    loss = criterion(preds, y_true)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


In [None]:
class ReplayMemory():
    def __init__(self, max_capacity):
        self.memory = deque(maxlen=max_capacity)
        self.max_capacity = 10000

    def __len__(self):
        return len(self.memory)
    
    def add(self, transition):
        if len(self.memory) >= self.max_capacity:
            self.memory.popleft()
        self.memory.append(transition)
        
    def sample(self, batch_size):
        ln = len(self.memory)
        indices = torch.randint(0, ln, (batch_size,))

        non_terms = [
            self.memory[i] 
            for i in indices 
            if self.memory[i]['next_state'] is not None
        ]
        terms = [
            self.memory[i] 
            for i in indices 
            if self.memory[i]['next_state'] is None
        ]
        
        def cat(lst, dim=0):
            return torch.cat(lst, dim=dim).to(device)

        return {
            'terms': {
                'state': cat([el['state'] for el in terms]),
                'action': torch.cat([el['action'] for el in terms]),
                'reward': torch.cat([el['reward'] for el in terms]),
            } if len(terms) > 0 else None, 
            'non_terms': {
                'state': cat([el['state'] for el in non_terms]),
                'action': torch.cat([el['action'] for el in non_terms]),
                'reward': torch.cat([el['reward'] for el in non_terms]),
                'next_state': cat([el['next_state'] for el in non_terms]),
            } if len(non_terms) > 0 else None
        }

| Num | Observation           | Min                          | Max                          |
|-----|------------------------|-------------------------------|-------------------------------|
| 0   | Cart Position          | -4.8                          | 4.8                           |
| 1   | Cart Velocity          | -Inf                          | Inf                           |
| 2   | Pole Angle             | ~ -0.418 rad (-24°)           | ~ 0.418 rad (24°)             |
| 3   | Pole Angular Velocity  | -Inf                          | Inf                           |


The cart x-position (index 0) can be take values between (-4.8, 4.8), but the episode terminates if the cart leaves the (-2.4, 2.4) range.

The pole angle can be observed between (-.418, .418) radians (or ±24°), but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°)



### Rewards


Since the goal is to keep the pole upright for as long as possible, by default, a reward of +1 is given for every step taken, including the termination step. The default reward threshold is 500 for v1 and 200 for v0 due to the time limit on the environment.

If sutton_barto_reward=True, then a reward of 0 is awarded for every non-terminating step and -1 for the terminating step. As a result, the reward threshold is 0 for v0 and v1.



In [None]:
batch_size = 32
learning_rate = 0.0001
n_episodes = 1000
eps_greedy = 0.01
gamma = 0.99

target_model = Model(input_dim, output_dim).to(device)
policy_model = Model(input_dim, output_dim).to(device)
replay_memory = ReplayMemory(int(1e4))

optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()

In [None]:
def torch_state(state):
    return torch.tensor(
        state,
        dtype=dtype
    ).unsqueeze(0)

fig = go.FigureWidget()
fig.add_scatter(x=[], y=[], mode='lines+markers', name='Episode Lengths')
fig.update_layout(title='Episode Lengths', xaxis_title='Episode', yaxis_title='Length')
display(fig)

update_target_steps = 100
update_it = 0

episode_lens = []
smooth = []
sm = 0

for episode in tqdm(range(n_episodes)):
    finished = False
    
    state, info = env.reset()
    state = torch_state(state)

    episode_len = 0

    while not finished:
        episode_len += 1

        if torch.empty(1).uniform_(0.0, 1.0) < eps_greedy:
            action = torch.randint(low=0, high=2, size=(1,))  # [0, 1]
        else:
            action = policy_model.predict(state.to(device)).to('cpu')
        
        next_state, reward, terminated, truncated, info = env.step(action.item())
        
        next_state = torch_state(next_state)
        reward = torch.tensor([reward], dtype=dtype).unsqueeze(0)
        action = action.unsqueeze(0)

        finished = finished or terminated or truncated
        if finished:
            next_state = None

        transition = {
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state
        }
        replay_memory.add(transition)

        state = next_state

        if len(replay_memory) < batch_size:
            continue
        
        batch = replay_memory.sample(batch_size)

        loss = process_batch(
            batch,
            gamma,
            target_model,
            policy_model,
            criterion,
            optimizer
        )


        update_it += 1
        if update_it == update_target_steps:
            target_model.load_state_dict(policy_model.state_dict())
            update_it = 0

    episode_lens.append(episode_len)
    
    sm += episode_len
    if episode >= 10:
        sm -= episode_lens[episode - 10]
    sm /= min(episode + 1, 10)
    smooth.append(sm)
    
    fig.data[0].x = list(range(1, episode + 2))
    fig.data[0].y = smooth

In [None]:
batch = replay_memory.sample(batch_size)

In [None]:
process_batch(
    batch,
    gamma,
    target_model,
    policy_model,
    criterion,
    optimizer
)

In [None]:
batch['non_terms']