<a href="https://colab.research.google.com/github/shirsh008/reinforcement-learning-model-training/blob/main/cart_pole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gymnasium as gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
env_id = "CartPole-v1"
env = gym.make(env_id)
eval_env = gym.make(env_id)
s_size = env.observation_space.shape[0]
a_size = env.action_space.n

In [None]:
class Policy(nn.Module):
  def __init__(self, s_size, a_size, h_size):
    super(Policy, self).__init__()
    self.fc1 = nn.Linear(s_size, h_size)
    self.fc2 = nn.Linear(h_size, a_size)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.softmax(x, dim=1)

  def act(self, state):
    state = torch.from_numpy(state).float().unsqueeze(0).to(device)
    probs = self.forward(state).cpu()
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

In [None]:
def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):
  scores_deque = deque(maxlen= 100)
  scores = []
  for i_episode in range(1, n_training_episodes + 1):
    saved_log_prob = []
    rewards = []
    state = env.reset()[0]
    for t in range(max_t):
      action, log_prob = policy.act(state)
      saved_log_prob.append(log_prob)
      state, reward, terminated, truncated, info = env.step(action)
      rewards.append(reward)
      if terminated or truncated:
        break
    scores_deque.append(sum(rewards))
    scores.append(sum(rewards))

    returns = deque(maxlen=max_t)
    n_steps = len(rewards)

    for t in range(n_steps)[::-1]:
      disc_return_t = rewards[t] + (gamma * disc_return_t if t < n_steps - 1 else 0)
      returns.appendleft(disc_return_t)

    eps = np.finfo(np.float32).eps.item()
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)

    policy_loss = []
    for log_prob, disc_return in zip(saved_log_prob, returns):
      policy_loss.append(-log_prob * disc_return)
    policy_loss = torch.cat(policy_loss).sum()

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()
  return scores

In [None]:
cartpole_hyperparameters = {
    "h_size" : 64,
    "n_training_episodes" : 1000,
    "n_evaluation_episodes" : 10,
    "max_t" : 1000,
    "gamma" : 1.0,
    "lr" : 1e-3,
    "env_id" : env_id,
    "state_space" : s_size,
    "action_space" : a_size,
}

In [None]:
cartpole_policy = Policy(cartpole_hyperparameters["state_space"], cartpole_hyperparameters["action_space"], cartpole_hyperparameters["h_size"]).to(device)
cartpole_optimizer = optim.Adam(cartpole_policy.parameters(), lr=cartpole_hyperparameters["lr"])

In [None]:
scores = reinforce(cartpole_policy,
                   cartpole_optimizer,
                   cartpole_hyperparameters["n_training_episodes"],
                   cartpole_hyperparameters["max_t"],
                   cartpole_hyperparameters["gamma"],
                   100)

In [None]:
def evaluate_agent(env, max_steps, n_eval_episodes, policy):
  episode_rewards = []
  for episodes in range(n_eval_episodes):
    state = env.reset()[0] # Modified: get only the observation
    step = 0
    terminated = False
    truncated = False
    total_rewards_ep = 0

    for step in range(max_steps):
      action, _ = policy.act(state)
      new_state, reward, terminated, truncated, info = env.step(action)
      total_rewards_ep += reward

      if terminated or truncated:
        break
      state = new_state
    episode_rewards.append(total_rewards_ep)
  mean_reward = np.mean(episode_rewards)
  std_reward = np.std(episode_rewards)

  return mean_reward, std_reward

In [None]:
mean_reward, std_reward = evaluate_agent(eval_env,
               cartpole_hyperparameters["max_t"],
               cartpole_hyperparameters["n_evaluation_episodes"],
               cartpole_policy)

print(f"Mean reward: {mean_reward:.2f}")
print(f"Standard deviation of reward: {std_reward:.2f}")

Mean reward: 500.00
Standard deviation of reward: 0.00
