In [1]:
import gym
import numpy as np
import torch.nn as nn
import pdb

In [2]:
import torch
print(torch.__version__)

0.4.1


In [3]:
class Net(nn.Module):
    def __init__(self,obs_size, hidden_size,n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(nn.Linear(obs_size,hidden_size),
                                nn.Tanh(),
                                nn.Linear(hidden_size,n_actions))
    
    def forward(self,x):
        return self.net(x)

In [4]:
from torch.distributions.categorical import Categorical
from torch.distributions.normal import Normal
import torch

In [13]:
def get_policy(net, obs):
    logits = net(obs)
    return Categorical(logits=logits)

def get_policy_cont(net, obs):
    mu = net(obs)
    mu = torch.tanh(mu)*1.8
    return Normal(loc=mu, scale=STD)

def get_action(net, obs, cont=False):
    if cont:
        policy = get_policy_cont(net, obs)
    else:
        policy = get_policy(net, obs)
    act = policy.sample().item()
    if cont:
        act = np.array([act])
    return act

def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

def reward_to_go_avg(rews, avg):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0) - avg/n
    return rtgs

In [14]:
def compute_loss(obs, acts, wts, net, cont=False):
#     obs_v = torch.FloatTensor(obs)
    if cont:
        policy = get_policy_cont(net, obs)
    else:
        policy = get_policy(net,obs)
#     print(policy)
#     pdb.set_trace()
    log_p = policy.log_prob(acts)
    return -(log_p*wts).mean(), policy.entropy().mean()



In [15]:
def get_critic_targets(eps_rews):
    targets = []
    val = 0
    for i in reversed(range(len(eps_rews))):
        val = val*GAMMA + eps_rews[i]
        targets.append(val)
    targets = targets[::-1]
    return targets

def get_critic_targets_biased(eps_rews, eps_obs, critic):
    targets = []

    eps_rews = torch.as_tensor(eps_rews,dtype=torch.float32)
        
#     pdb.set_trace()
    
#     eps_rews = (eps_rews - torch.mean(eps_rews))/(torch.std(eps_rews))
    eps_rews = eps_rews / torch.sum(eps_rews)
    
    
    for i in range(len(eps_obs)):
        if (i==len(eps_obs)-1):
            val = eps_rews[i]
        else:
            next_obs = eps_obs[i+1]
            val = eps_rews[i] + GAMMA*critic(torch.as_tensor(next_obs, dtype=torch.float32))
        targets.append(val)
        
    
    return targets

def get_advantage(batch_obs, batch_rews, critic):
#     batch_obs = torch.as_tensor(batch_obs, dtype = torch.float32)
    targets = []
    
    batch_rews = torch.as_tensor(batch_rews,dtype=torch.float32)
    
#     batch_rews = (batch_rews - torch.mean(batch_rews))/(torch.std(batch_rews))
    batch_rews = batch_rews/torch.sum(batch_rews)
    
    for i in range(len(batch_obs)):
        obs = batch_obs[i]
        obs = torch.as_tensor(obs, dtype = torch.float32)
        
        if (i==len(batch_obs)-1):
            target = batch_rews[i] - critic(obs).item() 
        else:
            next_obs = batch_obs[i+1]
            next_obs = torch.as_tensor(next_obs, dtype = torch.float32)
            target = batch_rews[i] + GAMMA*(critic(next_obs).item()) - critic(obs).item()
        
        targets.append(target.item())
    
    # last advantage should just be the rew
#     targets.append(batch_rews[-1])
    
    return targets


def get_batch_advantage(batch_obs, batch_rew, critic):
    advantage = []
    for i in range(len(batch_obs)):
        eps_obs = batch_obs[i]
        eps_rew = batch_rew[i]
        adv = []
#         eps_obs = torch.as_tensor(batch_obs, dtype = torch.float32)
        for j in range(len(eps_obs)):
            obs = eps_obs[j]
            obs = torch.as_tensor(obs, dtype = torch.float32)
            
            if (j==len(eps_obs)-1):
                val = eps_rew[j] - critic(obs).item()
            else:
                next_obs = eps_obs[j+1]
                next_obs = torch.as_tensor(next_obs, dtype = torch.float32)
                val = eps_rew[j] + GAMMA*critic(next_obs).item() - critic(obs).item()
            adv.append(val)
            
        advantage.extend(adv)
    
    return advantage

def get_normalised_adv_targets(ep_obs, ep_rews, critic):
    rewards = []
    disc_rew = 0
    for rew in ep_rews[::-1]:
        disc_rew = GAMMA*disc_rew + rew
        rewards.append(disc_rew)
        
    returns = rewards[::-1]
    eps = 1e-6

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    
#     targets=rewards
    targets = returns
    advantage = []
    
    ep_obs = torch.as_tensor(ep_obs, dtype = torch.float32)
    
#     for obs, rew in zip(ep_obs, returns):
#         adv = rew - critic(obs)
#         advantage.append(adv.item())
    
    for i in range(len(ep_obs)):
        a = returns[i] - critic(ep_obs[i])
#         if (i==len(ep_obs)-1):
#             a = rewards[i] - critic(ep_obs[i])
#         else:
#             a = rewards[i] + critic(ep_obs[i+1]) - critic(ep_obs[i])
        
        advantage.append(a.item())
    
    return advantage, targets

In [17]:
def train_one_epoch(env, net,critic, cont, batch_size=5000, render=False):
    batch_obs = []
    batch_wts = []
    batch_acts = []
    batch_rets = []
    batch_len = []
    eps_rew = []
    batch_rews = []
    targets = []
    advantage = []
    eps_obs = []
    obs = env.reset()
    done=False
    epoch_finished_rendering = False
    
    batch_pack_obs = []
    batch_pack_rew = []
    
    while True:
        if not epoch_finished_rendering and render:
            env.render()
        
        act = get_action(net, obs = torch.as_tensor(obs,dtype=torch.float32), cont=cont)
        batch_obs.append(obs.copy())
        batch_acts.append(act)
        
        eps_obs.append(obs)
        
        obs,rew,done,_ = env.step(act)
        
        eps_rew.append(rew)
        
        batch_rews.append(rew)
#         obs= next_obs
        
        if done:
            eps_ret = sum(eps_rew)
            eps_len = len(eps_rew)
            batch_rets.append(eps_ret)
            batch_len.append(eps_len)
            
#             eps_target = get_critic_targets(eps_rew)
#             eps_target = get_critic_targets_biased(eps_rew,eps_obs, critic)
#             targets.extend(eps_target)
            
#             # get advantage estimate
#             adv = get_advantage(eps_obs, eps_rew, critic)
#             advantage.extend(adv)


            batch_pack_obs.append(eps_obs)
            batch_pack_rew.append(eps_rew)
        
            ep_adv, ep_targ = get_normalised_adv_targets(eps_obs, eps_rew, critic)
            advantage.extend(ep_adv)
            targets.extend(ep_targ)
            
            #plain
#             batch_wts = batch_wts + [eps_ret]*eps_len

            #subtract avg reward
#             batch_wts = batch_wts + [eps_ret- avg_rew]*eps_len
            
            # reward to-go
#             batch_wts = batch_wts + list(reward_to_go(eps_rew))

            # reward to-go with avg rew
#             batch_wts = batch_wts + list(reward_to_go_avg(eps_rew, avg_rew))
            
            eps_rew = []
            eps_obs = []
            done = False
            
            obs = env.reset()
            epoch_finished_rendering = True
            
            if len(batch_obs)>batch_size:
                break
    
    # critic update
    pred_values = critic(torch.as_tensor(batch_obs, dtype = torch.float32))
#     actual_values = get_critic_targets(batch_rews)
    
    optimizer_critic.zero_grad()
    batch_loss_critic = loss_mae(pred_values.reshape(-1),
                                 torch.as_tensor(targets, dtype = torch.float32))
    batch_loss_critic.backward()
    optimizer_critic.step()
    
#     # get advantage estimate
#     advantage = get_advantage(batch_obs, batch_rews, critic)
#     advantage = get_batch_advantage(batch_pack_obs, batch_pack_rew, critic)
    
#     pdb.set_trace()
    # policy network update
    optimizer.zero_grad()
    
    if cont:
        batch_act_v = torch.as_tensor(batch_acts, dtype=torch.float32)
    else:
        batch_act_v = torch.as_tensor(batch_acts, dtype = torch.int32)
    
    batch_loss, entropy_v = compute_loss(obs = torch.as_tensor(batch_obs, dtype=torch.float32),
                              acts = batch_act_v,
                              wts = torch.as_tensor(advantage, dtype = torch.float32),
                             net = net,cont=cont)
    
    entropy = entropy_v.item()
    
    batch_loss.backward()
    optimizer.step()
    return batch_loss,batch_rets, batch_len, batch_loss_critic, advantage, entropy

In [18]:
# env = gym.make('CartPole-v1')
# env = gym.make('Pendulum-v0')
# env = gym.make('MountainCar-v0')
env = gym.make('LunarLander-v2')
cont = False

obs_size = env.observation_space.shape[0]

if cont:
    n_actions = env.action_space.shape[0]
else:
    n_actions = env.action_space.n
obs_size, n_actions



(8, 4)

In [19]:
env.observation_space, env.action_space

(Box(8,), Discrete(4))

In [20]:
HIDDEN_SIZE = 32
BATCH_SIZE = 500
GAMMA = 0.99
STD = 0.1

In [21]:
net = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions= n_actions)

In [22]:
obs = env.observation_space.sample()
act = get_action(net, obs= torch.as_tensor(obs, dtype=torch.float32), cont=cont)
act

1

In [23]:
obs = env.observation_space.sample()
pol = get_policy(net, obs= torch.as_tensor(obs, dtype=torch.float32))
# pol = get_policy_cont(net,obs= torch.as_tensor(obs, dtype=torch.float32))
pol

Categorical()

In [25]:
pol.log_prob(torch.as_tensor(act, dtype=torch.float32))

tensor(-1.2531, grad_fn=<SqueezeBackward1>)

In [26]:
a = env.action_space.sample()
a

0

In [27]:
from torch.optim import Adam
lr = 1e-2
optimizer = Adam(net.parameters(), lr=lr)

In [28]:
# critic
critic = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions=1)
lr_c = 1e-2
optimizer_critic = Adam(critic.parameters(), lr=lr_c)

from torch.nn import MSELoss
loss_mae = MSELoss(reduction='sum')
# loss_mae = MSELoss()

In [29]:
obs = env.observation_space.sample()
value = critic(torch.as_tensor(obs, dtype=torch.float32))
value

tensor([0.1364], grad_fn=<ThAddBackward>)

In [30]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(comment="-actor_critic_mntcar")

In [31]:
# train

rew_req = 200
i=0
mean_rew = 0
mean_rew_sum = 0
avg_rew = 0

while i<500:
    i+=1
    render = True if i%20==0 else False
#     render = False
    batch_loss,batch_ret, batch_len, critic_loss,advantage, entropy = train_one_epoch(env, net, critic, cont,batch_size=BATCH_SIZE, render=render)
    mean_rew = np.mean(batch_ret) # mean return per episode
    mean_rew_sum += mean_rew  # sum of returns of episodes from start
    avg_rew = mean_rew_sum/i  # avg reward per episode from start
    if render:
        env.close()
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f \t critic_loss: %.3f \t adv: %.3f'%
                (i, batch_loss, np.mean(batch_ret), np.mean(batch_len), critic_loss, np.mean(advantage)))
    writer.add_scalar("loss", batch_loss, i)
    writer.add_scalar("reward_mean", mean_rew, i)
    writer.add_scalar('entropy', entropy,i)
    writer.add_scalar('advantage',np.mean(advantage),i)
    writer.add_scalar('critc_loss',critic_loss,i)

epoch:   1 	 loss: 0.141 	 return: -301.001 	 ep_len: 86.500 	 critic_loss: 456.352 	 adv: 0.084
epoch:   2 	 loss: 0.039 	 return: -260.674 	 ep_len: 88.333 	 critic_loss: 518.926 	 adv: 0.030
epoch:   3 	 loss: -0.059 	 return: -247.123 	 ep_len: 91.833 	 critic_loss: 582.611 	 adv: -0.035
epoch:   4 	 loss: 0.049 	 return: -189.375 	 ep_len: 90.667 	 critic_loss: 536.127 	 adv: 0.031
epoch:   5 	 loss: 0.023 	 return: -186.485 	 ep_len: 95.500 	 critic_loss: 553.543 	 adv: 0.027
epoch:   6 	 loss: -0.008 	 return: -166.354 	 ep_len: 103.600 	 critic_loss: 610.368 	 adv: 0.007
epoch:   7 	 loss: -0.128 	 return: -152.780 	 ep_len: 81.857 	 critic_loss: 612.724 	 adv: -0.100
epoch:   8 	 loss: -0.095 	 return: -180.444 	 ep_len: 90.667 	 critic_loss: 589.227 	 adv: -0.074
epoch:   9 	 loss: -0.184 	 return: -169.805 	 ep_len: 97.833 	 critic_loss: 651.115 	 adv: -0.117
epoch:  10 	 loss: -0.087 	 return: -124.913 	 ep_len: 102.400 	 critic_loss: 490.891 	 adv: -0.073
epoch:  11 	 loss

epoch:  84 	 loss: -0.538 	 return: -112.007 	 ep_len: 184.333 	 critic_loss: 534.094 	 adv: -0.438
epoch:  85 	 loss: 0.180 	 return: -87.032 	 ep_len: 449.000 	 critic_loss: 1020.179 	 adv: 0.147
epoch:  86 	 loss: -0.013 	 return: -121.580 	 ep_len: 258.000 	 critic_loss: 456.497 	 adv: 0.007
epoch:  87 	 loss: 0.457 	 return: -29.532 	 ep_len: 1000.000 	 critic_loss: 692.234 	 adv: 0.391
epoch:  88 	 loss: 0.342 	 return: -37.794 	 ep_len: 557.500 	 critic_loss: 987.093 	 adv: 0.307
epoch:  89 	 loss: -0.651 	 return: -23.108 	 ep_len: 167.000 	 critic_loss: 491.758 	 adv: -0.581
epoch:  90 	 loss: 0.107 	 return: -43.139 	 ep_len: 366.250 	 critic_loss: 1130.261 	 adv: 0.099
epoch:  91 	 loss: -0.515 	 return: -82.644 	 ep_len: 147.000 	 critic_loss: 662.374 	 adv: -0.458
epoch:  92 	 loss: 0.234 	 return: -32.221 	 ep_len: 369.750 	 critic_loss: 1215.390 	 adv: 0.193
epoch:  93 	 loss: -0.517 	 return: -27.152 	 ep_len: 136.000 	 critic_loss: 369.531 	 adv: -0.447
epoch:  94 	 lo

epoch: 168 	 loss: -0.346 	 return: 32.468 	 ep_len: 662.500 	 critic_loss: 640.004 	 adv: -0.285
epoch: 169 	 loss: 0.214 	 return: 110.256 	 ep_len: 1000.000 	 critic_loss: 234.916 	 adv: 0.214
epoch: 170 	 loss: -0.252 	 return: 35.185 	 ep_len: 476.000 	 critic_loss: 627.342 	 adv: -0.235
epoch: 171 	 loss: -0.089 	 return: 41.257 	 ep_len: 634.500 	 critic_loss: 579.761 	 adv: -0.083
epoch: 172 	 loss: -0.073 	 return: 34.723 	 ep_len: 633.500 	 critic_loss: 504.072 	 adv: -0.061
epoch: 173 	 loss: -0.039 	 return: 17.368 	 ep_len: 624.000 	 critic_loss: 760.825 	 adv: -0.017
epoch: 174 	 loss: -0.151 	 return: 80.557 	 ep_len: 653.000 	 critic_loss: 621.848 	 adv: -0.127
epoch: 175 	 loss: -0.973 	 return: -7.131 	 ep_len: 316.500 	 critic_loss: 676.367 	 adv: -0.846
epoch: 176 	 loss: 0.223 	 return: 41.957 	 ep_len: 1000.000 	 critic_loss: 248.954 	 adv: 0.250
epoch: 177 	 loss: 0.002 	 return: 48.186 	 ep_len: 622.500 	 critic_loss: 608.890 	 adv: 0.000
epoch: 178 	 loss: 0.32

epoch: 252 	 loss: -0.012 	 return: 79.851 	 ep_len: 619.000 	 critic_loss: 539.122 	 adv: -0.019
epoch: 253 	 loss: 0.237 	 return: 110.742 	 ep_len: 1000.000 	 critic_loss: 308.863 	 adv: 0.242
epoch: 254 	 loss: 0.251 	 return: 144.827 	 ep_len: 1000.000 	 critic_loss: 319.791 	 adv: 0.257
epoch: 255 	 loss: 0.120 	 return: 113.624 	 ep_len: 1000.000 	 critic_loss: 251.840 	 adv: 0.149
epoch: 256 	 loss: 0.054 	 return: 124.517 	 ep_len: 1000.000 	 critic_loss: 241.440 	 adv: 0.083
epoch: 257 	 loss: 0.175 	 return: 84.518 	 ep_len: 1000.000 	 critic_loss: 334.299 	 adv: 0.202
epoch: 258 	 loss: 0.021 	 return: 117.188 	 ep_len: 1000.000 	 critic_loss: 276.569 	 adv: 0.035
epoch: 259 	 loss: -0.842 	 return: -38.327 	 ep_len: 285.000 	 critic_loss: 572.572 	 adv: -0.848
epoch: 260 	 loss: -0.272 	 return: 138.317 	 ep_len: 719.000 	 critic_loss: 1294.877 	 adv: 0.061
epoch: 261 	 loss: -0.090 	 return: 49.326 	 ep_len: 636.500 	 critic_loss: 418.022 	 adv: -0.096
epoch: 262 	 loss: 

epoch: 336 	 loss: 0.123 	 return: 116.060 	 ep_len: 1000.000 	 critic_loss: 655.644 	 adv: 0.088
epoch: 337 	 loss: 0.082 	 return: 63.566 	 ep_len: 680.500 	 critic_loss: 843.640 	 adv: 0.036
epoch: 338 	 loss: -0.132 	 return: 12.089 	 ep_len: 279.000 	 critic_loss: 313.411 	 adv: -0.175
epoch: 339 	 loss: 0.122 	 return: 109.766 	 ep_len: 1000.000 	 critic_loss: 604.513 	 adv: 0.149
epoch: 340 	 loss: -0.297 	 return: 18.549 	 ep_len: 356.500 	 critic_loss: 384.122 	 adv: -0.313
epoch: 341 	 loss: -0.295 	 return: -13.280 	 ep_len: 322.000 	 critic_loss: 312.264 	 adv: -0.289
epoch: 342 	 loss: -0.462 	 return: 76.312 	 ep_len: 324.500 	 critic_loss: 790.728 	 adv: -0.304
epoch: 343 	 loss: -0.040 	 return: 204.147 	 ep_len: 665.000 	 critic_loss: 1650.826 	 adv: 0.252
epoch: 344 	 loss: 0.289 	 return: 99.093 	 ep_len: 1000.000 	 critic_loss: 440.680 	 adv: 0.317
epoch: 345 	 loss: 0.273 	 return: 158.253 	 ep_len: 1000.000 	 critic_loss: 436.425 	 adv: 0.323
epoch: 346 	 loss: 0.

epoch: 420 	 loss: -0.547 	 return: 81.118 	 ep_len: 321.500 	 critic_loss: 797.904 	 adv: -0.367
epoch: 421 	 loss: 0.196 	 return: 137.674 	 ep_len: 1000.000 	 critic_loss: 419.311 	 adv: 0.254
epoch: 422 	 loss: 0.052 	 return: 44.070 	 ep_len: 632.500 	 critic_loss: 397.596 	 adv: 0.049
epoch: 423 	 loss: 0.032 	 return: 116.172 	 ep_len: 1000.000 	 critic_loss: 336.798 	 adv: 0.048
epoch: 424 	 loss: -0.370 	 return: 68.344 	 ep_len: 408.000 	 critic_loss: 916.248 	 adv: -0.247
epoch: 425 	 loss: -0.521 	 return: 96.524 	 ep_len: 287.500 	 critic_loss: 717.796 	 adv: -0.396
epoch: 426 	 loss: 0.139 	 return: 79.101 	 ep_len: 1000.000 	 critic_loss: 288.163 	 adv: 0.112
epoch: 427 	 loss: -0.556 	 return: 34.949 	 ep_len: 226.667 	 critic_loss: 581.520 	 adv: -0.545
epoch: 428 	 loss: 0.043 	 return: 190.265 	 ep_len: 1000.000 	 critic_loss: 388.900 	 adv: 0.078
epoch: 429 	 loss: -0.392 	 return: 67.949 	 ep_len: 268.667 	 critic_loss: 649.097 	 adv: -0.410
epoch: 430 	 loss: 0.11

In [None]:
env.close()

# Scratchpad

In [5]:
c = Categorical(logits = torch.FloatTensor([1,0,1]))
c

Categorical()

In [6]:
c.sample()

tensor(2)

In [7]:
count = [0,0,0]
for i in range(0,100):
    x = c.sample().item()
    count[x]+=1
count

[39, 13, 48]

In [8]:
c.log_prob(torch.as_tensor(0, dtype=torch.int32))

tensor(-0.8620)

In [9]:
x =([2]*8)
x = torch.FloatTensor(x)
len(x)

8

In [10]:
y = ([1]*3 + [2]*5)
y

[1, 1, 1, 2, 2, 2, 2, 2]

In [11]:
(torch.FloatTensor(x)*torch.FloatTensor(y)).mean()

tensor(3.2500)

In [12]:
reward_to_go([1,0,2,1,0])

NameError: name 'reward_to_go' is not defined

In [None]:
np.mean([1,0,2,1,0]),np.std([1,0,2,1,0])

In [16]:
x = [1,1,0,0,0,1]
get_critic_targets(x)
y = [5,4]
x.extend(y)
x

NameError: name 'GAMMA' is not defined

In [202]:
x = [1,2,3]
y = []
y.append(x)
y

[[1, 2, 3]]

In [203]:
x=[4,2,3]
y

[[1, 2, 3]]

In [204]:
y.append(x.copy())

In [205]:
y

[[1, 2, 3], [4, 2, 3]]

In [206]:
y.append(x)

In [207]:
y

[[1, 2, 3], [4, 2, 3], [4, 2, 3]]

In [237]:
obs = env.observation_space.sample()
obs

array([-5.3292841e-01,  1.3936392e+38,  2.2844117e-02,  8.5078421e+37],
      dtype=float32)

In [148]:
net(torch.FloatTensor(obs))

tensor([79045296309277814133001641918405279744.,
         8074543887068950992878338433657864192.], grad_fn=<ThAddBackward>)

In [169]:
c2 = Categorical(net(torch.FloatTensor(obs)))
c2.sample()

tensor(0)

In [173]:
c2.log_prob(torch.as_tensor(1, dtype = torch.int32))

tensor(-2.3786, grad_fn=<SqueezeBackward1>)

In [377]:
def test_rew1(ep_rews):
    rewards = []
    disc_rew = 0
    for rew in ep_rews[::-1]:
        disc_rew = GAMMA*disc_rew + rew
        rewards.append(disc_rew)
    rewards = rewards[::-1]
    return rewards
        
def test_rew2(ep_rews):
    R = 0
    returns = []
    eps = 1e-6
    for r in ep_rews[::-1]:
        # calculate the discounted value
        R = r + GAMMA * R
        returns.insert(0, R)
    return returns

In [378]:
x = [1,1,1,1,1,1]
test_rew1(x)

[5.8519850599, 4.90099501, 3.9403989999999998, 2.9701, 1.99, 1.0]

In [379]:
test_rew2(x)

[5.8519850599, 4.90099501, 3.9403989999999998, 2.9701, 1.99, 1.0]