In [1]:
import gymnasium as gym
import torch
from torch import nn as nn
from tqdm import tqdm

In [2]:
class ActorModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.activation = nn.LeakyReLU(0.2)
        self.lin1 = nn.Linear(4, 128)
        self.out = nn.Linear(128, 2)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        out = self.out(out)
        return nn.functional.softmax(out, dim=-1)


In [3]:
class Actor:
    def __init__(self, gamma, lr, device) -> None:
        self.policy = ActorModel()
        self.policy = self.policy.to(device)
        self.gamma = gamma
        self.optim = torch.optim.Adam(self.policy.parameters(), lr)

    def get_action(self, action):
        probs = self.policy(action)
        probs = probs.squeeze()
        distribution = torch.distributions.Categorical(probs)
        action = distribution.sample()
        log_prob = distribution.log_prob(action)
        return action, log_prob 
    
    def update(self, reward, next_state_value, state_value, log_prob):
        self.optim.zero_grad()
        loss = -(reward + self.gamma * next_state_value - state_value) * log_prob
        loss = torch.sum(loss)
        loss.backward()
        self.optim.step()

In [4]:
class CriticModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.activation = nn.ReLU()
        self.lin1 = nn.Linear(4, 128)
        self.lin3 = nn.Linear(128, 1)
    
    def forward(self, state):
        out = self.lin1(state)
        out = self.activation(out)
        return self.lin3(out)


In [5]:
class Critic:
    def __init__(self, gamma, lr, device) -> None:
        self.value_function = CriticModel()
        self.value_function = self.value_function.to(device)
        self.gamma = gamma
        self.optim = torch.optim.Adam(self.value_function.parameters(), lr)
    
    def get_state_value(self, state):
        return self.value_function(state)
    
    def update(self, reward, state_value, next_state_value):
        self.optim.zero_grad()
        loss = (reward + self.gamma * next_state_value - state_value).pow(2)
        loss.backward()
        self.optim.step()

In [6]:
class Agent:
    def __init__(self, gamma, actor_lr, critic_lr, device) -> None:
        self.device = device
        env = gym.make("CartPole-v1", max_episode_steps=500)
        # self.test_env = gym.make("CartPole-v1", render_mode="human", max_episode_steps=500)
        self.env_wrapper = gym.wrappers.AutoResetWrapper(env)
        self.actor = Actor(gamma, actor_lr, device)
        self.critic = Critic(gamma, critic_lr, device)
    
    def learn(self, epochs):
        observation, info = self.env_wrapper.reset()
        observation = torch.from_numpy(observation).unsqueeze(0).to(self.device)

        rewards = []
        for epoch in tqdm(range(epochs)):
            ep_rewards = []

            done = False
            while not done:
                action, log_prob = self.actor.get_action(observation)
                action, log_prob = action.squeeze(), log_prob.squeeze()
                first_state_value = self.critic.get_state_value(observation)
                first_state_value = first_state_value.squeeze()

                observation, reward, terminated, truncated, info = self.env_wrapper.step(action.cpu().numpy())
                observation = torch.from_numpy(observation).unsqueeze(0).to(self.device)

                done = terminated or truncated
                if not done:
                    next_state_value = self.critic.get_state_value(observation)
                    next_state_value = next_state_value.squeeze()
                else:
                    next_state_value = torch.tensor(0)

                self.critic.update(reward=reward, state_value=first_state_value, next_state_value=next_state_value)
                self.actor.update(reward=reward, state_value=first_state_value.detach(), next_state_value=next_state_value.detach(), log_prob=log_prob)
                ep_rewards.append(reward)
                        
            rewards.append(sum(ep_rewards))
            
            if epoch % 500 == 0:
                # for _ in range(10):
                #     done = False
                #     test_observation, info = self.test_env.reset()
                #     test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                #     while not done:
                #         with torch.no_grad():
                #             action, _ = self.actor.get_action(test_observation)
                #             action = action.squeeze()
                #         test_observation, _, terminated, truncated, _ = self.env_wrapper.step(action.cpu().numpy())
                #         test_observation = torch.from_numpy(test_observation).unsqueeze(0).to(self.device)
                #         done = terminated or truncated
                print(f"Episode {epoch}\t Mean reward = {sum(rewards)/len(rewards)}")
                rewards = []
        self.test_env.close()


In [7]:
agent = Agent(0.99, 0.0003, 0.0003, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

In [8]:
agent.learn(50000)

  0%|          | 4/50000 [00:01<2:52:45,  4.82it/s] 

Episode 0	 Mean reward = 31.0


  0%|          | 234/50000 [00:24<1:27:16,  9.50it/s]


KeyboardInterrupt: 

In [52]:
agent.env.close()