In [1]:
import gymnasium as gym
from gymnasium.spaces.box import Box

import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
import torch.nn.functional as F

# device = "cpu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from IPython import display
import cv2
import matplotlib.pyplot as plt

In [2]:
def _process_frame42(frame):
    frame = frame[34:34 + 160, :160]
    # Resize by half, then down to 42x42 (essentially mipmapping). If
    # we resize directly we lose pixels that, when mapped to 42x42,
    # aren't close enough to the pixel boundary.
    frame = cv2.resize(frame, (80, 80))
    frame = cv2.resize(frame, (42, 42))
    frame = frame.mean(2, keepdims=True)
    frame = frame.astype(np.float32)
    frame *= (1.0 / 255.0)
    frame = np.moveaxis(frame, -1, 0)
    return frame

class AtariRescale42x42(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(AtariRescale42x42, self).__init__(env)
        self.observation_space = Box(0.0, 1.0, [1, 42, 42])

    def observation(self, observation):
        return _process_frame42(observation)
    
class NormalizedEnv(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(NormalizedEnv, self).__init__(env)
        self.state_mean = 0
        self.state_std = 0
        self.alpha = 0.9999
        self.num_steps = 0

    def observation(self, observation):
        self.num_steps += 1
        self.state_mean = self.state_mean * self.alpha + \
            observation.mean() * (1 - self.alpha)
        self.state_std = self.state_std * self.alpha + \
            observation.std() * (1 - self.alpha)

        unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
        unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))

        return (observation - unbiased_mean) / (unbiased_std + 1e-8)

max_episode_steps = 1000000
def create_atari_env(env_id):
    env = gym.make(env_id, max_episode_steps=max_episode_steps)
    # env = gym.make(env_id, max_episode_steps=max_episode_steps, render_mode="human")
    env = AtariRescale42x42(env)
    env = NormalizedEnv(env)
    return env


env = create_atari_env("ALE/Pong-v5")

In [3]:
state_dim = env.observation_space.shape
action_dim = env.action_space.n
state_dim, action_dim

((1, 42, 42), 6)

In [4]:
class ActorCritic(torch.nn.Module):
    def __init__(self, num_inputs=state_dim[0], action_dim=action_dim):
        super(ActorCritic, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)

        # self.lstm = nn.LSTMCell(32 * 3 * 3, 256)
        self.fc = nn.Linear(32 * 3 * 3, 256)

        num_outputs = action_dim
        self.critic_linear = nn.Linear(256, 1)
        self.actor_linear = nn.Linear(256, num_outputs)
        self.soft = nn.Softmax(dim=-1)


        self.train()

    def forward(self, state):
        # state = torch.from_numpy(state)
        x = F.elu(self.conv1(state))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))

        x = x.view(-1, 32 * 3 * 3)
        x = self.fc(x)
        action_probs = self.soft(self.actor_linear(x))
        state_val = self.critic_linear(x)
        return action_probs, state_val

    def act(self, state):
        action_probs, state_val = self(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):  
        action_probs, state_val = self(state)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        return action_logprobs, state_val, dist_entropy


state, _ = env.reset()
# state = torch.from_numpy(state)
# model = ActorCritic()
# cx = torch.zeros(1, 256).float()
# hx = torch.zeros(1, 256).float()
# # action_probs, state_val, (hx, cx) = model((state, (hx, cx)))
# # action_probs.shape, state_val.shape
# action, action_logprob, state_val = model.act(state)
# action_logprobs, state_val, dist_entropy = model.evaluate(state, action)

In [5]:
state.shape

(1, 42, 42)

In [6]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

In [7]:
class PPO:
    def __init__(self, state_dim=state_dim[0], action_dim=action_dim, lr_actor=0.0003 , lr_critic=0.001 , gamma=0.99 , K_epochs=80, eps_clip=0.2):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.buffer = RolloutBuffer()
        self.time_buffer_clear = 0

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr_actor)

        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()

        def create_discount_matrix(gamma, n):
            temp = torch.tensor([gamma ** i for i in range(n)], dtype=torch.float32)
            final = torch.zeros((n, n), dtype=torch.float32)
            for i in range(n):
                final[i:, i] = temp[:n-i]

            return final
        
        self.discount_matrix = create_discount_matrix(gamma, 10000).to(device)

    def select_action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(device)
            action, action_logprob, state_val = self.policy_old.act(state)
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.item()

    def update(self):
        rewards = torch.tensor(self.buffer.rewards).view(1, -1).float().to(device)
        rewards = rewards @ self.discount_matrix
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
        rewards = rewards.view(-1)

        old_states = torch.stack(self.buffer.states, dim=0).detach().to(device)
        # old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        advantages = rewards.detach() - old_state_values.detach()

        for _ in range(self.K_epochs):
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            state_values = torch.squeeze(state_values)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())

        self.buffer.clear()
        # if self.time_buffer_clear == 0:
        #     self.time_buffer_clear = 0
        #     self.buffer.clear()
        # self.time_buffer_clear  += 1

    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
  
    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
ppo_agent = PPO()
ppo_agent.load(f"ppo_actor_kaggle.pt")
# list_episodes_reward = []
time_step = 0
# i_episode = 0

In [8]:
# ppo_agent.load(f"ppo_actor.pt")
def create_atari_env(env_id):
    # env = gym.make(env_id, max_episode_steps=max_episode_steps)
    env = gym.make(env_id, max_episode_steps=max_episode_steps, render_mode="human")
    env = AtariRescale42x42(env)
    env = NormalizedEnv(env)
    return env


env = create_atari_env("ALE/Pong-v5")
state, info = env.reset()

episode_reward = 0

while True:
    # action = env.action_space.sample()  # agent policy that uses the observation and info
    action = ppo_agent.select_action(state)
    state, reward, terminated, truncated, info = env.step(action)
    episode_reward += reward
    if reward != 0:
        print(reward)
    if terminated or truncated:
        observation, info = env.reset()
        print(episode_reward)
        episode_reward = 0

env.close()

-1.0
1.0
-1.0
1.0
1.0
-1.0
1.0
1.0
-1.0
1.0
-1.0
1.0
1.0
1.0
1.0
1.0
-1.0
1.0
-1.0
1.0
-1.0
-1.0
1.0
-1.0
1.0
1.0
1.0
1.0
-1.0
1.0
1.0
-1.0
1.0
9.0


In [None]:
# import gymnasium as gym
# from gymnasium.spaces import Box
# import numpy as np
# import cv2

# def _process_frame42(frame):
#     frame = frame[34:34 + 160, :160]
#     frame = cv2.resize(frame, (80, 80))
#     frame = cv2.resize(frame, (42, 42))
#     frame = frame.mean(2, keepdims=True)
#     frame = frame.astype(np.float32)
#     frame *= (1.0 / 255.0)
#     frame = np.moveaxis(frame, -1, 0)
#     return frame

# class AtariRescale42x42(gym.ObservationWrapper):
#     def __init__(self, env=None):
#         super(AtariRescale42x42, self).__init__(env)
#         self.observation_space = Box(0.0, 1.0, [1, 42, 42])

#     def observation(self, observation):
#         return _process_frame42(observation)

# class NormalizedEnv(gym.ObservationWrapper):
#     def __init__(self, env=None):
#         super(NormalizedEnv, self).__init__(env)
#         self.state_mean = 0
#         self.state_std = 0
#         self.alpha = 0.9999
#         self.num_steps = 0

#     def observation(self, observation):
#         self.num_steps += 1
#         self.state_mean = self.state_mean * self.alpha + \
#             observation.mean() * (1 - self.alpha)
#         self.state_std = self.state_std * self.alpha + \
#             observation.std() * (1 - self.alpha)

#         unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
#         unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))

#         return (observation - unbiased_mean) / (unbiased_std + 1e-8)

# class CustomPongEnv(gym.Wrapper):
#     def __init__(self, env):
#         super(CustomPongEnv, self).__init__(env)
#         self.last_hit_by_player = False
#         self.ball_x_previous = None

#     def reset(self, **kwargs):
#         self.last_hit_by_player = False
#         self.ball_x_previous = None
#         return super().reset(**kwargs)

#     def get_ball_position(self, observation):
#         # The observation is now a 1x42x42 array
#         # We'll use the middle row to estimate the ball's x position
#         middle_row = observation[0, 21, :]
#         ball_x = np.argmax(middle_row)
#         return ball_x if middle_row[ball_x] > 0 else None

#     def step(self, action):
#         observation, reward, terminated, truncated, info = self.env.step(action)
        
#         ball_x = self.get_ball_position(observation)
        
#         if ball_x is not None and self.ball_x_previous is not None:
#             if ball_x < self.ball_x_previous and self.last_hit_by_player:
#                 # Ball is moving left after player hit, opponent hit the ball
#                 reward = -1
#                 self.last_hit_by_player = False
#             elif ball_x > self.ball_x_previous and not self.last_hit_by_player:
#                 # Ball is moving right after opponent hit, player hit the ball
#                 reward = 1
#                 self.last_hit_by_player = True
        
#         self.ball_x_previous = ball_x
        
#         return observation, reward, terminated, truncated, info

# def create_custom_pong_env(env_id, max_episode_steps=1000000, render_mode='human'):
#     env = gym.make(env_id, max_episode_steps=max_episode_steps, render_mode=render_mode)
#     env = AtariRescale42x42(env)
#     env = NormalizedEnv(env)
#     env = CustomPongEnv(env)
#     return env

# # Usage
# env = create_custom_pong_env("ALE/Pong-v5")
# state, info = env.reset()

# for _ in range(10000):  # Run for more steps to see the effects
#     # action = env.action_space.sample()  # Replace with your agent's action
#     action = ppo_agent.select_action(state)
#     state, reward, terminated, truncated, info = env.step(action)
    
#     if reward != 0:
#         print(f"Reward: {reward}")
    
#     if terminated or truncated:
#         state, info = env.reset()

# env.close()