In [129]:
import torch
import numpy as np
# seans agent
class ppo_agent():
    def __init__(self, gamma=0.99, lam=0.95, clip_epsilon=0.2, device='cpu', norm_adv=True):
        self.gamma = gamma
        self.lam = lam  # Lambda for GAE
        self.clip_epsilon = clip_epsilon  # Clipping range for policy loss
        self.device = device
        self.norm_adv = norm_adv
        
    def compute_gae(self, rewards, values, dones):
        values = list(values) + [0]  # Convert values to a list before concatenation
        advantages = []
        gae = 0
        for step in reversed(range(len(rewards))):
            # print( rewards[step], values[step + 1] , 1 - dones[step])
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.lam * (1 - dones[step]) * gae
            advantages.append(gae)
        advantages.reverse()
        return advantages
    
    def process_gae_returns(self,rewards, values, dones):
        with torch.no_grad():
            returns = []
            # G = 0
            # # Compute returns for each timestep
            # for reward, done in zip(reversed(rewards), reversed(dones)):
            #     G = reward + self.gamma * G * (1 - done)
            #     returns.append(G)
            # returns.reverse()
            
            # Compute advantages using GAE
            advantages = self.compute_gae(rewards, values, dones)
            advantages = torch.FloatTensor(advantages).to(self.device)
            # Normalize advantages for training stability
            # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Prepare tensors
            # returns = torch.FloatTensor(returns).to(self.device)
            returns = advantages + torch.FloatTensor(values).to(self.device)

        return returns,advantages
    def process_gae_returns0(self,rewards, values, dones):
        with torch.no_grad():
            returns = []
            G = 0
            # Compute returns for each timestep
            for reward, done in zip(reversed(rewards), reversed(dones)):
                G = reward + self.gamma * G * (1 - done)
                returns.append(G)
            returns.reverse()
            
            # Compute advantages using GAE
            advantages = self.compute_gae(rewards, values, dones)
            advantages = torch.FloatTensor(advantages).to(self.device)
            # Normalize advantages for training stability
            # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Prepare tensors
            returns = torch.FloatTensor(returns).to(self.device)
            # returns = advantages + torch.FloatTensor(values).to(self.device)

        return returns,advantages

In [113]:
rewards = [0,0,0,0,1,0,0,-1,0,0,0,1,0,0,1,0]
dones =   [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
next_done = 1

# values =  np.random.rand(len(rewards))
values = [0.47495744, 0.04955575, 0.88012086, 0.81747744, 0.47915682,
       0.75750459, 0.11583911, 0.42138228, 0.77161771, 0.4410075 ,
       0.72813253, 0.42859623, 0.50210429, 0.47359902, 0.91350348,
       0.97346051]

next_value = 0 # can be any, as next_done is 1
len(dones)

16

In [115]:
returns,advantages = ppo_agent().process_gae_returns0(rewards, values, dones[1:]+[1])
(returns, advantages)

(tensor([1.7926, 1.8107, 1.8290, 1.8475, 1.8661, 0.8749, 0.8837, 0.8927, 1.9118,
         1.9311, 1.9506, 1.9703, 0.9801, 0.9900, 1.0000, 0.0000]),
 tensor([ 0.8716,  1.3796,  0.5931,  0.7059,  1.1154, -0.1652,  0.5078,  0.2195,
          0.9325,  1.3477,  1.1354,  1.5303,  0.4910,  0.5574,  0.1347, -0.9735]))

In [130]:
# Our ppo-test collected is actually next done
returns,advantages = ppo_agent().process_gae_returns(rewards, values, dones[1:]+[1])
(returns, advantages)

(tensor([1.3465, 1.4291, 1.4732, 1.5234, 1.5945, 0.5923, 0.6236, 0.6409, 1.7041,
         1.7887, 1.8636, 1.9589, 0.9931, 1.0310, 1.0482, 0.0000]),
 tensor([ 0.8716,  1.3796,  0.5931,  0.7059,  1.1154, -0.1652,  0.5078,  0.2195,
          0.9325,  1.3477,  1.1354,  1.5303,  0.4910,  0.5574,  0.1347, -0.9735]))

In [126]:
# CleanRL
class args:
    num_steps = len(rewards)
    gamma = 0.99
    gae_lambda = 0.95
    
def cleanrl_gae_returns(rewards, values, dones, args, device='cpu'):
    with torch.no_grad():
        print(values)
        next_value = values[-1]
        advantages = torch.zeros_like(rewards).to(device)
        lastgaelam = 0
        for t in reversed(range(args.num_steps)):
            if t == args.num_steps - 1:
                nextnonterminal = 1.0 - dones[-1]
                nextvalues = next_value
            else:
                nextnonterminal = 1.0 - dones[t + 1]
                nextvalues = values[t + 1]
            print(t, rewards[t], nextvalues, nextnonterminal)
            delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
            advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
        returns = advantages + torch.FloatTensor(values[:-1]).to(device)
    return returns, advantages

cleanrl_gae_returns(torch.FloatTensor(rewards), values+[0], dones+[1], args)

[0.47495744, 0.04955575, 0.88012086, 0.81747744, 0.47915682, 0.75750459, 0.11583911, 0.42138228, 0.77161771, 0.4410075, 0.72813253, 0.42859623, 0.50210429, 0.47359902, 0.91350348, 0.97346051, 0]
15 tensor(0.) 0 0.0
14 tensor(1.) 0.97346051 1.0
13 tensor(0.) 0.91350348 1.0
12 tensor(0.) 0.47359902 1.0
11 tensor(1.) 0.50210429 1.0
10 tensor(0.) 0.42859623 1.0
9 tensor(0.) 0.72813253 1.0
8 tensor(0.) 0.4410075 1.0
7 tensor(-1.) 0.77161771 1.0
6 tensor(0.) 0.42138228 1.0
5 tensor(0.) 0.11583911 1.0
4 tensor(1.) 0.75750459 1.0
3 tensor(0.) 0.47915682 1.0
2 tensor(0.) 0.81747744 1.0
1 tensor(0.) 0.88012086 1.0
0 tensor(0.) 0.04955575 1.0


(tensor([1.3465, 1.4291, 1.4732, 1.5234, 1.5945, 0.5923, 0.6236, 0.6409, 1.7041,
         1.7887, 1.8636, 1.9589, 0.9931, 1.0310, 1.0482, 0.0000]),
 tensor([ 0.8716,  1.3796,  0.5931,  0.7059,  1.1154, -0.1652,  0.5078,  0.2195,
          0.9325,  1.3477,  1.1354,  1.5303,  0.4910,  0.5574,  0.1347, -0.9735]))