In [292]:
import gym
import numpy as np
import torch.nn as nn

In [293]:
import torch
print(torch.__version__)

0.4.1


In [294]:
class Net(nn.Module):
    def __init__(self,obs_size, hidden_size,n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(nn.Linear(obs_size,hidden_size),
                                nn.Tanh(),
                                nn.Linear(hidden_size,n_actions))
    
    def forward(self,x):
        return self.net(x)

In [295]:
from torch.distributions.categorical import Categorical
import torch

In [71]:
c = Categorical(logits = torch.FloatTensor([1,0,1]))
c

Categorical()

In [72]:
c.sample()

tensor(0)

In [75]:
count = [0,0,0]
for i in range(0,100):
    x = c.sample().item()
    count[x]+=1
count

[41, 14, 45]

In [87]:
c.log_prob(torch.as_tensor(0, dtype=torch.int32))

tensor(-0.8620)

In [105]:
x =([2]*8)
x

[2, 2, 2, 2, 2, 2, 2, 2]

In [106]:
y = ([1]*3 + [2]*5)
y

[1, 1, 1, 2, 2, 2, 2, 2]

In [110]:
(torch.FloatTensor(x)*torch.FloatTensor(y)).mean()

tensor(3.2500)

In [285]:
# make function to compute action distribution
def get_policy(obs, net):
    logits = net(obs)
    return Categorical(logits=logits)

# make action selection function (outputs int actions, sampled from policy)
def get_action(obs, net):
    return get_policy(obs, net).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights, net):
    logp = get_policy(obs,net).log_prob(act)
    return -(logp * weights).mean()


In [337]:
# for training policy

def train2(env, net, optimizer,epochs = 50, batch_size=5000, lr=1e-2, render=False):
    
#     optimizer = Adam(net.parameters(), lr=lr)
    
#     # make function to compute action distribution
#     def get_policy(obs):
#         logits = net(obs)
#         return Categorical(logits=logits)

#     # make action selection function (outputs int actions, sampled from policy)
#     def get_action(obs):
#         return get_policy(obs).sample().item()

#     # make loss function whose gradient, for the right data, is policy gradient
#     def compute_loss(obs, act, weights):
#         logp = get_policy(obs).log_prob(act)
#         return -(logp * weights).mean()

    def get_action(obs):
#         obs_v = torch.as_tensor(obs,dtype=torch.float32)
        logits = net(obs)
        acts_sm = Categorical(logits=logits)
        act = acts_sm.sample().item()
        return act
    
    def compute_loss(obs, act, weights):
    #     obs_v = torch.FloatTensor(obs)
        policy = Categorical(logits = net(obs))
        log_p = policy.log_prob(act)
        return -(log_p*weights).mean()
    
    

    def train_one_epoch():
        batch_obs = []
        batch_wts = []
        batch_acts = []
        batch_rets = []
        batch_len = []
        eps_rew = []
        obs = env.reset()
        done=False
        epoch_finished_rendering = False

        while True:
            if not epoch_finished_rendering and render:
                env.render()

            act = get_action(torch.as_tensor(obs,dtype=torch.float32))
            
            batch_obs.append(obs.copy())
            
            batch_acts.append(act)

            obs,rew,done,_ = env.step(act)

            eps_rew.append(rew)

    #         obs= next_obs

            if done:
                eps_ret = sum(eps_rew)
                eps_len = len(eps_rew)
                batch_rets.append(eps_ret)
                batch_len.append(eps_len)

                batch_wts = batch_wts + [eps_ret]*eps_len

                eps_rew = []
                done = False

                obs = env.reset()
                epoch_finished_rendering = True

                if len(batch_obs)>batch_size:
                    break

        optimizer.zero_grad()
        batch_loss = compute_loss(obs = torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act = torch.as_tensor(batch_acts, dtype = torch.int32),
                                  weights = torch.as_tensor(batch_wts, dtype = torch.float32))
        batch_loss.backward()
        optimizer.step()
        return batch_loss,batch_rets, batch_len
    
    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))


In [338]:
env = gym.make('CartPole-v0')
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

obs_size, n_actions

(4, 2)

In [339]:
HIDDEN_SIZE = 32
BATCH_SIZE = 5000

In [340]:
net = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions= n_actions)

In [341]:
from torch.optim import Adam
lr = 1e-2
optimizer = Adam(net.parameters(), lr=lr)

In [281]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(comment="-vanilla_policy_grad")

In [290]:
# train

rew_req = 200
i=0
mean_rew = 0
while i<100:
    i+=1
#     render = True if i%100==0 else False
    render = False
    batch_loss,batch_ret, batch_len = train_one_epoch(env=env,batch_size=BATCH_SIZE, net=net, optimizer=optimizer,
                                                      render=render)
    mean_rew = np.mean(batch_ret)
    if render:
        env.close()
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_ret), np.mean(batch_len)))
    writer.add_scalar("loss", batch_loss, i)
    writer.add_scalar("reward_mean", mean_rew, i)

epoch:   1 	 loss: 21.896 	 return: 24.019 	 ep_len: 24.019
epoch:   2 	 loss: 22.237 	 return: 24.169 	 ep_len: 24.169
epoch:   3 	 loss: 21.804 	 return: 24.488 	 ep_len: 24.488
epoch:   4 	 loss: 23.991 	 return: 25.738 	 ep_len: 25.738
epoch:   5 	 loss: 20.532 	 return: 23.838 	 ep_len: 23.838
epoch:   6 	 loss: 21.582 	 return: 24.184 	 ep_len: 24.184
epoch:   7 	 loss: 22.240 	 return: 25.025 	 ep_len: 25.025
epoch:   8 	 loss: 22.508 	 return: 24.549 	 ep_len: 24.549
epoch:   9 	 loss: 21.976 	 return: 23.943 	 ep_len: 23.943
epoch:  10 	 loss: 21.975 	 return: 24.072 	 ep_len: 24.072
epoch:  11 	 loss: 22.045 	 return: 23.590 	 ep_len: 23.590
epoch:  12 	 loss: 20.400 	 return: 23.245 	 ep_len: 23.245
epoch:  13 	 loss: 23.419 	 return: 25.263 	 ep_len: 25.263
epoch:  14 	 loss: 24.044 	 return: 26.681 	 ep_len: 26.681
epoch:  15 	 loss: 23.757 	 return: 26.067 	 ep_len: 26.067
epoch:  16 	 loss: 22.940 	 return: 25.582 	 ep_len: 25.582
epoch:  17 	 loss: 19.582 	 return: 23.1

In [254]:
wts

[13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 13.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 23.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 15.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 21.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 18.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 25.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,
 20.0,

In [301]:
train(env_name = 'CartPole-v0')

epoch:   0 	 loss: 19.815 	 return: 21.595 	 ep_len: 21.595
epoch:   1 	 loss: 22.077 	 return: 24.198 	 ep_len: 24.198
epoch:   2 	 loss: 27.823 	 return: 28.455 	 ep_len: 28.455
epoch:   3 	 loss: 27.448 	 return: 29.427 	 ep_len: 29.427
epoch:   4 	 loss: 30.861 	 return: 35.549 	 ep_len: 35.549
epoch:   5 	 loss: 33.756 	 return: 37.358 	 ep_len: 37.358
epoch:   6 	 loss: 33.728 	 return: 39.968 	 ep_len: 39.968
epoch:   7 	 loss: 38.875 	 return: 44.105 	 ep_len: 44.105
epoch:   8 	 loss: 41.882 	 return: 47.358 	 ep_len: 47.358
epoch:   9 	 loss: 41.350 	 return: 50.160 	 ep_len: 50.160
epoch:  10 	 loss: 40.570 	 return: 52.853 	 ep_len: 52.853
epoch:  11 	 loss: 44.237 	 return: 54.978 	 ep_len: 54.978
epoch:  12 	 loss: 42.222 	 return: 57.736 	 ep_len: 57.736
epoch:  13 	 loss: 45.412 	 return: 61.753 	 ep_len: 61.753
epoch:  14 	 loss: 45.116 	 return: 62.185 	 ep_len: 62.185
epoch:  15 	 loss: 52.168 	 return: 71.789 	 ep_len: 71.789
epoch:  16 	 loss: 49.939 	 return: 67.4

In [342]:
train2(env=env, net=net, optimizer=optimizer)

epoch:   0 	 loss: 15.750 	 return: 18.913 	 ep_len: 18.913
epoch:   1 	 loss: 18.182 	 return: 20.727 	 ep_len: 20.727
epoch:   2 	 loss: 23.703 	 return: 25.528 	 ep_len: 25.528
epoch:   3 	 loss: 25.573 	 return: 28.401 	 ep_len: 28.401
epoch:   4 	 loss: 25.571 	 return: 28.426 	 ep_len: 28.426
epoch:   5 	 loss: 28.797 	 return: 31.478 	 ep_len: 31.478
epoch:   6 	 loss: 28.968 	 return: 34.517 	 ep_len: 34.517
epoch:   7 	 loss: 35.402 	 return: 38.569 	 ep_len: 38.569
epoch:   8 	 loss: 38.308 	 return: 46.315 	 ep_len: 46.315
epoch:   9 	 loss: 41.127 	 return: 47.990 	 ep_len: 47.990
epoch:  10 	 loss: 46.987 	 return: 55.656 	 ep_len: 55.656
epoch:  11 	 loss: 39.756 	 return: 53.042 	 ep_len: 53.042
epoch:  12 	 loss: 52.824 	 return: 63.062 	 ep_len: 63.062
epoch:  13 	 loss: 50.006 	 return: 66.145 	 ep_len: 66.145


KeyboardInterrupt: 

In [237]:
obs = env.observation_space.sample()
obs

array([-5.3292841e-01,  1.3936392e+38,  2.2844117e-02,  8.5078421e+37],
      dtype=float32)

In [320]:
from gym.spaces import Discrete, Box

def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, 
          epochs=50, batch_size=5000, render=False):

    # make environment, check spaces, get obs / act dims
    env = gym.make(env_name)
    assert isinstance(env.observation_space, Box), \
        "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), \
        "This example only works for envs with discrete action spaces."

    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n

    # make core of policy network
    logits_net = Net(obs_dim, hidden_sizes[0], n_acts)

    # make function to compute action distribution
    def get_policy(obs):
        logits = logits_net(obs)
        return Categorical(logits=logits)

    # make action selection function (outputs int actions, sampled from policy)
    def get_action(obs):
        return get_policy(obs).sample().item()

    # make loss function whose gradient, for the right data, is policy gradient
    def compute_loss(obs, act, weights):
        logp = get_policy(obs).log_prob(act)
        return -(logp * weights).mean()

    # make optimizer
    optimizer = Adam(logits_net.parameters(), lr=lr)

    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len

                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act=torch.as_tensor(batch_acts, dtype=torch.int32),
                                  weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))

In [148]:
net(torch.FloatTensor(obs))

tensor([79045296309277814133001641918405279744.,
         8074543887068950992878338433657864192.], grad_fn=<ThAddBackward>)

In [169]:
c2 = Categorical(net(torch.FloatTensor(obs)))
c2.sample()

tensor(0)

In [173]:
c2.log_prob(torch.as_tensor(1, dtype = torch.int32))

tensor(-2.3786, grad_fn=<SqueezeBackward1>)