In [59]:
import pygame
import random
import torch.nn as nn
import torch
from collections import deque
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [60]:
class Game:
    def __init__(self,screen_width,screen_height,screen,paddle,ball):
        pygame.init()
        self.paddle = paddle
        self.ball = ball
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.screen = screen
        self.clock = pygame.time.Clock()
    def run(self,episode_number,epsilon,batch_size,gamma,initial_state,memory,Q,optim):
        # Main game loop
        for i in range(episode_number):
            running = True
            self.paddle.y = random.randint(self.paddle.height,self.screen_height)
            self.ball.x = random.randint(self.paddle.width + 10,self.screen_width-2*self.ball.length)
            self.ball.y = random.randint(0,self.screen_height-2*self.ball.length)
            self.ball.velx = random.choice([-5,5])
            self.ball.vely = random.choice([-5,5])
            initial_state[2] = 1 if self.ball.y < (self.paddle.y-self.paddle.height) else 0
            initial_state[3] = int(not bool(initial_state[2]))
            self.ball.update()
            self.paddle.update()
            while running:
                action = [0,0,0]
                if random.uniform(0,1) < epsilon:
                    move = random.choice([0,1,2]) #replace with actual possible actions
                    action[move] = 1
                else:
                    with torch.no_grad():
                        move = torch.argmax(Q(torch.tensor([initial_state],dtype=torch.float32))).item()
                        action[move] = 1

                next_state, reward = self.paddle.get_state(initial_state,self.ball,action)
                memory.append((initial_state,action,reward,next_state))

                initial_state = next_state
                if len(memory) >= batch_size:
                    batch = random.sample(memory,batch_size)

                    initial_states = torch.tensor([i[0] for i in batch], dtype=torch.float32)
                    actions = torch.tensor([i[1] for i in batch], dtype=torch.float32)
                    rewards = torch.tensor([i[2] for i in batch], dtype=torch.float32)
                    next_states = torch.tensor([i[3] for i in batch], dtype=torch.float32)


                    Q_values = Q(initial_states)
                    target = Q_values.clone()
                    for idx in range(len(rewards)):
                        Q_next = rewards[idx] + gamma*torch.max(Q(next_states[idx]))
                    # Calculate the target Q-values

                        target[idx][torch.argmax(actions[idx]).item()] = Q_next
                    # Calculate the loss
                    

                    loss = nn.MSELoss()(Q_values, target)

                    # Backpropagate the error and optimize the Q-network
                    optim.zero_grad()
                    loss.backward()
                    optim.step()


                self.screen.fill((0, 0, 0))
                # Handle events
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        return

                # Update the paddle velocities based on user input
                
                # Update the paddle positions
                self.paddle.update()

                # Keep the paddles within the screen bounds
                self.paddle.y = max(0, min(self.paddle.y+self.paddle.height//2,self.screen_height-self.paddle.height)) - self.paddle.height//2

                # Update the ball position
                self.ball.update()

                # Check for ball collision with the top and bottom of the screen
                if self.ball.y <= 0 or self.ball.y >= self.screen_height - self.ball.length:
                    self.ball.vely *= -1
                #check for collision with right side of the wall
                if self.ball.x >= self.screen_width - self.ball.length:
                    self.ball.velx *= -1

                # Check for ball collision with the paddles
                if self.ball.x <= self.paddle.width and self.ball.rect.colliderect(self.paddle.rect):
                    self.ball.velx = -self.ball.velx

                if self.ball.x < 0:
                    break
                    
                # Clear the screen
                

                # Update the screen
                pygame.display.update()

                # Control the frame rate
                self.clock.tick(120)
        


In [61]:
class Paddle:
    def __init__(self,width,height,x,y,vel,screen):
        self.height = height
        self.width = width
        self.x = x
        self.y = y
        self.vel = vel
        self.screen = screen
    def update(self):
        self.y += self.vel
        self.rect = pygame.Rect(self.x-self.width//2,self.y+self.height//2,self.width,self.height)
        pygame.draw.rect(self.screen, (255, 255, 255), rect=self.rect)
    def get_state(self,state,ball,action):
        #given state, give a certain reward
        if ball.rect.colliderect(self.rect):
            reward = 1
        elif ball.x <= ball.length:
            reward = -1
        else:
            reward = 0

        if action.index(1) == 0:
            self.vel = -5
            state[0] = 1
            state[1] = 0
        elif action.index(1) == 1:
            self.vel = 5
            state[0] = 0
            state[1] = 1
        else:
            self.vel = 0
            state[0] = state[1] = 0
        state[2] = 1 if ball.y < (self.y-self.height) else 0
        state[3] = int(not bool(state[2]))
        return state,reward
        
class Ball:
    def __init__(self,x,y,velx,vely,length,screen):
        self.length = length
        self.x = x
        self.y = y
        self.velx = velx
        self.vely = vely
        self.screen = screen
    def update(self):
        self.x += self.velx
        self.y += self.vely
        self.rect = pygame.Rect(self.x-self.length//2, self.y+self.length//2, self.length, self.length)
        pygame.draw.rect(self.screen, (255, 255, 255), rect=self.rect)

In [62]:
# Set the screen size
screen_width = 800
screen_height = 600
screen = pygame.display.set_mode((screen_width, screen_height))
# Set the paddle and ball sizes
paddle_width = 20
paddle_height = 100
ball_size = 20

# Initialize the paddle and ball positions
paddle_x = 0
paddle_y = screen_height // 2 - paddle_height // 2
ball_x = screen_width // 2
ball_y = screen_height // 2

# Initialize the paddle and ball velocities
paddle_vel = 0
ball_max_vel = 5
ball_velx = random.choice([-ball_max_vel, ball_max_vel])
ball_vely = random.choice([-ball_max_vel, ball_max_vel])

#states = [up,down,ball up,ball down]
state = [0,0,0,0]
paddle = Paddle(paddle_width,paddle_height,paddle_x,paddle_y,paddle_vel,screen)
ball = Ball(ball_x,ball_y,ball_velx,ball_vely,ball_size,screen)

memory = deque(maxlen=1000)

In [63]:
class Network(nn.Module):
    def __init__(self,input_states,output_actions,N):
        super().__init__()
        self.m1 = nn.Linear(input_states,N)
        self.m2 = nn.Linear(N,N)
        self.m3 = nn.Linear(N,output_actions)
    def forward(self,x):
        x = F.relu(self.m1(x))
        x = F.relu(self.m2(x))
        x = self.m3(x)
        return x
Q = Network(len(state),3,80)

optim = optim.Adam(Q.parameters(),lr=0.001)

In [64]:
episode_number = 1000
epsilon = 0.1
batch_size = 100
gamma = 0.9
game = Game(screen_width,screen_height,screen,paddle,ball)
loss = game.run(episode_number,epsilon,batch_size,gamma,state,memory,Q,optim)