In [None]:
# !pip install gym imageio torchvision==0.15.1 torch==2.0.0 torchaudio==2.0.0 ale-py gymnasium==0.28.1

In [None]:
# !pip install gym[atari] autorom[accept-rom-license]

In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from utils import transform
from fileIO import FileIO
from utils import file_exists, write_history, get_last_history, compute_reward, decimal2

# Define the PPO actor and critic neural network
class ActorCritic(nn.Module):
    """CNN with 3 convolution layers"""
    def __init__(self, input_shape, num_actions):
        super(ActorCritic, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.device = torch.device("cpu")
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU()
        )
        self.actor = nn.Linear(512, num_actions)
        self.critic = nn.Linear(512, 1)

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(torch.prod(torch.tensor(o.size())))

    def forward(self, x):
        # m = x[0].permute(1, 2, 0).numpy()
        # plt.imshow(m)
        # plt.axis('off')  # Optional: Turn off axis ticks and labels
        # plt.show()
        # return 'uuu'
        # x = x.to(self.device)
        conv_out = self.conv(x).view(x.size()[0], -1)
        fc_out = self.fc(conv_out)
        return self.actor(fc_out), self.critic(fc_out)


# Proximal Policy Optimization Agent
class PPOAgent:
    def __init__(self, env, env_name: str, lr=0.001, gamma=0.99, clip_epsilon=0.2,
                 value_coef=0.5, entropy_coef=0.01):
        self.env = env
        self.device = torch.device("cpu")
        self.name = env_name.lower()
        self.input_shape = (3, 160, 160)
        self.num_actions = env.action_space.n
        self.actor_critic = ActorCritic(self.input_shape, self.num_actions).to(self.device)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr)
        
        # Define hyperparameters
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        
        self.resume = False
        self.policy_path = f'saves/{self.name}/ppo_policy__.h5'
        self.history_path = f'saves/{self.name}/ppo_history.csv'
        self.log = ''
        
        # save time by retraining model from a known trained weights
        if (file_exists(self.policy_path)):
            self.resume = True
            self.actor_critic.load_state_dict(torch.load(self.policy_path))
            print('----- loaded saved weights ------')

        # implement evaluation history so traning can be ran in batches to not crash system
        self.current_episode =  get_last_history(self.history_path)

    def act(self, state):
        # state = torch.from_numpy(state).float().unsqueeze(0)
        with torch.no_grad():
            dist, value = self.actor_critic(state)
            logits = F.softmax(dist, dim=-1)
            action_probs = Categorical(logits)
            action = action_probs.sample()
            # self.log = f'{logits}, {action}'
            
        return action.item(), logits.tolist()

    def learn(self, rollout):
        states, actions, old_probs, rewards, next_states, dones = rollout

        actions = torch.LongTensor([actions]).to(self.device)
        rewards = torch.FloatTensor([rewards]).to(self.device)
        old_probs = torch.FloatTensor([old_probs]).to(self.device)

        dist, values = self.actor_critic(states)
        _, next_values = self.actor_critic(next_states)

        advantages = rewards + self.gamma * next_values * (1 - dones) - values.detach()
        returns = rewards + self.gamma * next_values * (1 - dones)

        prob_ratio = torch.exp(dist.squeeze(dim=0) - old_probs)

        surrogate1 = prob_ratio * advantages
        surrogate2 = torch.clamp(prob_ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages

        actor_loss = -torch.min(surrogate1, surrogate2).mean()
        critic_loss = F.smooth_l1_loss(values, returns)

        # Calculate entropy using Categorical distribution
        entropy = Categorical(probs=F.softmax(dist, dim=-1)).entropy().mean()

        loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    
    def train(self, num_episodes):
        history_data = []
        save_after = 10
        for episode in range(num_episodes):
            state = transform(self.env.reset()[0])
            # state = self.env.reset()[0]

            done = False
            total_trad_reward = 0
            total_nontrad_reward = 0
            timestep = 0

            while not done:
                action, logits = self.act(state)
                print('act', action, logits)
                next_state, reward, done, _, __ = self.env.step(action)
                next_state = transform(next_state)
                
                # get traditional and non-traditional reward
                trad, nontrad = compute_reward(reward, done)
                
                combined_reward = trad + nontrad
                total_trad_reward += trad
                total_nontrad_reward += nontrad

                self.learn((state, action, logits, combined_reward, next_state, done))

                state = next_state
                
                done = done or _
                print('done', done)
                timestep += 1

            total_nontrad_reward = decimal2(total_nontrad_reward)
            total_reward = total_trad_reward + total_nontrad_reward
            
            print(f"Episode {episode + 1} | Total Reward: {total_reward}, Traditional: {total_trad_reward}")
            
            # persist history and policy after every 20 episodes
            history_data.append(f"{int(self.current_episode) + episode + 1}, {total_reward}, {total_trad_reward}, {total_nontrad_reward}, {timestep}, {self.log}")
            if ((episode + 1) % save_after == 0):
                # Save the network weights
                torch.save(self.actor_critic.state_dict(), self.policy_path)
                
                # write history
                write_history(self.history_path, history_data)
                history_data = []

    def evaluate(self, num_episodes):
        wins = 0
        for episode in range(num_episodes):
            state = transform(self.env.reset()[0])
            done = False
            while not done:
                action = self.act(state)[0]
                next_state, reward, done, _, __ = self.env.step(action)
                state = transform(next_state)
                if reward > 0:
                    wins += 1
        
        print(f'Wins - {wins}; Episodes - {num_episodes}; Average - {wins / num_episodes}') 
    
    def close_session(self):
        self.env.close()


---- starting environment 1 ----
act 0 [[0.24607427418231964, 0.2610017955303192, 0.24890394508838654, 0.24401994049549103]]
done False
act 1 [[0.25154709815979004, 0.22577346861362457, 0.24970334768295288, 0.2729760408401489]]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


done False
act 2 [[0.24789516627788544, 0.139862522482872, 0.2412540167570114, 0.3709883391857147]]
done False
act 3 [[0.2005818486213684, 0.04341711848974228, 0.1945118010044098, 0.5614892840385437]]
done False
act 0 [[0.1345239281654358, 0.004168695770204067, 0.1417335569858551, 0.7195737361907959]]
done False
act 3 [[0.04260402172803879, 0.0008143780287355185, 0.055215585976839066, 0.9013660550117493]]
done False
act 3 [[0.004798536188900471, 0.00024314771872013807, 0.01533783320337534, 0.9796205163002014]]
done False
act 3 [[0.005713850725442171, 7.650707266293466e-05, 0.011482235044240952, 0.9827273488044739]]
done False
act 0 [[0.38437148928642273, 0.00424599414691329, 0.08630974590778351, 0.5250726938247681]]
done False
act 0 [[0.9335028529167175, 0.04904400184750557, 0.017165854573249817, 0.0002873367920983583]]
done False
act 0 [[0.4462622106075287, 0.5531933307647705, 0.0005444622947834432, 6.250958084486058e-10]]
done False
act 1 [[0.004953577648848295, 0.9950461387634277, 3

KeyboardInterrupt: 

# ATARI Breakout

### Train Proxy Proximal Optimization Agent

In [None]:
env_name = 'Breakout-v4'
# env_name = 'Berzerk-v4'
training_batch = 1
training_episodes = 10
eval_episodes = 10

In [None]:
for i in range(training_batch):
  print(f'---- starting environment {i + 1} ----')
  env = gym.make(env_name)
  env_name = env.spec.id.split('-')[0]
  agent = PPOAgent(env, env_name)
  agent.train(training_episodes)
  agent.close_session()
  
  # free up memory for next iteration
  del env
  del agent


### Evaluate Trained PPO Agent

In [None]:
env = gym.make(env_name)
env_name = env.spec.id.split('-')[0]
agent = PPOAgent(env, env_name)
agent.evaluate(eval_episodes)
agent.close_session()

In [None]:
!pip list