In [8]:
from torch import nn
import torch.nn.functional as F
import torch
from torch.distributions import Categorical, Normal
import gymnasium as gym
from tqdm.notebook import tnrange
import numpy as np
import scipy
import wandb
from gymnasium.spaces import Box, Discrete
import os
import random
from gymnasium.wrappers.record_video import RecordVideo

In [9]:
def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

class PPOBuffer():
    def __init__(self, observation_dim, action_dim, capacity, gamma, lam):
        self.obs_buf = np.zeros(combined_shape(capacity, observation_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(capacity, action_dim), dtype=np.float32)
        self.adv_buf = np.zeros(capacity, dtype=np.float32)
        self.rew_buf = np.zeros(capacity, dtype=np.float32)
        self.rtg_buf = np.zeros(capacity, dtype=np.float32)
        self.val_buf = np.zeros(capacity, dtype=np.float32)
        self.logp_buf = np.zeros(capacity, dtype=np.float32)
        # self.logp_buf = np.zeros(combined_shape(capacity, action_dim), dtype=np.float32)
        self.capacity = capacity
        self.idx = 0
        self.path_idx = 0
        self.gamma = gamma
        self.lam = lam

    def push(self, obs, act, rew, val, logp):
        assert self.idx < self.capacity
        self.obs_buf[self.idx] = obs
        self.act_buf[self.idx] = act
        self.rew_buf[self.idx] = rew
        self.val_buf[self.idx] = val
        self.logp_buf[self.idx] = logp

        self.idx += 1

    def GAE_cal(self, last_val):
        path_slice = slice(self.path_idx, self.idx)
        # to make the deltas the same dim
        rewards = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)

        deltas = rewards[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)

        ### OpenAI spinning up implemetation comment: No ideal, big value loss when episode rewards are large
        # self.rtg_buf[path_slice] = discount_cumsum(rewards, self.gamma)[:-1]

        ### OpenAI stable_baseline3 implementation
        ### in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
        ### TD(lambda) estimator, see "Telescoping in TD(lambda)"
        self.rtg_buf[path_slice] = self.adv_buf[path_slice] + self.val_buf[path_slice]
        
        self.path_idx = self.idx

                
    def sample(self, minibatch_size, device):
        """This method sample a list of minibatches from the memory

        Args:
            minibatch_size (int): size of minibatch, usually 2^n
            device (object): CPU or GPU

        Returns:
            list: a list of minibatches
        """
        assert self.idx == self.capacity
        # reset the index
        self.idx, self.path_idx = 0, 0
        # normalise advantage
        self.adv_buf = (self.adv_buf - np.mean(self.adv_buf)) / (np.std(self.adv_buf) + 1e-8)
        
        inds = np.arange(self.capacity)
        
        np.random.shuffle(inds)
        
        data = []
        for start in range(0, self.capacity, minibatch_size):
            end = start + minibatch_size
            minibatch_inds = inds[start:end]
            minibatch = dict(obs=self.obs_buf[minibatch_inds], act=self.act_buf[minibatch_inds], \
                             rtg=self.rtg_buf[minibatch_inds], adv=self.adv_buf[minibatch_inds], \
                             logp=self.logp_buf[minibatch_inds])
            data.append({k: torch.as_tensor(v, dtype=torch.float32, device=device) for k,v in minibatch.items()})
        
        return data

In [10]:
def layer_init(layer, std=np.sqrt(2)):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, 0.0)
    return layer

class Actor_Net(nn.Module):
    def __init__(self, n_observations, n_actions, num_cells, continous_action, log_std_init=0.0):
        super(Actor_Net,self).__init__()
        
        self.layer1 = layer_init(nn.Linear(n_observations, num_cells))
        self.layer2 = layer_init(nn.Linear(num_cells, num_cells))
        self.layer3 = layer_init(nn.Linear(num_cells, n_actions), std=0.01)

        self.continous_action = continous_action
        self.action_dim = n_actions
        
        if self.continous_action:
            log_std = log_std_init * np.ones(self.action_dim, dtype=np.float32)
            # Add it to the list of parameters
            self.log_std = torch.nn.Parameter(torch.as_tensor(log_std), requires_grad=True)
            #
            ### https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/  implementation
            # self.log_std = nn.Parameter(torch.zeros(1, self.action_dim))  

            ### Stable-baseline3 implementation
            # self.log_std = nn.Parameter(torch.ones(self.action_dim) * log_std_init, requires_grad=False)      

        

    def forward(self, x):
        activation1 = F.tanh(self.layer1(x))
        activation2 = F.tanh(self.layer2(activation1))
        activation3 = self.layer3(activation2)

        return activation3
    
    def act(self, x):
        if self.continous_action:
            mu = self.forward(x)
            std = torch.exp(self.log_std)
            dist = Normal(mu, std)
        else:
            log_probs = F.log_softmax(self.forward(x), dim=1)
            dist = Categorical(log_probs)
    
        action = dist.sample()
        if self.continous_action:
            action_logprob = dist.log_prob(action).sum(axis=-1)
        else:
            action_logprob = dist.log_prob(action)

        return action.detach().cpu().numpy(), action_logprob.detach().cpu().numpy()
    
    def logprob_ent_from_state_acton(self, x, act):
        if self.continous_action:
            mu = self.forward(x)
            std = torch.exp(self.log_std)
            dist = Normal(mu, std)
            # sum term is crucial to reduce dimension, otherwise the ratio = torch.exp(logp - logp_old) will have wrong result with boardcasting
            act_logp = dist.log_prob(act).sum(axis=-1) 
        else:
            dist = Categorical(F.softmax(self.forward(x)))
            act_logp = dist.log_prob(act)
        entropy = dist.entropy()
        
        return entropy, act_logp
    
   
class Critic_Net(nn.Module):
    def __init__(self, n_observations, num_cells):
        super(Critic_Net,self).__init__()
        self.layer1 = layer_init(nn.Linear(n_observations, num_cells))
        self.layer2 = layer_init(nn.Linear(num_cells, num_cells))
        self.layer3 = layer_init(nn.Linear(num_cells, 1), std=1.0)

    def forward(self, x):
        activation1 = F.tanh(self.layer1(x))
        activation2 = F.tanh(self.layer2(activation1))
        activation3 = self.layer3(activation2)

        return activation3

class Actor_Critic_net(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim, continous_action, parameters_hardshare, log_std_init=0.0):

        super(Actor_Critic_net, self).__init__()

        self.parameters_hardshare = parameters_hardshare
        self.continous_action = continous_action
        self.action_dim = act_dim

        if self.parameters_hardshare:
            self.layer1 = layer_init(nn.Linear(obs_dim, hidden_dim))
            self.layer2 = layer_init(nn.Linear(hidden_dim, hidden_dim))

            self.actor_head = layer_init(nn.Linear(hidden_dim, act_dim), std=0.01)
            self.critic_head = layer_init(nn.Linear(hidden_dim, 1), std=1.0)
            if self.continous_action:
                log_std = log_std_init * np.ones(self.action_dim, dtype=np.float32)
                # Add it to the list of parameters
                self.log_std = torch.nn.Parameter(torch.as_tensor(log_std), requires_grad=True)
                #
                ### https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/  implementation
                # self.log_std = nn.Parameter(torch.zeros(1, self.action_dim))  

                ### Stable-baseline3 implementation
                # self.log_std = nn.Parameter(torch.ones(self.act_dim) * log_std_init, requires_grad=False) 

        else:
            self.actor = Actor_Net(obs_dim, act_dim, hidden_dim, continous_action)
            self.critic = Critic_Net(obs_dim, hidden_dim)


    def forward(self, x):
        if self.parameters_hardshare:
            activation1 = F.tanh(self.layer1(x))
            activation2 = F.tanh(self.layer2(activation1))
            actor_logits = self.actor_head(activation2)
            value = self.critic_head(activation2)
        else:
            actor_logits = self.actor.forward(x)
            value = self.critic.forward(x)

        return actor_logits, value

    def get_value(self, x):
        return self.critic(x).item()

    
    def act(self, x):
        if self.continous_action:
            mu, value = self.forward(x)
            log_std = self.log_std if self.parameters_hardshare else self.actor.log_std
            std = torch.exp(log_std)
            dist = Normal(mu, std)
        else:
            actor_logits, value = self.forward(x)
            log_probs = F.log_softmax(actor_logits, dim=1)
            dist = Categorical(log_probs)

        action = dist.sample()
        if self.continous_action:
            action_logprob = dist.log_prob(action).sum(axis=-1)
        else:
            action_logprob = dist.log_prob(action)
        

        return action.cpu().numpy(), action_logprob.cpu().numpy(), value.item()     

    def logprob_ent_from_state_acton(self, x, act):

        if self.continous_action:
            mu, value = self.forward(x)
            log_std = self.log_std if self.parameters_hardshare else self.actor.log_std
            std = torch.exp(log_std)
            dist = Normal(mu, std)
            # sum term is crucial to reduce dimension,-0.5 otherwise the ratio = torch.exp(logp - logp_old) will have wrong result with boardcasting
            act_logp = dist.log_prob(act).sum(axis=-1) 
        else:
            actor_logits, value = self.forward(x)
            dist = Categorical(F.softmax(actor_logits))
            act_logp = dist.log_prob(act)
        entropy = dist.entropy().sum(axis=-1)
        
        return entropy, act_logp, value

In [11]:
class PPO():
    def __init__(self, gamma, lamb, eps_clip, K_epochs, \
                 observation_space, action_space, num_cells, \
                 actor_lr, critic_lr, memory_size , minibatch_size,\
                 max_training_iter, cal_total_loss, c1, c2, \
                 early_stop, kl_threshold, parameters_hardshare, \
                 max_grad_norm , device
                 ):
        self.gamma = gamma
        self.lamb = lamb
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.max_training_iter = max_training_iter

        self.observation_space = observation_space
        self.action_space = action_space
        self.memory_size = memory_size
        self.minibatch_size = minibatch_size
        
        self.cal_total_loss = cal_total_loss
        self.c1 = c1
        self.c2 = c2
        self.early_stop = early_stop
        self.kl_threshold = kl_threshold

        self.parameters_hardshare = parameters_hardshare
        self.episode_count = 0
        self.max_grad_norm = max_grad_norm
        self.global_step = 0


        if isinstance(action_space, Box):
            self.continous_action = True
        elif isinstance(action_space, Discrete):
            self.continous_action = False
        else:
            raise AssertionError(f"action space is not valid {action_space}")


        self.observtion_dim = observation_space.shape[0]

        # self.actor = Actor_Net(self.observtion_dim, \
        #                        action_space.shape[0] if self.continous_action else action_space.n, \
        #                           num_cells, self.continous_action).to(device)
          
        # self.critic = Critic_Net(self.observtion_dim, num_cells).to(device)
        self.actor_critic = Actor_Critic_net(self.observtion_dim, \
                               action_space.shape[0] if self.continous_action else action_space.n, \
                                  num_cells, self.continous_action, parameters_hardshare).to(device)

        # self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        # self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        if parameters_hardshare:
            ### eps=1e-5 follows stable-baseline3
            self.actor_critic_opt = torch.optim.Adam(self.actor_critic.parameters(), lr=actor_lr, eps=1e-5)
            
        else:
            self.actor_critic_opt = torch.optim.Adam([ 
                {'params': self.actor_critic.actor.parameters(), 'lr': actor_lr, 'eps' : 1e-5},
                {'params': self.actor_critic.critic.parameters(), 'lr': critic_lr, 'eps' : 1e-5} 
            ])


        self.memory = PPOBuffer(observation_space.shape, action_space.shape, memory_size, gamma, lamb)

        self.device = device
        
        wandb.watch(self.actor_critic.actor, log='all', log_freq=1000, idx=1)
        wandb.watch(self.actor_critic.critic, log='all', log_freq=1000, idx=2)
        # wandb.watch(self.actor_critic, log='all', log_freq=1000)

    def roll_out(self, env):
        # TODO: implement multiple thread
        # make a new environment instance
        
        

        # Maybe a deep copy is necessary for multi-thread processing
        obs, _ = env.reset()

        ep_reward = 0
        
        action_shape = env.action_space.shape
        # Run the policy for T timestep
        for i in tnrange(self.memory_size, desc="roll_out", leave=False):
            with torch.no_grad():
                obs_tensor = torch.tensor(obs, \
                                        dtype=torch.float32, device=self.device).unsqueeze(0)
                
                # action, action_logprob = self.actor.act(obs_tensor)

                # action = action.reshape(action_shape)
                
                # value = self.critic.forward(obs_tensor).item()
            
            
                action, action_logprob, value = self.actor_critic.act(obs_tensor)
            
            action = action.reshape(action_shape)

            clipped_action = action

            if self.continous_action:
                clipped_action = np.clip(action, self.action_space.low, self.action_space.high)

            next_obs, reward, terminated, truncated, _ = env.step(clipped_action)

            self.global_step += 1

            self.memory.push(obs, action, reward, value, action_logprob)

            obs = next_obs

            ep_reward += reward

            if terminated or truncated:
                if truncated:
                    # last_value = self.critic.forward(torch.tensor(next_obs, dtype=torch.float32, device=self.device)).item()
                    last_value = self.actor_critic.get_value(torch.tensor(next_obs, dtype=torch.float32, device=self.device))
                else:
                    last_value = 0
                    

                
                self.memory.GAE_cal(last_value)

                obs,_ = env.reset()
                self.episode_count += 1
                wandb.log({'episode_reward' : ep_reward}, step=self.global_step)
                ep_reward = 0


        
        with torch.no_grad():
            last_value = self.actor_critic.get_value(torch.tensor(next_obs, dtype=torch.float32, device=self.device))
        self.memory.GAE_cal(last_value)


    def evaluate_recording(self, env):
        
        env_name = env.spec.id

        video_folder = os.path.join(wandb.run.dir, 'videos')

        env = RecordVideo(env, video_folder, name_prefix=env_name)

        obs, _ = env.reset()

        done = False

        action_shape = env.action_space.shape

        while not done:
            obs_tensor = torch.tensor(obs, \
                                    dtype=torch.float32, device=self.device).unsqueeze(0)
            action, _, _ = self.actor_critic.act(obs_tensor)
            action = action.reshape(action_shape)
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            obs = next_obs

        
        mp4_files = [file for file in os.listdir(video_folder) if file.endswith(".mp4")]

        for mp4_file in mp4_files:
            wandb.log({'Episode_recording': wandb.Video(os.path.join(video_folder, mp4_file))})

        env.close()
        




    # def compute_loss(self, data):
    #     observations, actions, logp_old = data['obs'], data['act'], data['logp']
    #     advs, rtgs = data['adv'], data['rtg']

    #     # Calculate the pi_theta (a_t|s_t)
    #     entropy, logp = self.actor.logprob_ent_from_state_acton(observations, actions)
    #     ratio = torch.exp(logp - logp_old)
    #     # Kl approx according to http://joschu.net/blog/kl-approx.html
    #     kl_apx = ((ratio - 1) - (logp - logp_old)).mean()
    #     wandb.log({'KL approx': kl_apx})
    #     clip_advs = torch.clamp(ratio, 1-self.eps_clip, 1+self.eps_clip) * advs
    #     # Torch Adam implement tation mius the gradient, to plus the gradient, we need make the loss negative
    #     actor_loss = -(torch.min(ratio*advs, clip_advs)).mean()

    #     values = self.critic(observations).squeeze()
    #     critic_loss = nn.MSELoss()(values, rtgs)

    #     entropy_loss = entropy.mean()

    #     return actor_loss, critic_loss, entropy_loss, kl_apx


    def compute_loss(self, data):
        observations, actions, logp_old = data['obs'], data['act'], data['logp']
        advs, rtgs = data['adv'], data['rtg']

        # Calculate the pi_theta (a_t|s_t)
        entropy, logp, values = self.actor_critic.logprob_ent_from_state_acton(observations, actions)
        ratio = torch.exp(logp - logp_old)
        # Kl approx according to http://joschu.net/blog/kl-approx.html
        kl_apx = ((ratio - 1) - (logp - logp_old)).mean()
    
        clip_advs = torch.clamp(ratio, 1-self.eps_clip, 1+self.eps_clip) * advs
        # Torch Adam implement tation mius the gradient, to plus the gradient, we need make the loss negative
        actor_loss = -(torch.min(ratio*advs, clip_advs)).mean()

        values = values.flatten() # I used squeeze before, maybe a mistake

        critic_loss = F.mse_loss(values, rtgs)
        # critic_loss = ((values - rtgs) ** 2).mean()

        entropy_loss = entropy.mean()

        return actor_loss, critic_loss, entropy_loss, kl_apx        

    def optimise(self):

        data = self.memory.sample(self.minibatch_size, self.device)

        early_stop_count = 0

        entropy_loss_list = []
        actor_loss_list = []
        critic_loss_list = []
        kl_approx_list = []
        
        # for _ in tnrange(self.K_epochs, desc=f"epochs", position=1, leave=False):
        for _ in range(self.K_epochs):
            
            for minibatch in data:
            
                actor_loss, critic_loss, entropy_loss, kl_apx = self.compute_loss(minibatch)

                entropy_loss_list.append(entropy_loss.item())
                actor_loss_list.append(actor_loss.item())
                critic_loss_list.append(critic_loss.item())
                kl_approx_list.append(kl_apx.item())

                if self.cal_total_loss:
                    total_loss = actor_loss + self.c1 * critic_loss - self.c2 * entropy_loss

                # If this update is too big, early stop and try next minibatch
                if self.early_stop and kl_apx > self.kl_threshold:
                    early_stop_count += 1
                    continue

                # self.actor_opt.zero_grad()
                # self.critic_opt.zero_grad()
                # if self.cal_total_loss:
                #     wandb.log({'total_loss': total_loss})
                #     total_loss.backward()
                #     self.actor_opt.step()
                #     self.critic_opt.step()
                # else:
                #     wandb.log({'actor_loss': actor_loss, 'critic_loss': critic_loss})
                #     actor_loss.backward()
                #     self.actor_opt.step()
                #     critic_loss.backward()
                #     self.critic_opt.step()
                
                self.actor_critic_opt.zero_grad()
                if self.cal_total_loss:
                    total_loss.backward()
                    # Used by stable-baseline3, maybe more important for RNN
                    torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
                    self.actor_critic_opt.step()

                else:
                    actor_loss.backward()
                    critic_loss.backward()
                    # Used by stable-baseline3, maybe more important for RNN
                    torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
                    self.actor_critic_opt.step()
            
        # wandb.run.summary['early_stop_count'] = early_stop_count
        # Logging, use the same metric as stable-baselines3 to compare performance
        with torch.no_grad():
            mean_std = self.actor_critic.actor.log_std.mean().item()

        wandb.log(
            {
                'actor_loss': np.mean(actor_loss_list),
                'critic_loss' : np.mean(critic_loss_list),
                'entropy_loss' : np.mean(entropy_loss_list),
                'KL_approx' : np.mean(kl_approx_list),
                'mean_std' : mean_std
            }, step=self.global_step
        )    

                
    def train(self, env):

        for i in tnrange(self.max_training_iter // self.memory_size):

            self.roll_out(env)

            self.optimise()

        # save the model to the wandb run folder
        # PATH = os.path.join(wandb.run.dir, "actor_critic.pt")
        # torch.save(self.actor_critic.state_dict(), PATH)


        wandb.run.summary['total_episode'] = self.episode_count

        
    
            

        


### Sweep for HalfCheetah
#### Continous action space

In [12]:
sweep_configuration = {
    'method': 'random',
    'metric':{'goal':'maximize', 'name':'episode_reward'},
    'parameters':
    {
        'actor_lr' : {'value' : 3e-4},
        'memory_size' : {'value' : 2048},
        'k_epochs' : {'value' : 10},
        'gamma' : {'value' : 0.99},
        'lam' : {'value' : 0.95}, 
        'early_stop': {'value': False},
        'cal_total_loss' : {'value' : False},
        'parameters_hardshare' : {'value' : False},
        'c1' : {'value': 0.5},
        'c2' : {'value' : 0},
        # 'kl_threshold' : {'min': 0.01, 'max': 0.04},
        'minibatch_size' : {'value' : 64}
    }
}

In [13]:
def main():



    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    run = wandb.init(
            project='PPO-HalfCheetah_cmp',
            # mode='disabled',
            # config = sweep_configuration
        )


    gamma = 0.99
    # gamma = wandb.config.gamma
    lamb = 0.95
    # lamb = wandb.config.lam
    eps_clip = 0.2
    max_training_iter = 1_000_000
    k_epochs = 10
    # k_epochs = wandb.config.k_epochs
    num_cells = 64
    actor_lr = 3e-4 
    # actor_lr = wandb.config.actor_lr
    critic_lr = actor_lr
    # critic_lr = wandb.config.critic_lr
    memory_size = 2048
    # memory_size = wandb.config.memory_size
    minibatch_size = 64
    # minibatch_size = wandb.config.minibatch_size
    
    c1 = 0.5
    c2 = 0
    kl_threshold = 0.013
    # c1 = wandb.config.c1
    # c2 = wandb.config.c2
    # kl_threshold = wandb.config.kl_threshold
    
    env_name = "HalfCheetah-v4" # CartPole-v1
    parameters_hardshare = False
    early_stop = False
    cal_total_loss = False
    # parameters_hardshare = wandb.config.parameters_hardshare
    # early_stop = wandb.config.early_stop
    # cal_total_loss = wandb.config.cal_total_loss
    max_grad_norm = 0.5

    seed = 123

    wandb.config.update(
        {
            'actor_lr' : actor_lr,
            'critic_lr' : critic_lr,
            'gamma' : gamma,
            'lambda' : lamb,
            'eps_clip' : eps_clip,
            'max_training_iter' : max_training_iter,
            'k_epochs' : k_epochs,
            'hidden_cell_dim' : num_cells,
            'memory_size' : memory_size,
            'minibatch_size' : minibatch_size,
            'cal_total_loss' : cal_total_loss,
            'c1' : c1,
            'c2' : c2,
            'early_stop' : early_stop,
            'env_name': env_name,
            'kl_threshold' : kl_threshold

        }
    )

    # wandb.define_metric("episode_reward", summary="mean")
    # wandb.define_metric("KL_approx", summary="mean")
        
    # Using render_mode slow the training process down
    env = gym.make(env_name)
    recording_env = gym.make(env_name, render_mode = 'rgb_array_list')

    # video_path = os.path.join(wandb.run.dir, 'recorded_episodes')
    # print(f"Videos are stored at {video_path}")
    # video_recorder = RecordVideo(env, video_path, episode_trigger=200)

    # Seeding evaluation purpose
    env.np_random = np.random.default_rng(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    recording_env.np_random = np.random.default_rng(seed)
    recording_env.action_space.seed(seed)
    recording_env.observation_space.seed(seed)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Deterministic operations for CuDNN, it may impact performances
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    my_ppo = PPO(gamma, lamb, eps_clip, k_epochs, env.observation_space, env.action_space, num_cells,\
                 actor_lr, critic_lr, memory_size, minibatch_size, max_training_iter, \
                 cal_total_loss, c1, c2, early_stop, kl_threshold, parameters_hardshare, max_grad_norm, device)
    
    my_ppo.train(env)
    my_ppo.evaluate_recording(recording_env)

    env.close()
    recording_env.close()
    # video_recorder.close()
    run.finish()

In [14]:
# %%wandb
main()

VBox(children=(Label(value='0.003 MB of 0.013 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.240034…

0,1
episode_reward,█▁

0,1
episode_reward,-440.41693


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670751099991322, max=1.0…

  0%|          | 0/488 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

roll_out:   0%|          | 0/2048 [00:00<?, ?it/s]

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:

# # %env "WANDB_NOTEBOOK_NAME" "PPO_GYM"
# sweep_id = wandb.sweep(sweep=sweep_configuration, project='PPO-HalfCheetah_cmp')
# wandb.agent(sweep_id, function=main, count=1)

### Sweep configuration for Pendulum
#### Continous action space

In [None]:
# sweep_configuration = {
#     'method': 'grid',
#     'metric':{'goal':'maximize', 'name':'episode_reward'},
#     'parameters':
#     {
#         'early_stop': {'value': False},
#         'cal_total_loss' : {'value' : True},
#         'parameters_hardshare' : {'value' : False},
#         'c1' : {'value': 0.5020639303776493},
#         'c2' : {'value' : 0.910077248529638},
#         # 'kl_threshold' : {'min': 0.01, 'max': 0.04},
#         'minibatch_size' : {'values' : [128, 256, 512, 1024]}
#     }
# }
# %env "WANDB_NOTEBOOK_NAME" "PPO_GYM"
# sweep_id = wandb.sweep(sweep=sweep_configuration, project='PPO-Pendulum-2')
# wandb.agent(sweep_id, function=main, count=4)

In [None]:
wandb.init(mode='disabled')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gamma = 0.99
# gamma = wandb.config.gamma
lamb = 0.95
# lamb = wandb.config.lam
eps_clip = 0.2
max_training_iter = 1_000_000
k_epochs = 10
# k_epochs = wandb.config.k_epochs
num_cells = 64
actor_lr = 3e-4 
# actor_lr = wandb.config.actor_lr
critic_lr = actor_lr
# critic_lr = wandb.config.critic_lr
memory_size = 2048
# memory_size = wandb.config.memory_size
minibatch_size = 64
# minibatch_size = wandb.config.minibatch_size

c1 = 0.5
c2 = 0
kl_threshold = 0.013
# c1 = wandb.config.c1
# c2 = wandb.config.c2
# kl_threshold = wandb.config.kl_threshold

env_name = "HalfCheetah-v4" # CartPole-v1
parameters_hardshare = False
early_stop = False
cal_total_loss = True
# parameters_hardshare = wandb.config.parameters_hardshare
# early_stop = wandb.config.early_stop
# cal_total_loss = wandb.config.cal_total_loss
max_grad_norm = 0.5

seed = 0

env = gym.make(env_name)
recording_env = gym.make(env_name, render_mode = 'rgb_array_list')

env.np_random = np.random.default_rng(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

recording_env.np_random = np.random.default_rng(seed)
recording_env.action_space.seed(seed)
recording_env.observation_space.seed(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Deterministic operations for CuDNN, it may impact performances
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

my_ppo = PPO(gamma, lamb, eps_clip, k_epochs, env.observation_space, env.action_space, num_cells,\
                actor_lr, critic_lr, memory_size, minibatch_size, max_training_iter, \
                cal_total_loss, c1, c2, early_stop, kl_threshold, parameters_hardshare, max_grad_norm, device)

In [None]:
class Actor_Critic_net(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim, continous_action, parameters_hardshare, log_std_init=0.0):
        super(Actor_Critic_net, self).__init__()
        self.parameters_hardshare = parameters_hardshare
        self.continous_action = continous_action
        self.action_dim = act_dim
        self.actor = nn.Sequential(
            layer_init(nn.Linear(obs_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, act_dim), std=0.01)

        )
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, 1), std=1.0)
        )
        if self.continous_action:
            log_std = log_std_init * np.ones(self.action_dim, dtype=np.float32)
            # Add it to the list of parameters
            self.log_std = torch.nn.Parameter(torch.as_tensor(log_std), requires_grad=True)
        #TODO: implement hard-sharing parameters

    def get_value(self, x):
        return self.critic(x).detach().item()

    def act(self, x):
        if self.continous_action:
            mu = self.actor(x)
            value = self.critic(x)
            std = torch.exp(self.log_std)
            dist = Normal(mu, std)
        else:
            actor_logits = self.actor(x)
            value = self.critic(x)
            log_probs = F.log_softmax(actor_logits, dim=1)
            dist = Categorical(log_probs)

        action = dist.sample()
        if self.continous_action:
            action_logprob = dist.log_prob(action).sum(axis=-1)
        else:
            action_logprob = dist.log_prob(action)
        
        return action.detach().cpu().numpy(), action_logprob.detach().cpu().numpy(), value.detach().item()   

    def logprob_ent_from_state_acton(self, x, act):

        if self.continous_action:
            mu = self.actor(x)
            std = torch.exp(self.log_std)
            dist = Normal(mu, std)
            # sum term is crucial to reduce dimension,-0.5 otherwise the ratio = torch.exp(logp - logp_old) will have wrong result with boardcasting
            act_logp = dist.log_prob(act).sum(axis=-1) 
        else:
            actor_logits = self.actor(x)
            dist = Categorical(F.log_softmax(actor_logits, dim=1))
            act_logp = dist.log_prob(act)

        entropy = dist.entropy().sum(axis=-1)
        value = self.critic(x)

        return entropy, act_logp, value  

In [None]:
obs, _ = env.reset()
print(obs)

In [None]:
my_ppo.actor_critic_opt

In [None]:
my_ppo.actor_critic.critic.state_dict

In [None]:
my_ppo.actor_critic.actor.state_dict()['4.weight']

In [None]:
obs_tensor = torch.tensor(obs, dtype=torch.float32, device=device)

In [None]:

action, action_logprob, value = my_ppo.actor_critic.act(obs_tensor)

In [None]:
action

In [None]:
value

In [None]:
action_logprob

In [None]:
next_obs, reward, terminated, truncated, _ = env.step(action)

In [None]:
next_obs

In [None]:
vector_parameters = torch.nn.utils.parameters_to_vector(my_ppo.actor_critic.parameters()).detach().cpu().numpy()

In [None]:
vector_parameters.shape

In [None]:
agent = Agent(env)

In [None]:
agent.get_action_and_value(obs_tensor)