In [21]:
import numpy as np
import gymnasium as gym
import random
from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

In [22]:
class Ball(object):
    def __init__(self, x: int | float, y: int | float,
                 spdx: int | float = 1, spdy: int | float = 1,
                 radius: int | float = 10):
        self.x, self.y = x, y
        self.spdx, self.spdy = spdx, spdy
        self.r = radius


class Paddle(object):
    def __init__(self, x: int | float, y: int | float,
                 w: int | float, h: int | float,
                 spd: int | float = 1):
        self.x, self.y = x, y
        self.w, self.h = w, h
        self.left, self.right = self.x - self.w / 2, self.x + self.w / 2
        self.top, self.bottom = self.x - self.h / 2, self.x + self.h / 2
        self.spd = spd
        self.score = 0


class PongEnv(gym.Env):

    def __init__(self):
        self.width, self.height = 900, 600
        self.ball_speed = 6
        self.paddle_speed = 6
        self.ball = Ball(self.width / 2, self.height / 2,
                         random.choice([1, -1]) * self.ball_speed, random.choice([1, -1]) * self.ball_speed)
        self.player = Paddle(self.width - 100, self.height / 2, 10, 100)
        self.opponent = Paddle(100, self.height / 2, 10, 100)
        self.frame = 0
        self.action_space = gym.spaces.Discrete(3)
        self.observation_space = gym.spaces.Box(-10000000.0, 10000000.0, shape=(8,), dtype=np.float32)

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)

        self.ball = Ball(self.width / 2, self.height / 2,
                         random.choice([1, -1]) * self.ball_speed, random.choice([1, -1]) * self.ball_speed)
        self.player = Paddle(self.width - 100, self.height / 2, 10, 100)
        self.opponent = Paddle(100, self.height / 2, 10, 100)
        self.frame = 0

        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        reward = 0

        if action == 1:
            self.player.y = max(0, self.player.y - self.player.spd)
        elif action == 2:
            self.player.y = min(self.height, self.player.y + self.player.spd)

        if self.ball.y >= self.height:
            self.ball.spdy = -self.ball_speed
        elif self.ball.y <= 0:
            self.ball.spdy *= self.ball_speed

        if self.ball.x <= 0:
            reward = 100
            self.ball = Ball(self.width / 2, self.height / 2,
                             random.choice([1, -1]) * self.ball_speed, random.choice([1, -1]) * self.ball_speed)
            self.player.score += 1
        elif self.ball.x >= self.width:
            reward = -100
            self.ball = Ball(self.width / 2, self.height / 2,
                             random.choice([1, -1]) * self.ball_speed, random.choice([1, -1]) * self.ball_speed)
            self.opponent.score += 1

        if self.player.x - self.ball.r <= self.ball.x <= self.player.right \
                and self.player.top - self.ball.r <= self.ball.y <= self.player.bottom + self.ball.r:
            self.ball.spdx = -self.ball_speed
            reward = 1
        if self.opponent.x - self.ball.r <= self.ball.x <= self.opponent.right \
                and self.opponent.top - self.ball.r <= self.ball.y <= self.opponent.bottom + self.ball.r:
            self.ball.spdx = self.ball_speed

        if self.frame % 8:
            if self.opponent.y < self.ball.y:
                self.opponent.top += self.paddle_speed
            else:
                self.opponent.bottom -= self.paddle_speed

        self.ball.x += self.ball.spdx
        self.ball.y += self.ball.spdy

        self.frame += 1

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, False, False, info

    def _get_obs(self):
        return np.array([
            self.ball.x, self.ball.y, self.ball.spdx, self.ball.spdy,
            self.player.x, self.player.y, self.opponent.x, self.opponent.y
        ], dtype=np.float32)

    def _get_info(self):
        """Compute auxiliary information for debugging.

        Returns:
            dict: Info with distance between agent and target
        """
        return {
            'Player score': self.player.score,
            'Opponent score': self.opponent.score
        }

gym.register(
    id="Pong_custom",
    entry_point=PongEnv,
    max_episode_steps=1200,  # Prevent infinite episodes
)

In [23]:
class DQNAgent:
    def __init__(self, action_dim, lr, gamma, epsilon, epsilon_decay, buffer_size):
        self.action_dim = action_dim
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.memory = deque(maxlen=buffer_size)

        self.optimizer = optim.Adam(model.parameters(), lr=lr)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.choice(self.action_dim)
        q_values = model(torch.tensor(state, dtype=torch.float32).to(device))
        return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(device)
            if not done:
                logits = model(next_state_tensor)
                target = reward + self.gamma * torch.max(logits).item()

            target_f = model(state_tensor)
            target_f[action] = target
            self.optimizer.zero_grad()
            loss = nn.MSELoss()(target_f, model(state_tensor))
            loss.backward()
            self.optimizer.step()
        if self.epsilon > 0.01:
            self.epsilon *= self.epsilon_decay

class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(8, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DQN().to(device)

device

device(type='cuda')

In [None]:
env = gym.make('Pong_custom')
action_dim = 3
agent = DQNAgent(action_dim, lr=0.01, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, buffer_size=10000)

batch_size = 32
num_episodes = 10
for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0
    done = False
    frames = 0
    while not done and frames <= 1200:
        action = agent.act(state)
        next_state, reward, done, _, _ = env.step(action)
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        agent.replay(batch_size)
        frames += 1
    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")