In [None]:
import gymnasium as gym

import torch
import torch.nn as nn
from torch.distributions import Categorical, MultivariateNormal

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from IPython import display
import matplotlib.pyplot as plt

In [None]:
env = gym.make("Humanoid-v4")
state, _ = env.reset()

In [None]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_dim, action_dim

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim=state_dim, action_dim=action_dim, has_continuous_action_space=True, action_std_init=0.6):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)

        self.head = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
        )

        if has_continuous_action_space :
            self.actor_mouth = nn.Sequential(
                            nn.Linear(64, action_dim),
                            nn.Tanh()
            )
        else:
            self.actor_mouth = nn.Sequential(
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
            )

        self.critic_mouth = nn.Sequential(
                        nn.Linear(64, 1)
        )
    
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
    
    def forward(self, state):
        x = self.head(state)
        action_mean_prob = self.actor_mouth(x)
        state_val = self.critic_mouth(x)
        return action_mean_prob, state_val

    def act(self, state):
        if self.has_continuous_action_space:
            action_mean, state_val = self(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs, state_val = self(state)
            dist = Categorical(action_probs)
            
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):
        if self.has_continuous_action_space:
            action_mean, state_val = self(state)
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var)
            dist = MultivariateNormal(action_mean, cov_mat)

            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs, state_val = self(state)
            dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        
        return action_logprobs, state_val, dist_entropy



ac = ActorCritic().to(device)
inp = torch.rand(1, state_dim).to(device)
action, action_logprob, state_val = ac.act(inp)
ac.evaluate(inp, action)

In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

In [None]:
class PPO:
    def __init__(self, state_dim=state_dim, action_dim=action_dim, lr_actor=0.0003 , lr_critic=0.001 , gamma=0.99 , K_epochs=80, eps_clip=0.2, has_continuous_action_space=True, action_std_init=0.6):
        
        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr_actor)

        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()

        self.update_timestep = 10000
        def create_discount_matrix(gamma, n):
            temp = torch.tensor([gamma ** i for i in range(n)], dtype=torch.float32)
            final = torch.zeros((n, n), dtype=torch.float32)
            for i in range(n):
                final[i:, i] = temp[:n-i]

            return final

        self.discount_matrix = create_discount_matrix(gamma, self.update_timestep).to(device)

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                print("setting actor output action_std to min_action_std : ", self.action_std)
            else:
                print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

    def select_action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).float().to(device)
            action, action_logprob, state_val = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            if self.has_continuous_action_space:
                return action.cpu().numpy().flatten()
            else:
                return action.item()

    def update(self):
        rewards = torch.tensor(self.buffer.rewards).view(1, -1).float().to(device)
        rewards = rewards @ self.discount_matrix
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
        rewards = rewards.view(-1)

        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        advantages = rewards.detach() - old_state_values.detach()

        for _ in range(self.K_epochs):
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            state_values = torch.squeeze(state_values)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())


        self.buffer.clear()

    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
  
    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
ppo_agent = PPO()
ppo_agent.load(f"ppo_actor.pt")
list_episodes_reward = []
time_step = 0
i_episode = 0

In [None]:
max_ep_len = 1000
ppo_agent.update_timestep = max_ep_len * 4 
ppo_agent.__init__()
ppo_agent.load(f"ppo_actor.pt")
max_training_timesteps = int(3e6) 
env = gym.make("Humanoid-v4")


action_std_decay_freq = int(2.5e5)
print_freq = max_ep_len * 10
save_model_freq = int(1e4)

print_running_reward = 0
print_running_episodes = 0



while time_step <= max_training_timesteps:
    state, _ = env.reset()
    current_ep_reward = 0
    for t in range(1, max_ep_len+1):
        action = ppo_agent.select_action(state)
        state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        ppo_agent.buffer.rewards.append(reward)
        ppo_agent.buffer.is_terminals.append(done)

        time_step += 1
        current_ep_reward += reward


        if time_step % ppo_agent.update_timestep == 0:
            ppo_agent.update()

        if time_step % print_freq == 0:
            display.clear_output()
            print_avg_reward = print_running_reward / print_running_episodes
            print_avg_reward = round(print_avg_reward, 2)
            print(f"Episode: {i_episode} Timestep: {time_step} Reward: {print_avg_reward}")
            print_running_reward = 0
            print_running_episodes = 0
            plt.plot(list_episodes_reward)
            plt.show()

        if time_step % save_model_freq == 0:
            print('------------------------')
            ppo_agent.save(f"ppo_actor.pt")
            print('model saved')
            print('------------------------')
        
        if done:
            break
    
    list_episodes_reward.append(current_ep_reward)
    print_running_reward += current_ep_reward
    print_running_episodes += 1

    i_episode += 1
