In [None]:
import torch.nn as nn 
import torch.optim as opt
import random
from torch.distributions.categorical import Categorical
import torch 
import gymnasium as gym
import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
device

device(type='cuda')

In [2]:
class Actor(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super(Actor, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.layers = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, out_features),
        )

    def forward(self, states: torch.tensor) -> torch.tensor:
        logits = self.layers(states)
        return Categorical(logits = logits)


class Critic(nn.Module):
    def __init__(self,  in_features: int, out_features: int):
        super(Critic, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.layers = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, out_features)
        )
    
    def forward(self, states: torch.tensor) -> torch.tensor:
        return self.layers(states)

In [None]:
class Agent:
    def __init__(self, env, actor: Actor, critic: Critic, epsilon: float, gamma: float, lam: float, actor_lr: float, critic_lr: float):
        self.env = env
        self.actor = actor
        self.critic = critic
        self.epsilon = epsilon
        self.gamma = gamma
        self.lam = lam
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_opt = opt.Adam(self.actor.parameters(), actor_lr)
        self.critic_opt = opt.Adam(self.critic.parameters(), critic_lr)

    def save_agent(self, path: str = 'cartpole_agent.pt') -> None:
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'hyperparameters': {
                'actor_lr': self.actor_lr,
                'critic_lr': self.critic_lr,
                'epsilon': self.epsilon,
                'gamma': self.gamma,
                'lam': self.lam,
                'actor_in_feats': self.actor.in_features,
                'actor_out_feats': self.actor.out_features,
                'critic_in_feats': self.critic.in_features,
                'critic_out_feats': self.critic.out_features
            }
        }, path)

    @staticmethod
    def load_agent(path: str, env, device: str) -> 'Agent':
        chckpt = torch.load(path)

        actor = Actor(chckpt['hyperparameters']['actor_in_feats'], chckpt['hyperparameters']['actor_out_feats'])
        critic = Critic(chckpt['hyperparameters']['critic_in_feats'], chckpt['hyperparameters']['critic_out_feats'])
        actor.load_state_dict(chckpt['actor'])
        critic.load_state_dict(chckpt['critic'])

        return Agent(
            env, 
            actor.to(device), 
            critic.to(device), 
            chckpt['hyperparameters']['epsilon'],
            chckpt['hyperparameters']['gamma'],
            chckpt['hyperparameters']['lam'],
            chckpt['hyperparameters']['actor_lr'],
            chckpt['hyperparameters']['critic_lr']
        )


    def select_action(self, states: torch.tensor) -> tuple:
        dist = self.actor.forward(states)
        actions = dist.sample()

        return actions, dist.log_prob(actions)  
    
    def get_state_values(self, states: torch.tensor) -> torch.tensor:
        return self.critic.forward(states)

    def update_nets(self, actor_loss: torch.tensor, critic_loss: torch.tensor) -> None:
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()

    def fit(self, train_iters: int, timesteps: int, K: int, bs: int) -> None:
        if bs > self.env.num_envs * timesteps:
            raise ValueError('batch size cannot be greater than number of environments * timesteps')
        
        self.all_rewards = []
        self.all_steps = []

        for train_iter in range(train_iters):
            rollout = []
            ep_reward = 0
            ep_steps = 0

            states, _ = self.env.reset()
            states = torch.from_numpy(states).to(device)

            for _ in range(timesteps):
                actions, probs = self.select_action(states)
                state_vals = self.get_state_values(states)

                next_states, rewards, dones, terminated, _ = self.env.step(actions.detach().cpu().numpy())
                next_states = torch.from_numpy(next_states).to(device)
                rewards = torch.from_numpy(rewards)
                dones = torch.from_numpy(dones)

                ep_reward += sum(rewards).item()

                next_state_vals = self.get_state_values(next_states)
                for s, sv, ns, nsv, a, p, r, d in zip(states, state_vals.detach(), next_states, next_state_vals, actions, probs.detach(), rewards, dones):
                    rollout.append([s, sv, ns, nsv, a, p, r, d])

                states = next_states
                ep_steps += 1
            
            print('finished episode:', train_iter)
            print('total reward:', ep_reward)
            print('number of steps:', ep_steps)
            print('-' * 15)

            self.all_rewards.append(ep_reward)
            self.all_steps.append(ep_steps)

            # next_advantage = 0
            # for t in reversed(range(len(rollout))):
            #     delta = rollout[t][6] + self.gamma * (rollout[t][3] if t + 1 < len(rollout) else 0) - rollout[t][1]
            #     rollout[t].append((delta + self.gamma * self.lam * next_advantage).detach())
            #     next_advantage = rollout[t][8]

            for i in range(len(advantages) - 1): # dont go out of bounds
                discount = 1
                advantage_t = 0
                for j in range(i, len(rollout) - 1):
                    delta = (rollout[j][6] + self.gamma * rollout[j][3] - rollout[j][1])
                    advantage_t += discount * delta
                rollout[i].append(advantage_t)
                discount *= self.gamma * self.lam

            for _ in range(K):
                samples = random.sample(rollout, bs)
                states = torch.stack([s[0] for s in samples])
                old_probs = torch.stack([s[5] for s in samples]) # get action probabilites from samples
                actions = torch.stack([s[4] for s in samples]) # get selected actions from samples

                advantages = torch.tensor([s[8] for s in samples]).to(device)
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                new_probs = torch.stack([self.actor.forward(state).log_prob(action) for state, action in zip(states, actions)]) # get new action probs from sample states
                ratio = (new_probs - old_probs).exp()

                G_vals = advantages + torch.stack([s[1] for s in samples])

                critic_loss = ((G_vals - self.critic.forward(states)) ** 2).mean() # loss for critic network
                actor_loss = -torch.min(ratio * advantages, torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages).mean()
                self.update_nets(actor_loss, critic_loss)


In [4]:
env = gym.make('CartPole-v1', render_mode = 'human')
agent = Agent.load_agent('cartpole_agent.pt', env, 'cuda')

In [None]:
eval_rewards = 0
eval_episodes = 10
for i in range(eval_episodes):
    done, truncated = False, False
    state, _ = agent.env.reset()
    total_reward = 0

    while not (done or truncated):
        state_t = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            dist = agent.actor(state_t)
            action = dist.probs.argmax(dim = 1)
            state, reward, done, truncated, _ = agent.env.step(action.item())
            total_reward += reward

    print("evaluation reward:", total_reward)
    eval_rewards += total_reward

print('average reward:', eval_rewards / eval_episodes)

evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
evaluation reward: 500.0
average reward: 500.0


In [6]:
agent.env.close()