# Pong Game with AI Agent
This notebook demonstrates how to build a Pong game using Pygame and train an AI agent using Deep Q-Learning with Stable-Baselines3.
The notebook is structured as follows:
1. Import necessary libraries.
2. Define a custom environment for the Pong game.
3. Train the agent using the Stable-Baselines3 DQN algorithm.
4. Evaluate the trained agent.
5. Integrate the agent into a playable Pygame-based Pong game.

<img src="./Game.png" alt="pong game screen shot" title="pong game" style="width:70%; height:auto;">


## Step 1: Import Libraries

In [None]:
import pygame
import numpy as np
import random
import gymnasium
from gymnasium import spaces
from stable_baselines3 import DQN
from stable_baselines3.common.env_checker import check_env

## Step 2: Define the Pong Environment
The custom environment is built using the OpenAI Gymnasium interface. This environment simulates the Pong game and defines the actions, states, and rewards for the agent.

In [None]:
class PongEnv(gymnasium.Env):
    def __init__(self):
        super(PongEnv, self).__init__()
        
        # Screen dimensions
        self.SCREEN_WIDTH = 800
        self.SCREEN_HEIGHT = 600
        
        # Paddle dimensions
        self.PADDLE_WIDTH = 10
        self.PADDLE_HEIGHT = 100
        
        # Ball dimensions
        self.BALL_SIZE = 10
        
        # Define action and observation space
        # Actions: 0 - Stay, 1 - Up, 2 - Down
        self.action_space = spaces.Discrete(3)
        
        # Observation space: [player_y, opponent_y, ball_x, ball_y, ball_speed_x, ball_speed_y]
        self.observation_space = spaces.Box(low=np.array([0, 0, 0, 0, -np.inf, -np.inf]), 
                                            high=np.array([self.SCREEN_HEIGHT, self.SCREEN_HEIGHT, self.SCREEN_WIDTH, self.SCREEN_HEIGHT, np.inf, np.inf]), 
                                            dtype=np.float32)
        
        # Initialize Pygame
        pygame.init()
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        pygame.display.set_caption('Pong')
        
        self.clock = pygame.time.Clock()
        
        self.reset()
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.player_y = self.SCREEN_HEIGHT // 2 - self.PADDLE_HEIGHT // 2
        self.opponent_y = self.SCREEN_HEIGHT // 2 - self.PADDLE_HEIGHT // 2
        self.ball_x = self.SCREEN_WIDTH // 2 - self.BALL_SIZE // 2
        self.ball_y = self.SCREEN_HEIGHT // 2 - self.BALL_SIZE // 2
        self.ball_speed_x = 7 * random.choice((1, -1))
        self.ball_speed_y = 7 * random.choice((1, -1))
        
        state = np.array([self.player_y, self.opponent_y, self.ball_x, self.ball_y, self.ball_speed_x, self.ball_speed_y], dtype=np.float32)
        return state, {}
    
    def step(self, action):
        # Player action
        if action == 1:
            self.player_y -= 10
        elif action == 2:
            self.player_y += 10
        
        # Ball movement
        self.ball_x += self.ball_speed_x
        self.ball_y += self.ball_speed_y
        
        # Ball collision with top and bottom
        if self.ball_y <= 0 or self.ball_y >= self.SCREEN_HEIGHT - self.BALL_SIZE:
            self.ball_speed_y *= -1
        
        # Ball collision with paddles
        if (self.ball_x <= 20 and self.opponent_y <= self.ball_y <= self.opponent_y + self.PADDLE_HEIGHT) or \
           (self.ball_x >= self.SCREEN_WIDTH - 20 - self.BALL_SIZE and self.player_y <= self.ball_y <= self.player_y + self.PADDLE_HEIGHT):
            self.ball_speed_x *= -1
        
        # Opponent AI
        if self.opponent_y + self.PADDLE_HEIGHT // 2 < self.ball_y:
            self.opponent_y += 10
        if self.opponent_y + self.PADDLE_HEIGHT // 2 > self.ball_y:
            self.opponent_y -= 10
        
        # Calculate reward
        reward = 0
        terminated = False
        truncated = False
        if self.ball_x <= 0 or self.ball_x >= self.SCREEN_WIDTH - self.BALL_SIZE:
            terminated = True
            reward = -1 if self.ball_x <= 0 else 1
        
        state = np.array([self.player_y, self.opponent_y, self.ball_x, self.ball_y, self.ball_speed_x, self.ball_speed_y], dtype=np.float32)
        return state, reward, terminated, truncated, {}
    
    def render(self, mode='human'):
        self.screen.fill((0, 0, 0))
        
        # Draw paddles and ball
        pygame.draw.rect(self.screen, (255, 255, 255), (self.SCREEN_WIDTH - 20, self.player_y, self.PADDLE_WIDTH, self.PADDLE_HEIGHT))
        pygame.draw.rect(self.screen, (255, 255, 255), (10, self.opponent_y, self.PADDLE_WIDTH, self.PADDLE_HEIGHT))
        pygame.draw.rect(self.screen, (255, 255, 255), (self.ball_x, self.ball_y, self.BALL_SIZE, self.BALL_SIZE))
        
        pygame.display.flip()
        self.clock.tick(60)
    
    def close(self):
        pygame.quit()

## Step 3: Train the Agent
Using the Stable-Baselines3 DQN implementation, we train an agent to play Pong.

In [None]:
# Create environment
env = PongEnv()

# Check environment
check_env(env)

# Create DQN model
model = DQN('MlpPolicy', env, verbose=1, learning_starts=1000)

# Train the model
model.learn(total_timesteps=200000)

# Save the model
model.save("dqn_pong")

## Step 4: Evaluate the Agent
After training, we test the agent's performance by letting it play multiple episodes.

In [4]:
# Load the trained model
model = DQN.load("dqn_pong")

# Evaluate the agent
obs, _ = env.reset()
for _ in range(1000):
    action, _states = model.predict(obs)
    obs, rewards, dones, truncated, info = env.step(action)
    env.render()
    if dones or truncated:
        obs, _ = env.reset()

env.close()

## Step 5: Playable Game with AI Integration
We integrate the AI agent into a Pygame-based Pong game. The agent controls one paddle, while the player controls the other.

In [None]:
# Load the trained model
model = DQN.load("dqn_pong")

# Set the environment for the model
model.set_env(env)

# Initialize Pygame
pygame.init()

# Screen dimensions
SCREEN_WIDTH = 800
SCREEN_HEIGHT = 600

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)

# Paddle dimensions
PADDLE_WIDTH = 10
PADDLE_HEIGHT = 100

# Ball dimensions
BALL_SIZE = 10

# Create screen
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption('Pong')

# Paddle class
class Paddle:
    def __init__(self, x, y, color):
        self.rect = pygame.Rect(x, y, PADDLE_WIDTH, PADDLE_HEIGHT)
        self.speed = 10
        self.color = color

    def move(self, up=True):
        if up:
            self.rect.y -= self.speed
        else:
            self.rect.y += self.speed

    def draw(self):
        pygame.draw.rect(screen, self.color, self.rect)


# Ball class
class Ball:
    def __init__(self, x, y):
        self.rect = pygame.Rect(x, y, BALL_SIZE, BALL_SIZE)
        self.speed_x = 7 * random.choice((1, -1))
        self.speed_y = 7 * random.choice((1, -1))

    def move(self):
        self.rect.x += self.speed_x
        self.rect.y += self.speed_y

    def draw(self):
        pygame.draw.rect(screen, WHITE, self.rect)

    def reset(self):
        self.rect.x = SCREEN_WIDTH // 2 - BALL_SIZE // 2
        self.rect.y = SCREEN_HEIGHT // 2 - BALL_SIZE // 2
        self.speed_x = 7 * random.choice((1, -1))
        self.speed_y = 7 * random.choice((1, -1))

# Initialize paddles and ball
player = Paddle(SCREEN_WIDTH - 20, SCREEN_HEIGHT // 2 - PADDLE_HEIGHT // 2, RED)  # Red color for player
opponent = Paddle(10, SCREEN_HEIGHT // 2 - PADDLE_HEIGHT // 2, WHITE)  # White color for agent
ball = Ball(SCREEN_WIDTH // 2 - BALL_SIZE // 2, SCREEN_HEIGHT // 2 - BALL_SIZE // 2)

# Initialize scores
player_score = 0
opponent_score = 0

# Game loop
running = True
clock = pygame.time.Clock()

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

    keys = pygame.key.get_pressed()
    if keys[pygame.K_UP]:
        player.move(up=True)
    if keys[pygame.K_DOWN]:
        player.move(up=False)

    ball.move()

    # Ball collision with top and bottom
    if ball.rect.top <= 0 or ball.rect.bottom >= SCREEN_HEIGHT:
        ball.speed_y *= -1

    # Ball collision with paddles
    if ball.rect.colliderect(player.rect) or ball.rect.colliderect(opponent.rect):
        ball.speed_x *= -1

    # Opponent AI: Follow the ball's y position
    if ball.rect.centery < opponent.rect.centery:
        opponent.move(up=True)
    elif ball.rect.centery > opponent.rect.centery:
        opponent.move(up=False)

    # Check if the ball goes out of bounds
    if ball.rect.left <= 0:
        player_score += 1
        ball.reset()
    elif ball.rect.right >= SCREEN_WIDTH:
        opponent_score += 1
        ball.reset()

    # Clear screen
    screen.fill(BLACK)

    # Draw paddles and ball
    player.draw()
    opponent.draw()
    ball.draw()

    # Display scores
    font = pygame.font.Font(None, 74)
    player_text = font.render(str(player_score), 1, RED)
    screen.blit(player_text, (SCREEN_WIDTH - 100, 10))
    opponent_text = font.render(str(opponent_score), 1, WHITE)
    screen.blit(opponent_text, (10, 10))

    # Update display
    pygame.display.flip()
    clock.tick(60)

pygame.quit()

<h2 style="text-align: center;">Thanks</h2>