In [1]:
import gymnasium as gym
import torch
import torch.nn as nn

In [2]:
# # now make the env
# env = gym.make('CartPole-v1', render_mode='human')

# print (f"All possible actions: {env.action_space}")
# print (f"All possible states: {env.observation_space}")

# observation, info = env.reset(seed=42)

# print (f'STATE[INITIAL]: {observation}')
# for _ in range(1000):
#     # insert your learned policy
#     action = env.action_space.sample()
#     print (f"ACTION: {action}")

#     observation, reward, terminated, truncated, info = env.step(action)
#     print (f'STATE[CURR]: {observation} - TERMINATED: {terminated} - TRUNCATED: {truncated}')

#     # if episode has ended or truncated
#     # TODO: Diff b/w terminated and truncated
#     if terminated or truncated:
#         observation, info = env.reset()

# env.close()

In [3]:
config  = {
    "gamma": 0.99, # discount factor used for discounting future rewards
    "clip_epsilon": 0.2,   # used in the clipping function
    "rollout_length": 2048,  # number of timesteps we want the Actor to play with the environment for
    "epochs": 100,    # number of epochs used for PPO update step
    "batch_size": 64, # batch size to use in PPO update
    "lr": 1e-4, # learning rate for PPO update 
}

In [4]:
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

print (f"State dimension: {state_dim}")
print (f"Action dimension: {action_dim}")

State dimension: 4
Action dimension: 2


## Policy network

- Also known as actor network.
- The goal is to learn a policy that the "Actor" will use to interact with the environment


In [5]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
    def forward(self, state):
        # Takes as input a state and returns the action logits
        return self.net(state)


## Value network

- Also known as critic network

In [6]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        # this function takes as input a state and returns the value of that state
        return self.net(state).squeeze(-1) # remove the last dimension of size 1

# Initialise the networks

In [7]:
policy_net = PolicyNetwork(state_dim, action_dim)
value_net = ValueNetwork(state_dim)
policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=config['lr'])
value_optimizer = torch.optim.Adam(value_net.parameters(), lr=config['lr'])


# Utilities

### Advantage function

In [8]:
def calc_advantages(rewards, values, gamma=0.09):
    # first calculate the rewards-to-go at each step
    # we already are given the expected rewards AKA values at each step
    # advantage of each step is the differennce b/w the rewards-to-go and the value of that step
    cuml_reward= 0
    rewards_to_go = []
    for reward in reversed(rewards):
        cuml_reward = reward + gamma * cuml_reward
        rewards_to_go.insert(0, cuml_reward)
    
    # convert rewards_to_go and values to tensors
    rewards_to_go = torch.tensor(rewards_to_go, dtype=torch.float32)
    values = torch.tensor(values, dtype=torch.float32)

    # calculate advantages for each time step
    advantages = rewards_to_go - values
    return advantages, rewards_to_go


def test_calc_advantages():
    rewards = [1.0, 1.5, -2.0, -1.5, 3.0, 2.5]
    values = [6.0, 5.5, -3.5, 9.0, 8.5, 10.0]
    advantages, rewards_to_go = calc_advantages(rewards, values)

    print("Rewards to go:", rewards_to_go)
    print("Advantages:", advantages)

test_calc_advantages()

Rewards to go: tensor([ 1.1179,  1.3102, -2.1089, -1.2098,  3.2250,  2.5000])
Advantages: tensor([ -4.8821,  -4.1898,   1.3911, -10.2098,  -5.2750,  -7.5000])


### PPO update function

In [9]:
NUM_ITERATIONS = 20

for iter_idx in range(NUM_ITERATIONS):

    states, actions, rewards, log_probs, values = [], [], [], [], []
    state, _ = env.reset()   # initial state

    for _ in range(config['rollout_length']):
        # get the action from the current policy
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_logits = policy_net(state_tensor)
        action_dist = torch.distributions.Categorical(logits=action_logits) # get the action distribution
        action = action_dist.sample()   # use the distribution to sample an action


        # pass the action to the environment
        next_state, reward, terminated, truncated, _ = env.step(action.item())

        states.append(state_tensor)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(action_dist.log_prob(action))
        values.append(value_net(state_tensor).item())

        state = next_state if not terminated and not truncated else env.reset()[0]
    
    # convert the collected data to tensors
    states = torch.stack(states)
    actions = torch.stack(actions)
    old_log_probs = torch.stack(log_probs).detach()
    # we will use the collected rewards and values to calculate the advantages
    advantages, rewards_to_go = calc_advantages(rewards, values, config['gamma'])

    # TODO: Why is this necessary?
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # normalize the advantages

    # now we have to update the policy and value networks
    # using rollout data we collected
    # we will show the rollout data to the policy and value networks multiple times (epochs) times
    # so the networks can learn different aspects from the collected rollout data
    for _ in range(config['epochs']):
        # We will do batch updates to train the forward and backward passes faster
        for i in range(0, config['rollout_length'], config['batch_size']):
            idx = slice(i, i + config['batch_size'])
            batch_states = states[idx]
            batch_actions = actions[idx]
            batch_advantages = advantages[idx]
            batch_old_log_probs = old_log_probs[idx]
            batch_rewards_to_go = rewards_to_go[idx]

            # get the action logits for the batch from the updated policy
            new_action_logits = policy_net(batch_states)
            new_action_dist = torch.distributions.Categorical(logits=new_action_logits)
            new_log_probs = new_action_dist.log_prob(batch_actions)   # get the log probabilities of the actions using the new policy

            ratio = (new_log_probs - batch_old_log_probs).exp()   # this can also be written as ratio = new_probs/old_probs
            unclipped = ratio * batch_advantages
            clipped = torch.clamp(ratio, 1 - config['clip_epsilon'], 1 + config['clip_epsilon']) * batch_advantages
            policy_loss = -torch.min(unclipped, clipped).mean()


            value_preds = value_net(batch_states)
            value_loss = nn.MSELoss()(value_preds, batch_rewards_to_go)

            policy_optimizer.zero_grad() # just zero out the gradients before the backward pass
            policy_loss.backward()
            policy_optimizer.step()

            value_optimizer.zero_grad() # just zero out the gradients before the backward pass
            value_loss.backward()
            value_optimizer.step()

        
    print (f"Iteration {iter_idx+1}/{NUM_ITERATIONS} - Policy Loss: {policy_loss.item()} - Value Loss: {value_loss.item()}")



Iteration 1/20 - Policy Loss: 4.604244232177734 - Value Loss: 448.4206237792969
Iteration 2/20 - Policy Loss: 4.408319473266602 - Value Loss: 367.6126708984375
Iteration 3/20 - Policy Loss: 3.796520471572876 - Value Loss: 2441.7431640625
Iteration 4/20 - Policy Loss: 3.0511834621429443 - Value Loss: 4432.21533203125
Iteration 5/20 - Policy Loss: 2.7016100883483887 - Value Loss: 4440.46142578125
Iteration 6/20 - Policy Loss: 3.447909355163574 - Value Loss: 4865.9033203125
Iteration 7/20 - Policy Loss: 4.083921432495117 - Value Loss: 4731.8359375
Iteration 8/20 - Policy Loss: 4.270987510681152 - Value Loss: 4815.7451171875
Iteration 9/20 - Policy Loss: 4.466119766235352 - Value Loss: 4901.462890625
Iteration 10/20 - Policy Loss: 4.504435062408447 - Value Loss: 4815.451171875
Iteration 11/20 - Policy Loss: 4.4711689949035645 - Value Loss: 4833.5751953125
Iteration 12/20 - Policy Loss: 4.628590106964111 - Value Loss: 4879.32958984375
Iteration 13/20 - Policy Loss: 4.575385093688965 - Value

In [None]:
# Initial observations are - 
# 1. The value loss is very high > 4000
# 2. We also need to continuously visualise the losses
