In [5]:
import gymnasium
import flappy_bird_gymnasium
import pickle
import gc
import numpy as np
import pygame
import itertools
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from enum import IntEnum
from torchvision.transforms import Compose, ToTensor, Resize, Grayscale
from flappy_bird_gymnasium.envs.flappy_bird_env import FlappyBirdEnv
from flappy_bird_gymnasium.envs.flappy_bird_env import Actions
from flappy_bird_gymnasium.envs.lidar import LIDAR
from flappy_bird_gymnasium.envs.constants import (
    PLAYER_FLAP_ACC,
    PLAYER_ACC_Y,
    PLAYER_MAX_VEL_Y,
    PLAYER_HEIGHT,
    PLAYER_VEL_ROT,
    PLAYER_WIDTH,
    PIPE_WIDTH,
    PIPE_VEL_X,
)

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
from matplotlib.font_manager import FontProperties



def new_render(self):
    """Renders the next frame."""
    if self.render_mode == "rgb_array":
        self._draw_surface(show_score=False, show_rays=False)
        # Flip the image to retrieve a correct aspect
        return np.transpose(pygame.surfarray.array3d(self._surface), axes=(1, 0, 2))
    else:
        self._draw_surface(show_score=True, show_rays=False)
        if self._display is None:
            self._make_display()

        self._update_display()
        self._fps_clock.tick(self.metadata["render_fps"])


FlappyBirdEnv.render = new_render

In [17]:
################
######DQN++######
################

class DQN_pp(nn.Module):
    def __init__(self, input_dim, action_space):
        super(DQN_pp, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)  # First hidden layer
        self.fc2 = nn.Linear(128, 128)   
        self.fc3 = nn.Linear(128, 128)      # Second hidden layer
        self.fc4 = nn.Linear(128, action_space)  # Output layer for Q-values

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        return self.fc4(x)


class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DQN_pp_Agent:
    def __init__(self, 
                 env,
                 hyper ={
                          "learning_rate": 0.001,
                          "discount_factor" : 0.99,
                          "epsilon" : 1.0,
                          "epsilon_decay" :0.999,
                          "epsilon_min" : 0.01,
                          "batch_size" : 64,
                          "memory_size" : 10000,
                          "episodes" : 100000,
                          "target_update_freq" : 10,
                          "rho" :1.0,
                          "kappa" : 1.0,
                          "eps_update_freq":100
                        }
                 ):
        # Environment
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_space = env.action_space.n

        # Device setup for GPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # Hyperparameters
        self.learning_rate = hyper["learning_rate"]
        self.discount_factor = hyper["discount_factor"]
        self.epsilon = hyper["epsilon"]
        self.epsilon_decay = hyper["epsilon_decay"]
        self.epsilon_min = hyper["epsilon_min"]
        self.batch_size = hyper["batch_size"]
        self.memory_size = hyper["memory_size"]
        self.episodes = hyper["episodes"]
        self.target_update_freq = hyper["target_update_freq"]
        self.rho = hyper["rho"]
        self.kappa = hyper["kappa"]
        self.eps_update_freq=hyper["eps_update_freq"]

        # Initialize policy and target networks
        self.policy_net = DQN_plus(self.state_dim, self.action_space).to(self.device)
        self.target_net = DQN_plus(self.state_dim, self.action_space).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        # Optimizer and Replay Memory
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
        self.memory = ReplayMemory(self.memory_size)

    def select_action(self, state, testing=False):
        """Epsilon-greedy action selection."""
        if not testing and random.random() < self.epsilon:
            return self.oracle(random.randint(0, self.action_space - 1) ) # Random action


        else:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
                return self.oracle(self.policy_net(state).argmax(dim=1).item())
    
    def oracle(self, action):
        next_state, reward, done, _, _ = self.env.step(action)
        if done :
            action = (action+1)%2

        return action
        

    def optimize_model(self):
        """Sample a batch from memory and optimize the policy network."""
        if len(self.memory) < self.batch_size:
            return

        batch = self.memory.sample(self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert to tensors and move to device
        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(self.device)

        # Compute Q-values and targets
        q_values = self.policy_net(states).gather(1, actions)
        next_q_values = self.target_net(next_states).max(1, keepdim=True)[0]
        targets = rewards + (self.discount_factor * next_q_values * (1 - dones))

        # Loss and backpropagation
        loss = nn.MSELoss()(q_values, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def train(self):
        res=[]
        reward_trace=0
        old_reward_trace=0
        """Train the agent."""
        for episode in range(self.episodes):
            state, _ = self.env.reset()
            done = False
            total_reward = 0

            while not done:
                # Select and execute action
                action = self.select_action(state)
                next_state, reward, done, _, _ = self.env.step(action)
                self.memory.push(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward

                # Optimize model
                self.optimize_model()


            # Decay epsilon
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

            
            # Update target network periodically
            
            if episode % self.target_update_freq == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

            if abs(self.epsilon-self.epsilon_min) < 1e-4 :
                reward_trace+=total_reward
                if (reward_trace-old_reward_trace)/self.eps_update_freq <self.kappa and  episode % self.eps_update_freq== 0:
                    self.epsilon = self.rho
                    
                    old_reward_trace=reward_trace
                    reward_trace=0


            res.append(total_reward)
            print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Epsilon: {self.epsilon:.4f}")

        self.env.close()
        print("Training complete!")

        return res

    def test(self, num_episodes=10,render="human"):
        
        """Test the trained policy with real-time rendering."""
        print("\nTesting the trained policy...\n")
        self.epsilon = 0.0  # Disable exploration
        test_env = gymnasium.make("FlappyBird-v0", render_mode=render,use_lidar=True)  # Render in "human" mode 
        total_rewards = []

        for episode in range(num_episodes):
            state, _ = test_env.reset()
            done = False
            total_reward = 0

            while not done:
                action = self.select_action(state, testing=True)
                next_state, reward, done, _, _ = test_env.step(action)
                state = next_state
                total_reward += reward

            total_rewards.append(total_reward)
        
            print(f"Test Episode: {episode + 1}, Total Reward: {total_reward}")

        
        avg_reward = np.mean(total_rewards)
        print(f"\nAverage Reward over {num_episodes} Test Episodes: {avg_reward}")
        test_env.close()

        return total_rewards

In [18]:
import gymnasium
import itertools
import pickle
import gc


# Initialize environment
env = gymnasium.make("FlappyBird-v0", render_mode="rgb_array", use_lidar=True)

# Hyperparameters
hyper = {
    "learning_rate": 0.001,
    "discount_factor": 0.99,
    "epsilon": 1.0,
    "epsilon_decay": 0.7,
    "epsilon_min": 0.01,
    "batch_size": 64,
    "memory_size": 10000,
    "episodes": 40000,
    "target_update_freq": 10,
    "rho":[0.1,0.2],
    "kappa":[0.25,0.5,0.75],
    "eps_update_freq":[10,100,1000]
}


param_combinations = itertools.product(
    hyper["rho"], 
    hyper["kappa"],
    hyper["eps_update_freq"]
)

# Initialize result dictionaries
exp_res_DQN_pp = {}
test_res_DQN_pp  = {}

# Iterate through parameter combinations
for rho, kappa,eps_update_freq in param_combinations:
    current_hyperparams = {
        "learning_rate": hyper["learning_rate"],
        "discount_factor": hyper["discount_factor"],
        "epsilon": hyper["epsilon"],
        "epsilon_decay": hyper["epsilon_decay"],
        "epsilon_min": hyper["epsilon_min"],
        "batch_size": hyper["batch_size"],
        "memory_size": hyper["memory_size"],
        "episodes": hyper["episodes"],
        "target_update_freq": hyper["target_update_freq"],
        "rho" :rho,
        "kappa" :kappa,
        "eps_update_freq" : eps_update_freq
    }
    
    try:
        # Train the agent
        agent = DQN_pp_Agent(env, current_hyperparams)
        exp_key = f"rho={rho}_kappa={kappa}"
        exp_res_DQN_pp[exp_key] = agent.train()
        
        # Test the agent
        test_res_DQN_pp[exp_key] = agent.test(num_episodes=5000, render=None)
        
        # Save intermediate results
        with open("exp_res_DQN_pp.pkl", "wb") as f:
            pickle.dump(exp_res_DQN_pp, f)
        with open("test_res_DQN_pp.pkl", "wb") as f:
            pickle.dump(test_res_DQN_pp, f)
        
        print(f"Finished training and testing for: {exp_key}")
    
    finally:
        # Free resources
        #del agent
        gc.collect()
        env.close()


with open("exp_res_DQN_pp.pkl", "rb") as f:
    exp_res_DQN_pp = pickle.load(f)
with open("test_res_DQN_pp.pkl", "rb") as f:
    test_res_DQN_pp = pickle.load(f)

print("All parameter combinations processed.")


cuda
Episode: 1, Total Reward: -1.5999999999999994, Epsilon: 0.7000
Episode: 2, Total Reward: -4.6, Epsilon: 0.4900
Episode: 3, Total Reward: -5.2, Epsilon: 0.3430
Episode: 4, Total Reward: -3.3999999999999995, Epsilon: 0.2401
Episode: 5, Total Reward: -3.3999999999999995, Epsilon: 0.1681
Episode: 6, Total Reward: -2.1999999999999993, Epsilon: 0.1176
Episode: 7, Total Reward: -2.1999999999999993, Epsilon: 0.0824
Episode: 8, Total Reward: -0.6999999999999993, Epsilon: 0.0576
Episode: 9, Total Reward: -0.9999999999999996, Epsilon: 0.0404
Episode: 10, Total Reward: -2.2999999999999994, Epsilon: 0.0282
Episode: 11, Total Reward: -0.9999999999999996, Epsilon: 0.0198
Episode: 12, Total Reward: -0.9999999999999996, Epsilon: 0.0138
Episode: 13, Total Reward: -0.9999999999999996, Epsilon: 0.0100
Episode: 14, Total Reward: -0.9999999999999996, Epsilon: 0.0100
Episode: 15, Total Reward: -0.9999999999999996, Epsilon: 0.0100
Episode: 16, Total Reward: -0.9999999999999996, Epsilon: 0.0100
Episode: 1

KeyboardInterrupt: 