In [1]:
import gymnasium as gym
import torch
from torch import nn as nn
from tqdm import tqdm

In [13]:
class ActorModel(nn.Module):
    def __init__(self, observation_size, hidden_size, out_size) -> None:
        super().__init__()
        self.activation = nn.LeakyReLU(0.2)
        self.lin1 = nn.Linear(observation_size, hidden_size)
        self.mean = nn.Linear(hidden_size, out_size)
        self.std = nn.Linear(hidden_size, out_size)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        mean = self.mean(out)
        std = self.std(out)
        return mean, torch.abs(std) + 0.001


In [14]:
class Actor:
    def __init__(self, gamma, lr, device, observation_size, hidden_size, out_size, b) -> None:
        self.policy = ActorModel(observation_size, hidden_size, out_size)
        self.policy = self.policy.to(device)
        self.gamma = gamma
        self.b = b
        self.optim = torch.optim.Adam(self.policy.parameters(), lr)

    def get_policy(self, observation):
        means, std = self.policy(observation)
        distribution = torch.distributions.Normal(means, std)
        return distribution

    def get_action(self, observation):
        distribution = self.get_policy(observation)
        action = distribution.sample()
        log_prob = distribution.log_prob(action).sum(dim=-1)
        return action, log_prob
    
    def update(self, rewards, next_state_values, state_values, log_probs, terminals, current_pi, past_pi):
        self.optim.zero_grad()
        loss = -(rewards + self.gamma * next_state_values * (1 - terminals) - state_values) * log_probs * torch.min(current_pi / past_pi, self.b * torch.ones_like(current_pi))
        loss = torch.mean(loss)
        loss.backward()
        self.optim.step()

In [15]:
class CriticModel(nn.Module):
    def __init__(self, observation_size, hidden_size) -> None:
        super().__init__()
        self.activation = nn.ReLU()
        self.lin1 = nn.Linear(observation_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, 1)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        return self.lin3(out)


In [16]:
class Critic:
    def __init__(self, gamma, lr, b, device, observation_size, hidden_size) -> None:
        self.value_function = CriticModel(observation_size, hidden_size)
        self.value_function = self.value_function.to(device)
        self.gamma = gamma
        self.b = b
        self.optim = torch.optim.Adam(self.value_function.parameters(), lr)
    
    def get_state_value(self, state):
        return self.value_function(state)
    
    def update(self, rewards, state_values, next_state_values, terminals, current_pi, past_pi):
        self.optim.zero_grad()
        loss = ((rewards + (self.gamma * next_state_values)*(1 - terminals) - state_values) * torch.min(current_pi / past_pi, self.b * torch.ones_like(current_pi))).pow(2)
        loss = torch.mean(loss)
        loss.backward()
        self.optim.step()

In [17]:
class ReplayBuffer:
    def __init__(self, max_length, observation_shape, action_shape):
        self.max_length = max_length
        self.rewards = torch.zeros(max_length)
        self.observations = torch.zeros([max_length, observation_shape])
        self.actions = torch.zeros([max_length, action_shape])
        self.next_observations = torch.zeros([max_length, observation_shape])
        self.pi = torch.zeros(max_length)
        self.terminals = torch.zeros(max_length)
        self.current_idx = 0
        self.added = 0
    
    def add_experience(self, reward, observation, action, next_observation, pi, terminal):
        self.rewards[self.current_idx] = reward
        self.observations[self.current_idx] = observation
        self.actions[self.current_idx] = action
        self.next_observations[self.current_idx] = next_observation
        self.pi[self.current_idx] = pi
        self.terminals[self.current_idx] = terminal

        self.added += 1
        self.current_idx = (self.current_idx + 1) % self.max_length
    
    def get_batch(self, size):
        if size > self.max_length:
            size = self.max_length
        if size > self.added:
            size = self.added
        indices = torch.randint(min(self.added, self.max_length), (size,))
        rewards = self.rewards[indices]
        observations = self.observations[indices]
        actions = self.actions[indices]
        next_obs = self.next_observations[indices]
        pi = self.pi[indices]
        terminals = self.terminals[indices]

        return rewards, observations, actions, next_obs, pi, terminals


In [24]:
class Agent:
    def __init__(self, env_name, gamma, actor_lr, critic_lr, device, hidden_size_actor, hidden_size_critic, batch_size, buffer_length, b) -> None:
        self.device = device
        env = gym.make(env_name, max_episode_steps=500, continuous=True)
        self.test_env = gym.make(env_name, render_mode="human", max_episode_steps=500, continuous=True)
        self.env_wrapper = gym.wrappers.AutoResetWrapper(env)
        self.actor = Actor(
            gamma=gamma, 
            lr=actor_lr, 
            device=device, 
            observation_size=env.observation_space.shape[0], 
            hidden_size=hidden_size_actor, 
            out_size=env.action_space.shape[0], 
            b=b
        )
        self.critic = Critic(
            gamma=gamma, 
            lr=critic_lr, 
            device=device, 
            observation_size=env.observation_space.shape[0], 
            hidden_size= hidden_size_critic, 
            b=b
        )
        self.replay_buffer = ReplayBuffer(buffer_length, env.observation_space.shape[0], env.action_space.shape[0])
        self.batch_size = batch_size
    
    def learn(self, epochs):
        observation, info = self.env_wrapper.reset()
        observation = torch.from_numpy(observation).unsqueeze(0).to(self.device)

        rewards = []
        for epoch in tqdm(range(epochs)):
            ep_rewards = []

            done = False
            while not done:
                action, log_prob = self.actor.get_action(observation)
                action, log_prob = action.squeeze(), log_prob.squeeze()

                new_observation, reward, terminated, truncated, info = self.env_wrapper.step(action.cpu().numpy())
                new_observation = torch.from_numpy(new_observation).unsqueeze(0).to(self.device)

                done = terminated or truncated
                
                self.replay_buffer.add_experience(
                    reward=reward, 
                    observation=observation, 
                    action=action,
                    next_observation=new_observation, 
                    pi=torch.exp(log_prob), 
                    terminal=done
                )

                batch_rewards, batch_observations, batch_actions, batch_next_obs, batch_pi, terminals = self.replay_buffer.get_batch(self.batch_size)
                batch_rewards = batch_rewards.to(self.device)
                batch_observations = batch_observations.to(self.device)
                batch_actions = batch_actions.to(self.device)
                batch_next_obs = batch_next_obs.to(self.device)
                batch_pi = batch_pi.to(self.device)
                terminals = terminals.to(self.device)

                batch_obs_value = self.critic.get_state_value(batch_observations)
                batch_next_obs_value = self.critic.get_state_value(batch_next_obs)
                policy = self.actor.get_policy(batch_observations)
                batch_log_probs = policy.log_prob(batch_actions).sum(dim=-1)
                current_pi = torch.exp(batch_log_probs)

                self.critic.update(
                    rewards=batch_rewards, 
                    state_values=batch_obs_value, 
                    next_state_values=batch_next_obs_value, 
                    terminals=terminals, 
                    current_pi=current_pi.detach(),
                    past_pi=batch_pi.detach()
                )
                self.actor.update(
                    rewards=batch_rewards.detach(), 
                    state_values=batch_obs_value.detach(), 
                    next_state_values=batch_next_obs_value.detach(), 
                    log_probs=batch_log_probs,
                    terminals=terminals.detach(),
                    current_pi=current_pi.detach(),
                    past_pi=batch_pi.detach()
                )
                ep_rewards.append(reward)
                observation = new_observation
                        
            rewards.append(sum(ep_rewards))
            
            if epoch % 100 == 0:
                for _ in range(5):
                    done = False
                    test_observation, info = self.test_env.reset()
                    test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                    while not done:
                        with torch.no_grad():
                            action, _ = self.actor.get_action(test_observation)
                            action = action.squeeze()
                        test_observation, _, terminated, truncated, _ = self.test_env.step(action.cpu().numpy())
                        test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                        done = terminated or truncated
                print(f"Episode {epoch}\t Mean reward = {sum(rewards)/len(rewards)}")
                rewards = []
        self.test_env.close()


In [25]:
agent = Agent(
    env_name="LunarLander-v2",
    gamma=0.99,
    actor_lr=0.003,
    critic_lr=0.003,
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    hidden_size_actor=128,
    hidden_size_critic=128,
    batch_size=256,
    buffer_length=100000,
    b=10000000
)

In [26]:
agent.learn(5000)

  0%|          | 1/5000 [00:15<21:45:07, 15.66s/it]

Episode 0	 Mean reward = -398.9090468282223


  0%|          | 17/5000 [00:26<1:14:50,  1.11it/s]