In [4]:
import gym
import numpy as np
import torch.nn as nn

In [5]:
import torch
print(torch.__version__)

0.4.1


In [6]:
class Net(nn.Module):
    def __init__(self,obs_size, hidden_size,n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(nn.Linear(obs_size,hidden_size),
                                nn.Tanh(),
                                nn.Linear(hidden_size,n_actions))
    
    def forward(self,x):
        return self.net(x)

In [16]:
from torch.distributions.categorical import Categorical
import torch

In [26]:
def get_policy(net, obs):
    logits = net(obs)
    return Categorical(logits=logits)

def get_action(net, obs):
    policy = get_policy(net, obs)
    act = policy.sample().item()
    return act

def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

def reward_to_go_avg(rews, avg):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0) - avg/n
    return rtgs

def disc_rtg_avg(rews, avg):
    rtgs = []
    val = 0
    for i in reversed(range(len(rews))):
        val = val*GAMMA + rews[i] - avg/len(rews)
        rtgs.append(val)
    rtgs = rtgs[::-1]
    return rtgs

def disc_rtg(rews):
    rtgs = []
    val = 0
    for i in reversed(range(len(rews))):
        val = val*GAMMA + rews[i]
        rtgs.append(val)
    rtgs = rtgs[::-1]
    return rtgs

In [118]:
def compute_loss(obs, acts, wts, net):
#     obs_v = torch.FloatTensor(obs)
    policy = get_policy(net,obs)
    log_p = policy.log_prob(acts)
    
    loss = -(log_p*wts).mean()
    entropy_v = policy.entropy().mean()
    return loss, entropy_v, policy

def get_kl_div(obs, old_policy, net ):
    new_policy = get_policy(net, obs)
    kl_div = -((new_policy.probs/old_policy.probs).log() * old_policy.probs).sum(-1).mean()
    return kl_div

In [119]:
def train_one_epoch(env, net, batch_size=5000, render=False):
    batch_obs = []
    batch_wts = []
    batch_acts = []
    batch_rets = []
    batch_len = []
    eps_rew = []
    batch_rew = []
    obs = env.reset()
    done=False
    epoch_finished_rendering = False
    
    while True:
        if not epoch_finished_rendering and render:
            env.render()
        
        act = get_action(net, obs = torch.as_tensor(obs,dtype=torch.float32))
        batch_obs.append(obs.copy())
        batch_acts.append(act)
        
        obs,rew,done,_ = env.step(act)
        
        eps_rew.append(rew)
        batch_rew.append(rew)
#         obs= next_obs
        
        if done:
            eps_ret = sum(eps_rew)
            eps_len = len(eps_rew)
            batch_rets.append(eps_ret)
            batch_len.append(eps_len)
            
            #plain
#             batch_wts = batch_wts + [eps_ret]*eps_len

            #subtract avg reward
#             batch_wts = batch_wts + [eps_ret- avg_rew]*eps_len
            
            # reward to-go
#             batch_wts = batch_wts + list(reward_to_go(eps_rew))

            # reward to-go with avg rew
#             batch_wts = batch_wts + list(reward_to_go_avg(eps_rew, avg_rew))
            
            # disc rtg
            batch_wts = batch_wts + list(disc_rtg(eps_rew))
            
            eps_rew = []
            done = False
            
            obs = env.reset()
            epoch_finished_rendering = True
            
            if len(batch_obs)>batch_size:
                break
    
    optimizer.zero_grad()
    batch_loss, entropy_v, policy = compute_loss(obs = torch.as_tensor(batch_obs, dtype=torch.float32),
                              acts = torch.as_tensor(batch_acts, dtype = torch.int32),
                              wts = torch.as_tensor(batch_wts, dtype = torch.float32),
                             net = net)
    
    
    
    entropy = entropy_v.item()
    batch_loss.backward()
    optimizer.step()
    
    
    kl_div = get_kl_div(obs = torch.as_tensor(batch_obs, dtype=torch.float32),
                       old_policy = policy,
                       net=net)
    return batch_loss,batch_rets, batch_len, entropy, kl_div

In [120]:
env = gym.make('CartPole-v1')
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

obs_size, n_actions

(4, 2)

In [121]:
env.observation_space, env.action_space

(Box(4,), Discrete(2))

In [122]:
HIDDEN_SIZE = 32
BATCH_SIZE = 500
GAMMA = 0.99

In [123]:
net = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions= n_actions)

In [124]:
obs = env.observation_space.sample()
act = get_action(net, obs= torch.as_tensor(obs, dtype=torch.float32))
act

1

In [125]:
a = env.action_space.sample()
a

1

In [126]:
from torch.optim import Adam
lr = 1e-2
optimizer = Adam(net.parameters(), lr=lr)

In [127]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(comment="-vanilla_policy_grad_cartpole_disc")

In [128]:
# train

rew_req = 200
i=0
mean_rew = 0
mean_rew_sum = 0
avg_rew = 0

while i<500:
    i+=1
    render = True if i%20==0 else False
#     render = False
    batch_loss,batch_ret, batch_len, entropy, kl_div = train_one_epoch(env, net,batch_size=BATCH_SIZE, render=render)
    mean_rew = np.mean(batch_ret)
    mean_rew_sum += mean_rew
    avg_rew = mean_rew_sum/i
    if render:
        env.close()
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f \t entropy: %.3f \t kl_div: %.3f'%
                (i, batch_loss, np.mean(batch_ret), np.mean(batch_len), entropy,kl_div))
    writer.add_scalar("loss", batch_loss, i)
    writer.add_scalar("reward_mean", mean_rew, i)
    writer.add_scalar("kl_div", kl_div, i)
    writer.add_scalar("entropy", entropy,i)
#     writer.add_scalar("rew_baseline", baseline,i)

epoch:   1 	 loss: 9.193 	 return: 25.400 	 ep_len: 25.400 	 entropy: 0.691 	 kl_div: 0.002
epoch:   2 	 loss: 12.293 	 return: 33.062 	 ep_len: 33.062 	 entropy: 0.684 	 kl_div: 0.002
epoch:   3 	 loss: 12.725 	 return: 33.800 	 ep_len: 33.800 	 entropy: 0.674 	 kl_div: 0.002
epoch:   4 	 loss: 13.458 	 return: 37.286 	 ep_len: 37.286 	 entropy: 0.668 	 kl_div: 0.001
epoch:   5 	 loss: 15.231 	 return: 45.500 	 ep_len: 45.500 	 entropy: 0.659 	 kl_div: 0.001
epoch:   6 	 loss: 13.723 	 return: 43.333 	 ep_len: 43.333 	 entropy: 0.654 	 kl_div: 0.001
epoch:   7 	 loss: 14.847 	 return: 42.500 	 ep_len: 42.500 	 entropy: 0.642 	 kl_div: 0.001
epoch:   8 	 loss: 16.975 	 return: 51.000 	 ep_len: 51.000 	 entropy: 0.635 	 kl_div: 0.001
epoch:   9 	 loss: 14.603 	 return: 48.182 	 ep_len: 48.182 	 entropy: 0.630 	 kl_div: 0.001
epoch:  10 	 loss: 13.114 	 return: 43.417 	 ep_len: 43.417 	 entropy: 0.630 	 kl_div: 0.001
epoch:  11 	 loss: 18.762 	 return: 66.875 	 ep_len: 66.875 	 entropy: 

epoch:  88 	 loss: 40.968 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.486 	 kl_div: 0.000
epoch:  89 	 loss: 40.631 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.494 	 kl_div: 0.001
epoch:  90 	 loss: 38.515 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.500 	 kl_div: 0.004
epoch:  91 	 loss: 38.996 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.492 	 kl_div: 0.001
epoch:  92 	 loss: 36.137 	 return: 362.500 	 ep_len: 362.500 	 entropy: 0.500 	 kl_div: 0.001
epoch:  93 	 loss: 33.755 	 return: 334.500 	 ep_len: 334.500 	 entropy: 0.508 	 kl_div: 0.001
epoch:  94 	 loss: 35.277 	 return: 295.500 	 ep_len: 295.500 	 entropy: 0.487 	 kl_div: 0.001
epoch:  95 	 loss: 33.373 	 return: 292.000 	 ep_len: 292.000 	 entropy: 0.491 	 kl_div: 0.002
epoch:  96 	 loss: 31.532 	 return: 241.000 	 ep_len: 241.000 	 entropy: 0.489 	 kl_div: 0.003
epoch:  97 	 loss: 31.877 	 return: 228.333 	 ep_len: 228.333 	 entropy: 0.476 	 kl_div: 0.002
epoch:  98 	 loss: 32.724 	 return: 255.500 	 ep_l

epoch: 175 	 loss: 40.535 	 return: 417.500 	 ep_len: 417.500 	 entropy: 0.508 	 kl_div: 0.002
epoch: 176 	 loss: 41.895 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.514 	 kl_div: 0.001
epoch: 177 	 loss: 41.025 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.512 	 kl_div: 0.001
epoch: 178 	 loss: 42.293 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.506 	 kl_div: 0.000
epoch: 179 	 loss: 41.347 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.515 	 kl_div: 0.001
epoch: 180 	 loss: 41.585 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.513 	 kl_div: 0.002
epoch: 181 	 loss: 40.458 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.521 	 kl_div: 0.002
epoch: 182 	 loss: 41.992 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.511 	 kl_div: 0.001
epoch: 183 	 loss: 40.886 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.519 	 kl_div: 0.002
epoch: 184 	 loss: 41.679 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.517 	 kl_div: 0.001
epoch: 185 	 loss: 40.716 	 return: 500.000 	 ep_l

epoch: 262 	 loss: 37.207 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.448 	 kl_div: 0.007
epoch: 263 	 loss: 35.779 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.453 	 kl_div: 0.005
epoch: 264 	 loss: 29.463 	 return: 306.000 	 ep_len: 306.000 	 entropy: 0.458 	 kl_div: 0.007
epoch: 265 	 loss: 34.307 	 return: 345.500 	 ep_len: 345.500 	 entropy: 0.433 	 kl_div: 0.002
epoch: 266 	 loss: 31.463 	 return: 300.000 	 ep_len: 300.000 	 entropy: 0.441 	 kl_div: 0.005
epoch: 267 	 loss: 35.440 	 return: 500.000 	 ep_len: 500.000 	 entropy: 0.442 	 kl_div: 0.004
epoch: 268 	 loss: 31.599 	 return: 300.500 	 ep_len: 300.500 	 entropy: 0.438 	 kl_div: 0.003
epoch: 269 	 loss: 33.095 	 return: 308.000 	 ep_len: 308.000 	 entropy: 0.433 	 kl_div: 0.002
epoch: 270 	 loss: 30.987 	 return: 302.500 	 ep_len: 302.500 	 entropy: 0.437 	 kl_div: 0.001
epoch: 271 	 loss: 33.377 	 return: 344.500 	 ep_len: 344.500 	 entropy: 0.429 	 kl_div: 0.001
epoch: 272 	 loss: 31.507 	 return: 304.500 	 ep_l

KeyboardInterrupt: 

In [20]:
env.close()

# Scratch Pad

In [108]:
c = Categorical(logits = torch.FloatTensor([[1,0,1],[1,1,1],[1,0,0]]))
c

Categorical()

In [117]:
(c.probs * torch.log(c.probs)).sum(dim=-1)

tensor([-1.0174, -1.0986, -0.9753])

In [112]:
c.entropy()

tensor([1.0174, 1.0986, 0.9753])

In [18]:
c.sample()

tensor(0)

In [19]:
count = [0,0,0]
for i in range(0,100):
    x = c.sample().item()
    count[x]+=1
count

[35, 16, 49]

In [20]:
c.log_prob(torch.as_tensor(0, dtype=torch.int32))

tensor(-0.8620)

In [21]:
x =([2]*8)
x

[2, 2, 2, 2, 2, 2, 2, 2]

In [22]:
y = ([1]*3 + [2]*5)
y

[1, 1, 1, 2, 2, 2, 2, 2]

In [23]:
(torch.FloatTensor(x)*torch.FloatTensor(y)).mean()

tensor(3.2500)

In [24]:
reward_to_go([1,0,2,1,0])

array([4, 3, 3, 1, 0])

In [25]:
disc_rtg([1,0,2,1,0])

[3.930499, 2.9601, 2.99, 1.0, 0.0]

In [39]:
[0.3,0.3,0.4] * np.log([0.3,0.3,0.4])

array([-0.36119184, -0.36119184, -0.36651629])

In [202]:
x = [1,2,3]
y = []
y.append(x)
y

[[1, 2, 3]]

In [203]:
x=[4,2,3]
y

[[1, 2, 3]]

In [204]:
y.append(x.copy())

In [205]:
y

[[1, 2, 3], [4, 2, 3]]

In [206]:
y.append(x)

In [207]:
y

[[1, 2, 3], [4, 2, 3], [4, 2, 3]]

In [237]:
obs = env.observation_space.sample()
obs

array([-5.3292841e-01,  1.3936392e+38,  2.2844117e-02,  8.5078421e+37],
      dtype=float32)

In [148]:
net(torch.FloatTensor(obs))

tensor([79045296309277814133001641918405279744.,
         8074543887068950992878338433657864192.], grad_fn=<ThAddBackward>)

In [169]:
c2 = Categorical(net(torch.FloatTensor(obs)))
c2.sample()

tensor(0)

In [173]:
c2.log_prob(torch.as_tensor(1, dtype = torch.int32))

tensor(-2.3786, grad_fn=<SqueezeBackward1>)