In [31]:
import pygame
import random
import torch.nn as nn
import torch
from collections import deque
import numpy as np
import torch.optim as optim
import torch.nn.functional as F

In [32]:
class Game:
    def __init__(self,screen_width,screen_height,screen,paddle,ball):
        pygame.init()
        self.paddle = paddle
        self.ball = ball
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.screen = screen
        self.clock = pygame.time.Clock()
    def run(self,episode_number,epsilon,batch_size,gamma,initial_state,memory,Q,optim):
        # Main game loop
        for i in range(episode_number):
            running = True
            done = False
            self.paddle.y = random.randint(self.paddle.height,self.screen_height)
            initial_state[1] = self.paddle.y
            while running:
                action = [-1,0,1]
                if random.uniform(0,1) < epsilon:
                    action = random.choice(action) #replace with actual possible actions
                else:
                    with torch.no_grad():
                        action = action[torch.argmax(Q(torch.tensor([initial_state],dtype=torch.float32))).item()]

                next_state, reward = self.paddle.get_state(self.ball,action)
                memory.append((initial_state,action,reward,next_state,done))

                if memory[-1][-1] == True: #fill in with terminating case
                    running = False
                initial_state = next_state
                if len(memory) >= batch_size:
                    batch = random.sample(memory,batch_size)

                    initial_states = torch.tensor([i[0] for i in batch], dtype=torch.float32)
                    actions = torch.tensor([i[1] for i in batch], dtype=torch.float32)
                    rewards = torch.tensor([i[2] for i in batch], dtype=torch.float32)
                    next_states = torch.tensor([i[3] for i in batch], dtype=torch.float32)

                    Q_next = torch.max(Q(next_states), dim=1)[0].unsqueeze(1)
                    # Calculate the target Q-values

                    target = rewards + gamma * Q_next
                    # Calculate the loss
                    Q_values = torch.max(Q(initial_states),dim=1)[0].unsqueeze(1)
                    

                    loss = nn.MSELoss()(Q_values, target)

                    # Backpropagate the error and optimize the Q-network
                    optim.zero_grad()
                    loss.backward()
                    optim.step()

                done = False
                self.screen.fill((0, 0, 0))
                # Handle events
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        running = False

                # Update the paddle velocities based on user input
                
                # Update the paddle positions
                self.paddle.update()

                # Keep the paddles within the screen bounds
                self.paddle.y = max(0, min(self.paddle.y+self.paddle.height//2,self.screen_height-self.paddle.height)) - self.paddle.height//2

                # Update the ball position
                self.ball.update()

                # Check for ball collision with the top and bottom of the screen
                if self.ball.y <= 0 or self.ball.y >= self.screen_height - self.ball.length:
                    self.ball.vely *= -1
                #check for collision with right side of the wall
                if self.ball.x >= self.screen_width - self.ball.length:
                    self.ball.velx *= -1

                # Check for ball collision with the paddles
                if self.ball.x <= self.paddle.width and self.ball.rect.colliderect(self.paddle.rect):
                    self.ball.velx = -self.ball.velx

                if self.ball.x < 0:
                    # Initialize the paddle and ball positions
                    paddle_x = 0
                    paddle_y = random.randint(self.paddle.height,self.screen_height)
                    ball_x = self.screen_width // 2
                    ball_y = self.screen_height // 2

                    # Initialize the paddle and ball velocities
                    paddle_vel = 0
                    ball_max_vel = 5
                    ball_velx = random.choice([-ball_max_vel, ball_max_vel])
                    ball_vely = random.choice([-ball_max_vel, ball_max_vel])

                    self.paddle.x = paddle_x
                    self.paddle.y = paddle_y
                    self.ball.x = ball_x
                    self.ball.y = ball_y
                    self.paddle.vel = paddle_vel
                    self.ball.velx = ball_velx
                    self.ball.vely = ball_vely

                    done = True
                # Clear the screen
                

                # Update the screen
                pygame.display.update()

                # Control the frame rate
                self.clock.tick(60)
                


In [33]:
class Paddle:
    def __init__(self,width,height,state,screen):
        self.state = state
        self.height = height
        self.width = width
        self.x = state[0]
        self.y = state[1]
        self.vel = state[2]
        self.screen = screen
    def update(self):
        self.y += self.vel
        self.rect = pygame.Rect(self.x-self.width//2,self.y+self.height//2,self.width,self.height)
        pygame.draw.rect(self.screen, (255, 255, 255), rect=self.rect)
    def get_state(self,ball,action):
        #given state, give a certain reward
        if ball.x <= self.width and ball.rect.colliderect(self.rect):
            reward = 10
        elif ball.x <= 0:
            reward = -10
        else:
            reward = 0

        if action == 1:
            self.vel = -5
            self.state[2] = self.vel
            self.state[1] += self.vel
        elif action == -1:
            self.vel = 5
            self.state[2] = self.vel
            self.state[1] += self.vel
        else:
            self.vel = 0
            self.state[2] = self.vel
            self.state[1] += self.vel
        self.state[3] = ball.x
        self.state[4] = ball.y
        self.state[5] = ball.velx
        self.state[6] = ball.vely
        return self.state,reward
        
class Ball:
    def __init__(self,length,state,screen):
        self.state = state
        self.length = length
        self.x = state[3]
        self.y = state[4]
        self.velx = state[5]
        self.vely = state[6]
        self.screen = screen
    def update(self):
        self.x += self.velx
        self.y += self.vely
        self.rect = pygame.Rect(self.x-self.length//2, self.y+self.length//2, self.length, self.length)
        pygame.draw.rect(self.screen, (255, 255, 255), rect=self.rect)

In [34]:
# Set the screen size
screen_width = 800
screen_height = 600
screen = pygame.display.set_mode((screen_width, screen_height))
# Set the paddle and ball sizes
paddle_width = 20
paddle_height = 100
ball_size = 20

# Initialize the paddle and ball positions
paddle_x = 0
paddle_y = screen_height // 2 - paddle_height // 2
ball_x = screen_width // 2
ball_y = screen_height // 2

# Initialize the paddle and ball velocities
paddle_vel = 0
ball_max_vel = 5
ball_velx = random.choice([-ball_max_vel, ball_max_vel])
ball_vely = random.choice([-ball_max_vel, ball_max_vel])

state = [paddle_x,paddle_y,paddle_vel,ball_x,ball_y,ball_velx,ball_vely]
paddle = Paddle(paddle_width,paddle_height,state,screen)
ball = Ball(ball_size,state,screen)

memory = deque(maxlen=1000)

In [35]:
class Network(nn.Module):
    def __init__(self,input_states,output_actions,N):
        super().__init__()
        self.m1 = nn.Linear(input_states,N)
        self.m2 = nn.Linear(N,output_actions)
    def forward(self,x):
        x = F.relu(self.m1(x))
        x = self.m2(x)
        return x
Q = Network(len(paddle.state),3,10)

optim = optim.Adam(Q.parameters(),lr=0.001)

In [36]:
episode_number = 100
epsilon = 0.5
batch_size = 10
gamma = 0.9
game = Game(screen_width,screen_height,screen,paddle,ball)
game.run(episode_number,epsilon,batch_size,gamma,state,memory,Q,optim)

KeyboardInterrupt: 