In [32]:
import gym
import numpy as np
import torch.nn as nn
import pdb

In [33]:
import torch
print(torch.__version__)

0.4.1


In [34]:
class Net(nn.Module):
    def __init__(self,obs_size, hidden_size,n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(nn.Linear(obs_size,hidden_size),
                                nn.Tanh(),
                                nn.Linear(hidden_size,n_actions))
    
    def forward(self,x):
        return self.net(x)

In [251]:
from torch.distributions.categorical import Categorical
from torch.distributions.normal import Normal
import torch

In [36]:
c = Categorical(logits = torch.FloatTensor([1,0,1]))
c

Categorical()

In [37]:
c.sample()

tensor(1)

In [38]:
count = [0,0,0]
for i in range(0,100):
    x = c.sample().item()
    count[x]+=1
count

[47, 18, 35]

In [39]:
c.log_prob(torch.as_tensor(0, dtype=torch.int32))

tensor(-0.8620)

In [40]:
x =([2]*8)
x = torch.FloatTensor(x)
len(x)

8

In [None]:
y = ([1]*3 + [2]*5)
y

In [42]:
(torch.FloatTensor(x)*torch.FloatTensor(y)).mean()

tensor(3.2500)

In [43]:
reward_to_go([1,0,2,1,0])

array([4, 3, 3, 1, 0])

In [130]:
np.mean([1,0,2,1,0]),np.std([1,0,2,1,0])

(0.8, 0.7483314773547883)

In [282]:
def get_policy(net, obs):
    logits = net(obs)
    return Categorical(logits=logits)

def get_policy_cont(net, obs):
    mu = net(obs)
    mu = torch.tanh(mu)*1.8
    return Normal(loc=mu, scale=STD)

def get_action(net, obs, cont=False):
    if cont:
        policy = get_policy_cont(net, obs)
    else:
        policy = get_policy(net, obs)
    act = policy.sample().item()
    if cont:
        act = np.array([act])
    return act

def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

def reward_to_go_avg(rews, avg):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0) - avg/n
    return rtgs

In [294]:
def compute_loss(obs, acts, wts, net, cont=False):
#     obs_v = torch.FloatTensor(obs)
    if cont:
        policy = get_policy_cont(net, obs)
    else:
        policy = get_policy(net,obs)
#     print(policy)
#     pdb.set_trace()
    log_p = policy.log_prob(acts)
    return -(log_p*wts).mean(), policy.entropy().mean()



In [426]:
def get_critic_targets(eps_rews):
    targets = []
    val = 0
    for i in reversed(range(len(eps_rews))):
        val = val*GAMMA + eps_rews[i]
        targets.append(val)
    targets = targets[::-1]
    return targets

def get_critic_targets_biased(eps_rews, eps_obs, critic):
    targets = []

    eps_rews = torch.as_tensor(eps_rews,dtype=torch.float32)
        
#     pdb.set_trace()
    
#     eps_rews = (eps_rews - torch.mean(eps_rews))/(torch.std(eps_rews))
    eps_rews = eps_rews / torch.sum(eps_rews)
    
    
    for i in range(len(eps_obs)):
        if (i==len(eps_obs)-1):
            val = eps_rews[i]
        else:
            next_obs = eps_obs[i+1]
            val = eps_rews[i] + GAMMA*critic(torch.as_tensor(next_obs, dtype=torch.float32))
        targets.append(val)
        
    
    return targets

def get_advantage(batch_obs, batch_rews, critic):
#     batch_obs = torch.as_tensor(batch_obs, dtype = torch.float32)
    targets = []
    
    batch_rews = torch.as_tensor(batch_rews,dtype=torch.float32)
    
#     batch_rews = (batch_rews - torch.mean(batch_rews))/(torch.std(batch_rews))
    batch_rews = batch_rews/torch.sum(batch_rews)
    
    for i in range(len(batch_obs)):
        obs = batch_obs[i]
        obs = torch.as_tensor(obs, dtype = torch.float32)
        
        if (i==len(batch_obs)-1):
            target = batch_rews[i] - critic(obs).item() 
        else:
            next_obs = batch_obs[i+1]
            next_obs = torch.as_tensor(next_obs, dtype = torch.float32)
            target = batch_rews[i] + GAMMA*(critic(next_obs).item()) - critic(obs).item()
        
        targets.append(target.item())
    
    # last advantage should just be the rew
#     targets.append(batch_rews[-1])
    
    return targets


def get_batch_advantage(batch_obs, batch_rew, critic):
    advantage = []
    for i in range(len(batch_obs)):
        eps_obs = batch_obs[i]
        eps_rew = batch_rew[i]
        adv = []
#         eps_obs = torch.as_tensor(batch_obs, dtype = torch.float32)
        for j in range(len(eps_obs)):
            obs = eps_obs[j]
            obs = torch.as_tensor(obs, dtype = torch.float32)
            
            if (j==len(eps_obs)-1):
                val = eps_rew[j] - critic(obs).item()
            else:
                next_obs = eps_obs[j+1]
                next_obs = torch.as_tensor(next_obs, dtype = torch.float32)
                val = eps_rew[j] + GAMMA*critic(next_obs).item() - critic(obs).item()
            adv.append(val)
            
        advantage.extend(adv)
    
    return advantage

def get_normalised_adv_targets(ep_obs, ep_rews, critic):
    rewards = []
    disc_rew = 0
    for rew in ep_rews[::-1]:
        disc_rew = GAMMA*disc_rew + rew
        rewards.append(disc_rew)
        
    returns = rewards[::-1]
    eps = 1e-6

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    
#     targets=rewards
    targets = returns
    advantage = []
    
    ep_obs = torch.as_tensor(ep_obs, dtype = torch.float32)
    
#     for obs, rew in zip(ep_obs, returns):
#         adv = rew - critic(obs)
#         advantage.append(adv.item())
    
    for i in range(len(ep_obs)):
        a = returns[i] - critic(ep_obs[i])
#         if (i==len(ep_obs)-1):
#             a = rewards[i] - critic(ep_obs[i])
#         else:
#             a = rewards[i] + critic(ep_obs[i+1]) - critic(ep_obs[i])
        
        advantage.append(a.item())
    
    return advantage, targets

In [341]:
x = [1,1,0,0,0,1]
get_critic_targets(x)
y = [5,4]
x.extend(y)
x

[1, 1, 0, 0, 0, 1, 5, 4]

In [411]:
def train_one_epoch(env, net,critic, cont, batch_size=5000, render=False):
    batch_obs = []
    batch_wts = []
    batch_acts = []
    batch_rets = []
    batch_len = []
    eps_rew = []
    batch_rews = []
    targets = []
    advantage = []
    eps_obs = []
    obs = env.reset()
    done=False
    epoch_finished_rendering = False
    
    batch_pack_obs = []
    batch_pack_rew = []
    
    while True:
        if not epoch_finished_rendering and render:
            env.render()
        
        act = get_action(net, obs = torch.as_tensor(obs,dtype=torch.float32), cont=cont)
        batch_obs.append(obs.copy())
        batch_acts.append(act)
        
        eps_obs.append(obs)
        
        obs,rew,done,_ = env.step(act)
        
        eps_rew.append(rew)
        
        batch_rews.append(rew)
#         obs= next_obs
        
        if done:
            eps_ret = sum(eps_rew)
            eps_len = len(eps_rew)
            batch_rets.append(eps_ret)
            batch_len.append(eps_len)
            
#             eps_target = get_critic_targets(eps_rew)
#             eps_target = get_critic_targets_biased(eps_rew,eps_obs, critic)
#             targets.extend(eps_target)
            
#             # get advantage estimate
#             adv = get_advantage(eps_obs, eps_rew, critic)
#             advantage.extend(adv)


            batch_pack_obs.append(eps_obs)
            batch_pack_rew.append(eps_rew)
        
            ep_adv, ep_targ = get_normalised_adv_targets(eps_obs, eps_rew, critic)
            advantage.extend(ep_adv)
            targets.extend(ep_targ)
            
            #plain
#             batch_wts = batch_wts + [eps_ret]*eps_len

            #subtract avg reward
#             batch_wts = batch_wts + [eps_ret- avg_rew]*eps_len
            
            # reward to-go
#             batch_wts = batch_wts + list(reward_to_go(eps_rew))

            # reward to-go with avg rew
#             batch_wts = batch_wts + list(reward_to_go_avg(eps_rew, avg_rew))
            
            eps_rew = []
            eps_obs = []
            done = False
            
            obs = env.reset()
            epoch_finished_rendering = True
            
            if len(batch_obs)>batch_size:
                break
    
    # critic update
    pred_values = critic(torch.as_tensor(batch_obs, dtype = torch.float32))
#     actual_values = get_critic_targets(batch_rews)
    
    optimizer_critic.zero_grad()
    batch_loss_critic = loss_mae(pred_values.reshape(-1),
                                 torch.as_tensor(targets, dtype = torch.float32))
    batch_loss_critic.backward()
    optimizer_critic.step()
    
#     # get advantage estimate
#     advantage = get_advantage(batch_obs, batch_rews, critic)
#     advantage = get_batch_advantage(batch_pack_obs, batch_pack_rew, critic)
    
#     pdb.set_trace()
    # policy network update
    optimizer.zero_grad()
    
    if cont:
        batch_act_v = torch.as_tensor(batch_acts, dtype=torch.float32)
    else:
        batch_act_v = torch.as_tensor(batch_acts, dtype = torch.int32)
    
    batch_loss, entropy_v = compute_loss(obs = torch.as_tensor(batch_obs, dtype=torch.float32),
                              acts = batch_act_v,
                              wts = torch.as_tensor(advantage, dtype = torch.float32),
                             net = net,cont=cont)
    
    entropy = entropy_v.item()
    
    batch_loss.backward()
    optimizer.step()
    return batch_loss,batch_rets, batch_len, batch_loss_critic, advantage, entropy

In [461]:
# env = gym.make('CartPole-v1')
# env = gym.make('Pendulum-v0')
# env = gym.make('MountainCar-v0')
env = gym.make('LunarLander-v2')
cont = False

obs_size = env.observation_space.shape[0]

if cont:
    n_actions = env.action_space.shape[0]
else:
    n_actions = env.action_space.n
obs_size, n_actions

(8, 4)

In [462]:
env.observation_space, env.action_space

(Box(8,), Discrete(4))

In [463]:
HIDDEN_SIZE = 32
BATCH_SIZE = 500
GAMMA = 0.99
STD = 0.1

In [464]:
net = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions= n_actions)

In [465]:
obs = env.observation_space.sample()
act = get_action(net, obs= torch.as_tensor(obs, dtype=torch.float32), cont=cont)
act

0

In [466]:
obs = env.observation_space.sample()
pol = get_policy(net, obs= torch.as_tensor(obs, dtype=torch.float32))
# pol = get_policy_cont(net,obs= torch.as_tensor(obs, dtype=torch.float32))
pol

Categorical()

In [467]:
pol.loc

AttributeError: 'Categorical' object has no attribute 'loc'

In [468]:
pol.log_prob(torch.as_tensor(act, dtype=torch.float32))

tensor(-1.6612, grad_fn=<SqueezeBackward1>)

In [469]:
a = env.action_space.sample()
a

0

In [470]:
from torch.optim import Adam
lr = 1e-2
optimizer = Adam(net.parameters(), lr=lr)

In [471]:
# critic
critic = Net(obs_size = obs_size, hidden_size = HIDDEN_SIZE, n_actions=1)
lr_c = 1e-2
optimizer_critic = Adam(critic.parameters(), lr=lr_c)

from torch.nn import MSELoss
loss_mae = MSELoss(reduction='sum')
# loss_mae = MSELoss()

In [472]:
obs = env.observation_space.sample()
value = critic(torch.as_tensor(obs, dtype=torch.float32))
value

tensor([0.0424], grad_fn=<ThAddBackward>)

In [473]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(comment="-actor_critic_mntcar")

In [474]:
# train

rew_req = 200
i=0
mean_rew = 0
mean_rew_sum = 0
avg_rew = 0

while i<500:
    i+=1
    render = True if i%20==0 else False
#     render = False
    batch_loss,batch_ret, batch_len, critic_loss,advantage, entropy = train_one_epoch(env, net, critic, cont,batch_size=BATCH_SIZE, render=render)
    mean_rew = np.mean(batch_ret) # mean return per episode
    mean_rew_sum += mean_rew  # sum of returns of episodes from start
    avg_rew = mean_rew_sum/i  # avg reward per episode from start
    if render:
        env.close()
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f \t critic_loss: %.3f \t adv: %.3f'%
                (i, batch_loss, np.mean(batch_ret), np.mean(batch_len), critic_loss, np.mean(advantage)))
    writer.add_scalar("loss", batch_loss, i)
    writer.add_scalar("reward_mean", mean_rew, i)
    writer.add_scalar('entropy', entropy,i)
    writer.add_scalar('advantage',np.mean(advantage),i)
    writer.add_scalar('critc_loss',critic_loss,i)

epoch:   1 	 loss: 0.011 	 return: -166.005 	 ep_len: 86.667 	 critic_loss: 557.426 	 adv: 0.005
epoch:   2 	 loss: -0.011 	 return: -304.225 	 ep_len: 113.200 	 critic_loss: 558.280 	 adv: -0.019
epoch:   3 	 loss: 0.012 	 return: -133.544 	 ep_len: 93.500 	 critic_loss: 586.948 	 adv: 0.007
epoch:   4 	 loss: -0.083 	 return: -349.428 	 ep_len: 99.833 	 critic_loss: 536.006 	 adv: -0.068
epoch:   5 	 loss: 0.049 	 return: -167.461 	 ep_len: 87.500 	 critic_loss: 517.839 	 adv: 0.033
epoch:   6 	 loss: 0.010 	 return: -269.968 	 ep_len: 105.200 	 critic_loss: 502.290 	 adv: 0.011
epoch:   7 	 loss: -0.012 	 return: -173.752 	 ep_len: 90.667 	 critic_loss: 521.423 	 adv: 0.010
epoch:   8 	 loss: 0.320 	 return: -192.112 	 ep_len: 104.400 	 critic_loss: 602.143 	 adv: 0.234
epoch:   9 	 loss: 0.150 	 return: -162.837 	 ep_len: 80.286 	 critic_loss: 566.544 	 adv: 0.111
epoch:  10 	 loss: 0.239 	 return: -190.630 	 ep_len: 107.600 	 critic_loss: 617.755 	 adv: 0.152
epoch:  11 	 loss: 0.

epoch:  85 	 loss: 0.004 	 return: -28.831 	 ep_len: 1000.000 	 critic_loss: 727.586 	 adv: 0.033
epoch:  86 	 loss: 0.141 	 return: -17.105 	 ep_len: 477.667 	 critic_loss: 958.885 	 adv: 0.135
epoch:  87 	 loss: 0.023 	 return: -6.371 	 ep_len: 630.500 	 critic_loss: 761.966 	 adv: 0.035
epoch:  88 	 loss: -0.334 	 return: 6.738 	 ep_len: 189.333 	 critic_loss: 360.573 	 adv: -0.260
epoch:  89 	 loss: -0.140 	 return: -0.523 	 ep_len: 170.333 	 critic_loss: 272.319 	 adv: -0.130
epoch:  90 	 loss: -0.103 	 return: 0.648 	 ep_len: 254.667 	 critic_loss: 472.829 	 adv: -0.028
epoch:  91 	 loss: 0.545 	 return: 17.662 	 ep_len: 1000.000 	 critic_loss: 590.868 	 adv: 0.489
epoch:  92 	 loss: 0.472 	 return: 5.478 	 ep_len: 1000.000 	 critic_loss: 593.758 	 adv: 0.413
epoch:  93 	 loss: 0.253 	 return: 40.837 	 ep_len: 490.667 	 critic_loss: 955.575 	 adv: 0.220
epoch:  94 	 loss: 0.192 	 return: -16.071 	 ep_len: 394.333 	 critic_loss: 806.967 	 adv: 0.172
epoch:  95 	 loss: -0.419 	 ret

epoch: 170 	 loss: 0.143 	 return: 85.775 	 ep_len: 1000.000 	 critic_loss: 339.012 	 adv: 0.139
epoch: 171 	 loss: 0.069 	 return: 121.004 	 ep_len: 1000.000 	 critic_loss: 311.087 	 adv: 0.100
epoch: 172 	 loss: -0.111 	 return: 79.270 	 ep_len: 1000.000 	 critic_loss: 524.976 	 adv: -0.066
epoch: 173 	 loss: 0.162 	 return: 80.466 	 ep_len: 1000.000 	 critic_loss: 387.373 	 adv: 0.164
epoch: 174 	 loss: 0.086 	 return: 83.525 	 ep_len: 1000.000 	 critic_loss: 314.518 	 adv: 0.107
epoch: 175 	 loss: -0.040 	 return: 91.543 	 ep_len: 1000.000 	 critic_loss: 365.868 	 adv: -0.026
epoch: 176 	 loss: 0.116 	 return: 131.488 	 ep_len: 1000.000 	 critic_loss: 197.807 	 adv: 0.116
epoch: 177 	 loss: -0.091 	 return: 127.400 	 ep_len: 1000.000 	 critic_loss: 274.162 	 adv: -0.062
epoch: 178 	 loss: 0.164 	 return: 136.121 	 ep_len: 1000.000 	 critic_loss: 172.262 	 adv: 0.153
epoch: 179 	 loss: -1.079 	 return: 5.083 	 ep_len: 304.500 	 critic_loss: 740.333 	 adv: -1.008
epoch: 180 	 loss: -

epoch: 254 	 loss: 0.132 	 return: 129.559 	 ep_len: 1000.000 	 critic_loss: 101.435 	 adv: 0.136
epoch: 255 	 loss: -0.156 	 return: 58.749 	 ep_len: 597.500 	 critic_loss: 606.106 	 adv: -0.124
epoch: 256 	 loss: -0.096 	 return: 61.199 	 ep_len: 598.000 	 critic_loss: 426.891 	 adv: -0.077
epoch: 257 	 loss: -0.018 	 return: 84.989 	 ep_len: 634.500 	 critic_loss: 475.096 	 adv: 0.001
epoch: 258 	 loss: 0.098 	 return: 113.681 	 ep_len: 1000.000 	 critic_loss: 120.185 	 adv: 0.103
epoch: 259 	 loss: -0.187 	 return: 58.780 	 ep_len: 655.500 	 critic_loss: 624.643 	 adv: -0.148
epoch: 260 	 loss: 0.044 	 return: 145.845 	 ep_len: 1000.000 	 critic_loss: 283.381 	 adv: 0.063
epoch: 261 	 loss: 0.047 	 return: 81.327 	 ep_len: 1000.000 	 critic_loss: 744.172 	 adv: 0.208
epoch: 262 	 loss: -0.071 	 return: 135.917 	 ep_len: 1000.000 	 critic_loss: 417.585 	 adv: -0.045
epoch: 263 	 loss: -0.039 	 return: 130.253 	 ep_len: 1000.000 	 critic_loss: 329.543 	 adv: -0.009
epoch: 264 	 loss:

epoch: 338 	 loss: 0.195 	 return: 24.372 	 ep_len: 153.500 	 critic_loss: 534.163 	 adv: 0.130
epoch: 339 	 loss: 0.118 	 return: 243.646 	 ep_len: 251.000 	 critic_loss: 501.503 	 adv: -0.060
epoch: 340 	 loss: 0.166 	 return: 25.685 	 ep_len: 142.000 	 critic_loss: 468.167 	 adv: 0.080
epoch: 341 	 loss: 0.230 	 return: 271.258 	 ep_len: 231.333 	 critic_loss: 733.293 	 adv: 0.049
epoch: 342 	 loss: 0.130 	 return: 85.960 	 ep_len: 166.000 	 critic_loss: 503.446 	 adv: 0.021
epoch: 343 	 loss: 0.084 	 return: 26.961 	 ep_len: 148.750 	 critic_loss: 399.438 	 adv: 0.016
epoch: 344 	 loss: 0.068 	 return: 272.068 	 ep_len: 256.500 	 critic_loss: 509.327 	 adv: 0.052
epoch: 345 	 loss: 0.101 	 return: 155.654 	 ep_len: 219.750 	 critic_loss: 660.905 	 adv: 0.079
epoch: 346 	 loss: 0.248 	 return: 250.988 	 ep_len: 709.000 	 critic_loss: 617.509 	 adv: 0.285
epoch: 347 	 loss: 0.041 	 return: 108.745 	 ep_len: 202.667 	 critic_loss: 360.432 	 adv: 0.031
epoch: 348 	 loss: 0.040 	 return

KeyboardInterrupt: 

In [None]:
env.close()

In [202]:
x = [1,2,3]
y = []
y.append(x)
y

[[1, 2, 3]]

In [203]:
x=[4,2,3]
y

[[1, 2, 3]]

In [204]:
y.append(x.copy())

In [205]:
y

[[1, 2, 3], [4, 2, 3]]

In [206]:
y.append(x)

In [207]:
y

[[1, 2, 3], [4, 2, 3], [4, 2, 3]]

In [237]:
obs = env.observation_space.sample()
obs

array([-5.3292841e-01,  1.3936392e+38,  2.2844117e-02,  8.5078421e+37],
      dtype=float32)

In [148]:
net(torch.FloatTensor(obs))

tensor([79045296309277814133001641918405279744.,
         8074543887068950992878338433657864192.], grad_fn=<ThAddBackward>)

In [169]:
c2 = Categorical(net(torch.FloatTensor(obs)))
c2.sample()

tensor(0)

In [173]:
c2.log_prob(torch.as_tensor(1, dtype = torch.int32))

tensor(-2.3786, grad_fn=<SqueezeBackward1>)

In [377]:
def test_rew1(ep_rews):
    rewards = []
    disc_rew = 0
    for rew in ep_rews[::-1]:
        disc_rew = GAMMA*disc_rew + rew
        rewards.append(disc_rew)
    rewards = rewards[::-1]
    return rewards
        
def test_rew2(ep_rews):
    R = 0
    returns = []
    eps = 1e-6
    for r in ep_rews[::-1]:
        # calculate the discounted value
        R = r + GAMMA * R
        returns.insert(0, R)
    return returns

In [378]:
x = [1,1,1,1,1,1]
test_rew1(x)

[5.8519850599, 4.90099501, 3.9403989999999998, 2.9701, 1.99, 1.0]

In [379]:
test_rew2(x)

[5.8519850599, 4.90099501, 3.9403989999999998, 2.9701, 1.99, 1.0]