In [1]:
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
from collections import deque
from IPython import display

from reinforce_agent import REINFORCE_AGENT

%matplotlib inline

In [2]:
env = gym.make("CartPole-v1")
print(f"State Space  : {env.observation_space}")
print(f"Action Space : {env.action_space}")

State Space  : Box(4,)
Action Space : Discrete(2)


In [3]:
state_space = env.observation_space.shape[0]
action_space = env.action_space.n

reinforce_agent = REINFORCE_AGENT(state_space, action_space, 42)
optimizer = torch.optim.Adam(reinforce_agent.policy.parameters(), lr=0.005)

In [4]:
GAMMA = 1
num_episodes = 2000
scores_window = deque(maxlen=100)
scores = list()


for e in range(1, num_episodes+1):
    saved_probs = list()
    rewards = list()
    policy_loss = list()
    
    state = env.reset()
    
    while True:
        action, action_log_probs = reinforce_agent.act(state)
        saved_probs.append(action_log_probs)
        state, reward, done, _ = env.step(action)
        rewards.append(reward)
        
        if done:
            scores.append(sum(rewards))
            scores_window.append(scores[-1])
            break

    discounts = [GAMMA**i for i in range(len(rewards) + 1)]
    final_reward = sum([discounted*reward for discounted, reward in zip(discounts, rewards)])

    policy_loss.extend([-log_prob*final_reward for log_prob in saved_probs])
    policy_loss = torch.cat(policy_loss).sum()

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()
        
        
    if e % 100 == 0:
        print('Episode {}\tAverage Score: {:.2f}'.format(e, np.mean(scores_window)))
    if np.mean(scores_window)>=195.0:
        print('Environment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(e-100, np.mean(scores_window)))
        break

Episode 100	Average Score: 27.63
Episode 200	Average Score: 44.14
Episode 300	Average Score: 66.32
Episode 400	Average Score: 89.54
Episode 500	Average Score: 100.55
Episode 600	Average Score: 87.38
Episode 700	Average Score: 128.33
Episode 800	Average Score: 134.14
Environment solved in 761 episodes!	Average Score: 195.94
