In [6]:
import numpy as np
import gymnasium as gym
import random
import time
from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

In [15]:
class Ball(object):
    def __init__(self, x: int | float, y: int | float,
                 spdx: int | float = 1, spdy: int | float = 1,
                 radius: int | float = 10):
        self.x, self.y = x, y
        self.spdx, self.spdy = spdx, spdy
        self.r = radius


class Paddle(object):
    def __init__(self, x: int | float, y: int | float,
                 w: int | float, h: int | float,
                 spd: int | float = 1):
        self.x, self.y = x, y
        self.w, self.h = w, h
        self.left, self.right = self.x - self.w / 2, self.x + self.w / 2
        self.top, self.bottom = self.y - self.h / 2, self.y + self.h / 2
        self.spd = spd
        self.score = 0

    def move(self, dx: int, dy: int):
        self.x += dx
        self.y += dy
        self.left, self.right = self.x - self.w / 2, self.x + self.w / 2
        self.top, self.bottom = self.y - self.h / 2, self.y + self.h / 2


class PongEnv(gym.Env):
    def __init__(self, max_frame: int = 1200):
        self.width, self.height = 900, 600
        self.ball_speed = 6
        self.paddle_speed = 6
        self.ball = None
        self.player = None
        self.opponent = None
        self.frame = 0
        self.max_frame = max_frame

        self.action_space = gym.spaces.Discrete(3)
        self.observation_space = gym.spaces.Box(-1e7, 1e7, shape=(8,), dtype=np.float32)

        self.reset()

    def _reset_ball(self):
        self.ball = Ball(
            self.width / 2, self.height / 2,
            random.choice([1, -1]) * self.ball_speed,
            random.choice([1, -1]) * self.ball_speed
        )

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        self._reset_ball()
        self.player = Paddle(self.width - 100, self.height / 2, 10, 100, self.paddle_speed)
        self.opponent = Paddle(100, self.height / 2, 10, 100, self.paddle_speed)
        self.frame = 0

        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        reward = 0

        # Player movement
        if action == 1:
            self.player.y = max(0 + self.player.h / 2, self.player.y - self.player.spd)
        elif action == 2:
            self.player.y = min(self.height - self.player.h / 2, self.player.y + self.player.spd)
        self.player.move(0, 0)  # Recalculate bounds

        # Ball bounce off top/bottom
        if self.ball.y >= self.height or self.ball.y <= 0:
            self.ball.spdy *= -1

        # Ball collision with player paddle
        if self.ball.spdx > 0 and \
           self.player.left - self.ball.r <= self.ball.x <= self.player.right + self.ball.r and \
           self.player.top - self.ball.r <= self.ball.y <= self.player.bottom + self.ball.r:
            self.ball.spdx = -self.ball_speed
            self.ball.spdy *= -1
            reward += 5

        # Ball collision with opponent paddle
        if self.ball.spdx < 0 and \
           self.opponent.left - self.ball.r <= self.ball.x <= self.opponent.right + self.ball.r and \
           self.opponent.top - self.ball.r <= self.ball.y <= self.opponent.bottom + self.ball.r:
            self.ball.spdx = self.ball_speed
            self.ball.spdy *= -1

        # Ball out of bounds — scoring
        if self.ball.x < 0:
            self.player.score += 1
            reward += 10
            self._reset_ball()
        elif self.ball.x > self.width:
            self.opponent.score += 1
            reward -= 10
            self._reset_ball()

        # AI opponent tracks ball perfectly
        self.opponent.y = self.ball.y
        self.opponent.move(0, 0)  # Update bounds

        # Update ball position
        self.ball.x += self.ball.spdx
        self.ball.y += self.ball.spdy

        # Frame count
        self.frame += 1
        done = self.frame >= self.max_frame

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, done, done, info

    def _get_obs(self):
        return np.array([
            self.ball.x, self.ball.y, self.ball.spdx, self.ball.spdy,
            self.player.x, self.player.y, self.opponent.x, self.opponent.y
        ], dtype=np.float32)

    def _get_info(self):
        return {
            'Player score': self.player.score,
            'Opponent score': self.opponent.score
        }

gym.register(
    id="Pong_custom",
    entry_point=PongEnv,
    max_episode_steps=1200,  # Prevent infinite episodes
)

In [16]:
class DQNAgent:
    def __init__(self, action_dim, lr, gamma, epsilon, epsilon_decay, buffer_size):
        self.action_dim = action_dim
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.memory = deque(maxlen=buffer_size)

        self.optimizer = optim.Adam(model.parameters(), lr=lr)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.choice(self.action_dim)
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        q_values = model(state_tensor)
        return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return

        minibatch = random.sample(self.memory, batch_size)

        states, actions, rewards, next_states, dones = zip(*minibatch)

        states = torch.tensor(states, dtype=torch.float32).to(device)
        actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)

        current_q = model(states).gather(1, actions)

        with torch.no_grad():
            max_next_q = model(next_states).max(1)[0].unsqueeze(1)
            target_q = rewards + (1 - dones) * self.gamma * max_next_q

        loss = nn.MSELoss()(current_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > 0.01:
            self.epsilon *= self.epsilon_decay


class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(n_observations, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, n_actions)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DQN(8, 3).to(device)


In [None]:
env = gym.make('Pong_custom', max_frame = 1200)
action_dim = 3
agent = DQNAgent(action_dim, lr=0.01, gamma=0.99, epsilon=1.0, epsilon_decay=0.999, buffer_size=10000)

batch_size = 32
num_episodes = 1000
for episode in range(num_episodes):
    state, info = env.reset()
    total_reward = 0
    done = False
    frames = 0
    t0 = time.time()
    while not done:
        action = agent.act(state)
        next_state, reward, done, _, info = env.step(action)
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        agent.replay(batch_size)
        frames += 1
    print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Time Spent: {time.time() - t0}")
    print(info)
    if (episode+1) % 100 == 0:
        torch.save(model.state_dict(), f'model_weights_{episode+1}.pth')

Episode: 1, Total Reward: -80, Time Spent: 2.5571506023406982
{'Player score': 0, 'Opponent score': 8}
Episode: 2, Total Reward: -45, Time Spent: 2.7774507999420166
{'Player score': 0, 'Opponent score': 6}
Episode: 3, Total Reward: 0, Time Spent: 2.9596400260925293
{'Player score': 0, 'Opponent score': 2}
Episode: 4, Total Reward: -15, Time Spent: 2.941955327987671
{'Player score': 0, 'Opponent score': 3}
Episode: 5, Total Reward: 15, Time Spent: 2.891503095626831
{'Player score': 0, 'Opponent score': 1}
Episode: 6, Total Reward: -30, Time Spent: 2.8913817405700684
{'Player score': 0, 'Opponent score': 4}
Episode: 7, Total Reward: -65, Time Spent: 2.858070135116577
{'Player score': 0, 'Opponent score': 7}
Episode: 8, Total Reward: 0, Time Spent: 2.821485996246338
{'Player score': 0, 'Opponent score': 2}
Episode: 9, Total Reward: -75, Time Spent: 2.8381922245025635
{'Player score': 0, 'Opponent score': 8}
Episode: 10, Total Reward: 15, Time Spent: 2.8708016872406006
{'Player score': 0, 