In [20]:
import gymnasium as gym
import torch
from torch import nn as nn
from tqdm import tqdm

In [21]:
class ActorModel(nn.Module):
    def __init__(self, observation_size, hidden_size, out_size) -> None:
        super().__init__()
        self.activation = nn.LeakyReLU(0.2)
        self.lin1 = nn.Linear(observation_size, hidden_size)
        self.out = nn.Linear(hidden_size, out_size)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        out = self.out(out)
        return nn.functional.softmax(out, dim=-1)


In [22]:
class Actor:
    def __init__(self, gamma, lr, device, observation_size, hidden_size, out_size) -> None:
        self.policy = ActorModel(observation_size, hidden_size, out_size)
        self.policy = self.policy.to(device)
        self.gamma = gamma
        self.optim = torch.optim.Adam(self.policy.parameters(), lr)

    def get_action(self, action):
        probs = self.policy(action)
        probs = probs.squeeze()
        distribution = torch.distributions.Categorical(probs)
        action = distribution.sample()
        log_prob = distribution.log_prob(action)
        return action, log_prob 
    
    def update(self, reward, next_state_value, state_value, log_prob):
        self.optim.zero_grad()
        loss = -(reward + self.gamma * next_state_value - state_value) * log_prob
        loss = torch.sum(loss)
        loss.backward()
        self.optim.step()

In [23]:
class CriticModel(nn.Module):
    def __init__(self, observation_size, hidden_size) -> None:
        super().__init__()
        self.activation = nn.ReLU()
        self.lin1 = nn.Linear(observation_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, 1)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        return self.lin3(out)


In [24]:
class Critic:
    def __init__(self, gamma, lr, device, observation_size, hidden_size) -> None:
        self.value_function = CriticModel(observation_size, hidden_size)
        self.value_function = self.value_function.to(device)
        self.gamma = gamma
        self.optim = torch.optim.Adam(self.value_function.parameters(), lr)
    
    def get_state_value(self, state):
        return self.value_function(state)
    
    def update(self, reward, state_value, next_state_value):
        self.optim.zero_grad()
        loss = (reward + self.gamma * next_state_value - state_value).pow(2)
        loss.backward()
        self.optim.step()

In [25]:
class Agent:
    def __init__(self, env_name, gamma, actor_lr, critic_lr, device, hidden_size_actor, hidden_size_critic) -> None:
        self.device = device
        env = gym.make(env_name, max_episode_steps=500)
        self.test_env = gym.make(env_name, render_mode="human", max_episode_steps=500)
        self.env_wrapper = gym.wrappers.AutoResetWrapper(env)
        self.actor = Actor(gamma, actor_lr, device, observation_size=env.observation_space.shape[0], hidden_size=hidden_size_actor, out_size=env.action_space.n)
        self.critic = Critic(gamma, critic_lr, device, observation_size=env.observation_space.shape[0], hidden_size= hidden_size_critic)
    
    def learn(self, epochs):
        observation, info = self.env_wrapper.reset()
        observation = torch.from_numpy(observation).unsqueeze(0).to(self.device)

        rewards = []
        for epoch in tqdm(range(epochs)):
            ep_rewards = []

            done = False
            while not done:
                action, log_prob = self.actor.get_action(observation)
                action, log_prob = action.squeeze(), log_prob.squeeze()
                first_state_value = self.critic.get_state_value(observation)
                first_state_value = first_state_value.squeeze()

                observation, reward, terminated, truncated, info = self.env_wrapper.step(action.cpu().numpy())
                observation = torch.from_numpy(observation).unsqueeze(0).to(self.device)

                done = terminated or truncated
                if not done:
                    next_state_value = self.critic.get_state_value(observation)
                    next_state_value = next_state_value.squeeze()
                else:
                    next_state_value = torch.tensor(0)

                self.critic.update(reward=reward, state_value=first_state_value, next_state_value=next_state_value)
                self.actor.update(reward=reward, state_value=first_state_value.detach(), next_state_value=next_state_value.detach(), log_prob=log_prob)
                ep_rewards.append(reward)
                        
            rewards.append(sum(ep_rewards))
            
            if epoch % 100 == 0:
                for _ in range(10):
                    done = False
                    test_observation, info = self.test_env.reset()
                    test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                    while not done:
                        with torch.no_grad():
                            action, _ = self.actor.get_action(test_observation)
                            action = action.squeeze()
                        test_observation, _, terminated, truncated, _ = self.test_env.step(action.cpu().numpy())
                        test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                        done = terminated or truncated
                print(f"Episode {epoch}\t Mean reward = {sum(rewards)/len(rewards)}")
                rewards = []
        self.test_env.close()


In [26]:
agent = Agent(
    env_name="LunarLander-v2",
    gamma=0.99,
    actor_lr=0.0003,
    critic_lr=0.0003,
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    hidden_size_actor=128,
    hidden_size_critic=128
)

In [27]:
agent.learn(5000)

  0%|          | 1/5000 [00:21<29:47:11, 21.45s/it]

Episode 0	 Mean reward = -44.37886048115726


  2%|▏         | 101/5000 [02:08<24:28:24, 17.98s/it]

Episode 100	 Mean reward = -188.90449342421743


  4%|▍         | 201/5000 [05:42<36:32:04, 27.41s/it]

Episode 200	 Mean reward = -110.15812566757918


  6%|▌         | 301/5000 [09:33<34:54:49, 26.75s/it]

Episode 300	 Mean reward = -99.86985433964301


  8%|▊         | 401/5000 [13:35<34:36:31, 27.09s/it]

Episode 400	 Mean reward = -66.08879226071484


 10%|█         | 501/5000 [18:04<35:33:07, 28.45s/it]

Episode 500	 Mean reward = -70.70048310471535


 12%|█▏        | 601/5000 [22:06<36:30:02, 29.87s/it]

Episode 600	 Mean reward = -59.64430906998118


 14%|█▍        | 701/5000 [26:04<34:00:04, 28.47s/it]

Episode 700	 Mean reward = -53.69280277567734


 16%|█▌        | 801/5000 [30:10<35:49:39, 30.72s/it]

Episode 800	 Mean reward = -72.92890242792123


 18%|█▊        | 901/5000 [34:14<33:40:47, 29.58s/it]

Episode 900	 Mean reward = -51.952124271450074


 20%|██        | 1001/5000 [38:28<34:56:40, 31.46s/it]

Episode 1000	 Mean reward = -68.71974778543132


 22%|██▏       | 1101/5000 [42:34<30:48:28, 28.45s/it]

Episode 1100	 Mean reward = -49.73753179388853


 24%|██▍       | 1201/5000 [46:50<33:44:55, 31.98s/it]

Episode 1200	 Mean reward = -66.14909050301968


 26%|██▌       | 1301/5000 [51:04<32:51:42, 31.98s/it]

Episode 1300	 Mean reward = -63.71828729222067


 28%|██▊       | 1401/5000 [55:28<31:37:12, 31.63s/it]

Episode 1400	 Mean reward = -74.87543334694847


 30%|███       | 1501/5000 [59:51<31:06:50, 32.01s/it]

Episode 1500	 Mean reward = -64.00129977374881


 32%|███▏      | 1601/5000 [1:04:15<30:12:44, 32.00s/it]

Episode 1600	 Mean reward = -72.69347831896879


 34%|███▍      | 1701/5000 [1:08:42<29:02:44, 31.70s/it]

Episode 1700	 Mean reward = -67.08139972834347


 36%|███▌      | 1801/5000 [1:13:02<28:21:29, 31.91s/it]

Episode 1800	 Mean reward = -43.636396605961124


 38%|███▊      | 1901/5000 [1:17:15<25:33:46, 29.70s/it]

Episode 1900	 Mean reward = -41.22268433580156


 40%|████      | 2001/5000 [1:21:38<26:37:26, 31.96s/it]

Episode 2000	 Mean reward = -34.492464438475174


 42%|████▏     | 2101/5000 [1:26:03<25:01:44, 31.08s/it]

Episode 2100	 Mean reward = -57.32520203814877


 44%|████▍     | 2201/5000 [1:30:23<24:52:00, 31.98s/it]

Episode 2200	 Mean reward = -58.84065636780387


 46%|████▌     | 2301/5000 [1:34:36<22:34:41, 30.12s/it]

Episode 2300	 Mean reward = -39.23274531166462


 48%|████▊     | 2401/5000 [1:38:59<23:03:04, 31.93s/it]

Episode 2400	 Mean reward = -42.892657898951505


 50%|█████     | 2501/5000 [1:43:32<21:55:34, 31.59s/it]

Episode 2500	 Mean reward = -46.86751174185342


 52%|█████▏    | 2600/5000 [1:46:15<1:11:56,  1.80s/it] 