In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

In [2]:
env_name = 'CartPole-v0'
env = gym.make(env_name)


# Hard coded policy for the cartpole problem
# Will eventually want to build up infrastructure to develop a policy depending on:
# env.action_space
# env.observation_space

policy = nn.Sequential(
    nn.Linear(4, 12),
    nn.ReLU(),
    nn.Linear(12,12),
    nn.ReLU(),
    nn.Linear(12,2),
    nn.Softmax(dim=-1)
    )

optimizer = optim.Adam(policy.parameters(), lr = .1)

# I guess we'll start with a categorical policy
# TODO investigate the cost of action.detach.numpy() and torch.Tensor(state)
def select_action(policy, state):
    m = Categorical(policy(torch.Tensor(state)))
    action = m.sample()
    logprob = m.log_prob(action)
    
    return action.detach().numpy(), logprob
    

In [3]:
policy(torch.randn(1,4))

tensor([[0.4328, 0.5672]], grad_fn=<SoftmaxBackward>)

In [3]:
#def vanilla_policy_grad(env, policy, optimizer):
    
action_list = []
state_list = []
logprob_list = []
reward_list = []

avg_reward_hist = []

num_epochs = 100
#batch_size = 20 # how many steps we want to use before we update our gradients
num_steps = 100 # number of steps in an episode (unless we terminate early)

loss = torch.zeros(1,requires_grad=True)

for epoch in range(num_epochs):

    # Probably just want to preallocate these with zeros, as either a tensor or an array
    loss_hist = []
    episode_length_hist = []
    action_list = []
    total_steps = 0

    while True:

        state = env.reset()
        logprob_list = []
        reward_list  = []
        action_list  = []
        
        for t in range(num_steps):

            action, logprob = select_action(policy, state)
            state, reward, done, _ = env.step(action.item())

            logprob_list.append(logprob)
            reward_list.append(reward)
            action_list.append(action)
            total_steps += 1

            if done:
                break

        # Now Calculate cumulative rewards for each action
        episode_length_hist.append(t)
        
        reward_ar = torch.tensor(reward_list)
        logprob_ar = torch.stack(logprob_list)
        
        episode_loss = torch.sum(
                        torch.stack(
                            [torch.sum(reward_ar[i:]*logprob_ar[i:]) for i in range(len(reward_list))]
            )
        )

        
        #if total_steps > batch_size:
        # update our gradients
        #print("here")
        avg_reward_hist.append(sum(episode_length_hist)/len(episode_length_hist))
        #other_list.append(1)
        #loss = torch.sum(torch.stack(loss_hist))
        #for action in episode_loss:
        episode_loss.backward()
            
        optimizer.step()


KeyboardInterrupt: 

In [11]:
while True:
    state = env.reset()
    cum_rewards = 0


    for t in range(num_steps):
        action, _ = select_action(policy,state)
        state, reward, done, _ = env.step(action.item())
        env.render()
        
        cum_rewards += reward
        if done:
            
            print('summed reward for espide: ', cum_rewards)
            print('time terminated:' , t)
            break

summed reward for espide:  26.0
time terminated: 25
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  21.0
time terminated: 20
summed reward for espide:  20.0
time terminated: 19
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  23.0
time terminated: 22
summed reward for espide:  21.0
time terminated: 20
summed reward for espide:  9.0
time terminated: 8
summed reward for espide:  9.0
time terminated: 8
summed reward for espide:  25.0
time terminated: 24
summed reward for espide:  9.0
time terminated: 8
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  9.0
time terminated: 8
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  9.0
time terminated: 8
summed reward for espide:  18.0
time terminated: 17
summed reward for espide:  10.0
time terminated: 9
summed reward for espide:  9.

KeyboardInterrupt: 