In [3]:
from __future__ import annotations
from typing import Union, Tuple, List, Type
import random
from itertools import count
import gym
from gym.spaces import Discrete, Box
import numpy as np
import atari_py
from gym.envs.atari import AtariEnv
ACTION = int


class PongGame(AtariEnv):
    _game_count = count(0)
    def __init__(self, second_player):
        super().__init__(frameskip=1)
        self._game_id = next(self._game_count)
        self._seconf_player_class = second_player
        self._is_multiplayer = second_player is not None
        self._action_set = self.ale.getMinimalActionSet()
        self._action_set2 = [x + 18 for x in self._action_set]
        self.current_player = 0  # 0 (P1) or 1 (P2)
        self.done = False
        self.player_1_action = None
        self._player2_bot = second_player(self.action_space, player=2) if self._is_multiplayer is True else None

    def step(self, a1: ACTION, a2: Union[ACTION, None] = None):
        action1 = self._action_set[a1]
        a2 = a2 or a1
        action2 = self._action_set2[a2]
        reward = self.ale.act2(action1, action2)
        ob = self._get_obs()
        return ob, reward

    # Do not make this static because MCTS requires it
    def possible_actions(self, player=None) -> List[ACTION]:
        if player is not None:
            ob = self._get_obs()
            if check_if_should_take_action(ob, player=player):
                return [DOWN, UP]
            return [FIRE]
        return [FIRE, DOWN, UP]

    def act(self, action: ACTION) -> int:
        if self.current_player == 0:
            self.player_1_action = action
            self.current_player = 1
            return False

        ob, reward = self.step(self.player_1_action, action)

        # reward could be only -1, 0 and 1 (-1 and 1 means there is a point scored by one of the sides)
        if reward != 0:
            self.current_player = 0 if reward == 1 else 1
            self.done = True
            return reward

        else:
            self.current_player = 0
            return 0

    def act_random(self) -> int:
        return self.act(random.choice(self.possible_actions(player=self.current_player)))

    def reset(self):
        super().reset()
        self.ale.press_select()
        self.ale.press_select()
        self.ale.press_select()
        self.ale.soft_reset()
        self.step(FIRE, FIRE)
        while self._get_ram()[RAM_BALL_Y_POS] == 0:
            self.step(FIRE, FIRE)

    def copy(self) -> PongGame:
        _new_game = PongGame(self._seconf_player_class)
        _new_game.restore_full_state(self.clone_full_state())

        return _new_game

    def get_state(self):
        return (self.clone_full_state(), self.player_1_action)

    def set_state(self, state, done, current_player):
        self.done = done
        self.current_player = current_player
        self.restore_full_state(state[0])
        self.player_1_action = state[1]

    def get_winner(self) -> int:
        ob = self._get_obs()
        return 0 if ob[P_RIGHT_SCORE] > 0 else 1

In [1]:
import sys
import random
import os
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
import pygame
pygame.init()
import time

def add_coordinates_with_speed(a, b, s):
    return [a[0] + (b[0] * s), a[1] + (b[1] * s)]

class Colors:
    def __init__(this):
        this.white = [255, 255, 255]
        this.red = [255, 0, 0]
        this.green = [0, 255, 0]
        this.black = [0, 0, 0]
        this.blue = [0, 0, 255]

class Slab:
    def __init__(this, size, step_size, orientation, display_height, pixel_size, position):
        this.size = size
        this.step_size = step_size
        this.pixel_size = pixel_size
        this.orientation = orientation
        this.display_height = display_height
        this.position = position
        this.pixel_slab_size = this.size * this.pixel_size
        this.buffer_top_size = 0 + this.pixel_slab_size
        this.buffer_bottom_size = this.display_height - this.pixel_slab_size

    def move(this, movement):
        if movement == 1:
            if this.position - this.step_size < this.buffer_top_size:
                this.position = this.buffer_top_size
            else:
                this.position -= this.step_size
        elif movement == 0:
            if this.position + this.step_size > this.buffer_bottom_size:
                this.position = this.buffer_bottom_size
            else:
                this.position += this.step_size

class Ball:
    def __init__(this, speed, display_width, display_height, pixel_size, position, direction):
        this.winner = -1
        this.direction = direction
        this.pixel_size = pixel_size
        this.speed = speed
        this.display_width = display_width
        this.display_height = display_height
        this.radius = this.pixel_size / 2
        this.position = position
        this.buffer_top_size = 0 + (this.radius)
        this.buffer_bottom_size = this.display_height - (this.radius)
        this.buffer_left_size = 0 + (this.radius)
        this.buffer_right_size = this.display_width - (this.radius)
        this.buffer_left_slab_size = 0 + (this.radius) + this.pixel_size
        this.buffer_right_slab_size = this.display_width - (this.radius) - this.pixel_size

    def step(this, left_slab, right_slab):
        if this.position[0] < this.buffer_left_size:
            this.winner = 0
            return 'Left'
        if this.position[0] > this.buffer_right_size:
            this.winner = 1
            return 'Right'
        temp = 0
        if this.position[1] <= this.buffer_top_size:
            this.direction[1] = 1
        if this.position[1] >= this.buffer_bottom_size:
            this.direction[1] = -1
        if this.position[0] <= this.buffer_left_slab_size and this.position[1] >= left_slab.position - left_slab.pixel_slab_size and this.position[1] <= left_slab.position + left_slab.pixel_slab_size:
            this.direction[0] = 1
            temp = 1
        if this.position[0] >= this.buffer_right_slab_size and this.position[1] >= right_slab.position - right_slab.pixel_slab_size and this.position[1] <= right_slab.position + right_slab.pixel_slab_size:
            this.direction[0] = -1
            temp = 1
        this.position = add_coordinates_with_speed(this.position, this.direction, this.speed)
        return temp

class Display():
    def __init__(this, display_width, display_height):
        this.display_width = display_width
        this.display_height = display_height
        this.screen = pygame.display.set_mode([this.display_width, this.display_height])
        pygame.display.set_caption("Ping-Pong Tron")

    def fill(this, color):
        this.screen.fill(color)

    def update(this):
        pygame.display.update()

    def circle(this, coordinates, radius, color):
        pygame.draw.circle(this.screen, color, coordinates, radius)

    def rect(this, coordinates, color):
        pygame.draw.rect(this.screen, color, coordinates)

    def quit(this):
        pygame.display.quit()
        pygame.quit()

class Environment:
    def __init__(this, time_delay = 0.01, ball_speed = 10, slab_step_size = 10, pixel_size = 30, slab_size = 2, display_width = 1500, display_height = 900, headless = False, controls = True, left_slab_position = 900 / 2, right_slab_position = 900 / 2, ball_position = [1500 / 2, random.randint(100, 800)], ball_direction = [random.choice([1, -1]), random.choice([1, -1])]):
        this.headless = headless
        this.score = 0
        this.time_delay = time_delay
        this.controls = controls
        this.ball_speed = ball_speed
        this.colors = Colors()
        this.slab_step_size = slab_step_size
        this.pixel_size = pixel_size
        this.slab_size = slab_size
        this.display_width = display_width
        this.display_height = display_height
        this.left_slab = Slab(this.slab_size, this.slab_step_size, 1, this.display_height, this.pixel_size, left_slab_position)
        this.right_slab = Slab(this.slab_size, this.slab_step_size, 2, this.display_height, this.pixel_size, right_slab_position)
        this.ball = Ball(this.ball_speed, this.display_width, this.display_height, this.pixel_size, ball_position, ball_direction)
        this.game_done = False
        if not headless:
            this.display = Display(this.display_width, this.display_height)

    def play(this):
        if not this.headless:
            this.render()
        boolean = True
        while boolean:
            time.sleep(this.time_delay)
            temp = this.step()
            if temp == 'Left' or temp == 'Right':
                boolean = False
                this.game_done = True
                this.game_over(temp)
                break
            this.render()
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    boolean = False
            if this.controls:
                keys = pygame.key.get_pressed()
                if keys[pygame.K_UP]:
                    this.right_slab.move(1)
                if keys[pygame.K_DOWN]:
                    this.right_slab.move(0)
                if keys[pygame.K_w]:
                    this.left_slab.move(1)
                if keys[pygame.K_s]:
                    this.left_slab.move(0)
                if keys[pygame.K_q]:
                    boolean = False
            else:
                keys = pygame.key.get_pressed()
                if keys[pygame.K_q]:
                    boolean = False
                if (this.right_slab.position > this.ball.position[1]):
                    this.right_slab.move(1)
                else:
                    this.right_slab.move(0)
                if (this.left_slab.position > this.ball.position[1]):
                    this.left_slab.move(1)
                else:
                    this.left_slab.move(0)


    def step(this):
        temp = this.ball.step(this.left_slab, this.right_slab)
        if temp == 0 or temp == 1:
            this.score += temp
        return temp

    def game_over(this, orientation):
            print(orientation, 'lost with a streak of', this.score)

    def render(this):
        if not this.headless:
            this.display.fill(this.colors.black)
            this.display.rect([0, this.left_slab.position - this.left_slab.pixel_slab_size, this.pixel_size, 2 * this.left_slab.pixel_slab_size], this.colors.red)
            this.display.rect([this.display_width - this.pixel_size, this.right_slab.position - this.right_slab.pixel_slab_size, this.pixel_size, 2 * this.right_slab.pixel_slab_size], this.colors.red)
            this.display.circle([this.ball.position[0], this.ball.position[1]], this.ball.radius, this.colors.green)
            this.display.update()

In [2]:
import copy
import math
actions = ["1","0"]
class Node:
    def __init__(self, state):
        self.children = {}
        self.is_leaf = state.game_done
        self.simulations = 0
        self.wins = 0
        self.state = state
        
class MCTS:
    def __init__(self,init_state,constant):
        self.root = Node(init_state)
        self.constant = constant
        
    def selection(self):
        path = [self.root]
        end = self.root
        while not (end.is_leaf or bool(actions - end.children.keys())):
            max_ = -math.inf
            max_action = None
            for action in actions:
                child = end.children[action]
                mean = child.wins/child.simulations
                if(max_ < mean + self.constant*math.sqrt(math.log(end.simulations)/child.simulations)):
                    max_ = mean + self.constant*math.sqrt(math.log(end.simulations)/child.simulations)
                    max_action = action
            end = end.children[max_action]
            path.append(end)
        return path
    
    def expand(self,leaf):
        if leaf.is_leaf:
            return None
        action = (actions - leaf.children.keys()).pop()
        temp_env = Environment(headless = True, controls = False, left_slab_position = leaf.state.left_slab.position, right_slab_position = leaf.state.right_slab.position, ball_position = leaf.state.ball.position, ball_direction = leaf.state.ball.direction)
        temp_env.left_slab.move(int(action))
        child_node = Node(temp_env)
        leaf.children[action] = child_node
        return child_node
    
    
    def simulate(self,leaf):
        temp_env = Environment(headless = True, controls = False, left_slab_position = leaf.state.left_slab.position, right_slab_position = leaf.state.right_slab.position, ball_position = leaf.state.ball.position, ball_direction = leaf.state.ball.direction)
        step = 0
        while not leaf.is_leaf and step < 1000:
            temp_env.left_slab.move(round(random.random()))
            score = temp_env.score
            step += 1
        return score
    
    def backpropagation(self, path, score):
        for obs in path:
            if (obs.state.ball.winner >= 0):
                obs.wins += 1
            obs.simulations += 1
    
    def choose_action(self):
        return max(self.root.children, key=lambda x: self.root.children[x].simulations)
    
    def run(self,steps):
        for i in range(steps):
            path = self.selection()
            child_node = path[-1]
            expand_node = self.expand(child_node)
            if expand_node:
                path.append(expand_node)
            else:
                expand_node = child_node
            score = self.simulate(expand_node)
            self.backpropagation(path,score)

In [3]:
environment = Environment(controls = False)
tree = MCTS(environment,1)
i = 0
while not environment.game_done:
    i = i + 1
    tree.run(50)
    action = tree.choose_action()
    environment.left_slab.move(int(action))
    if action in tree.root.children:
        tree.root = tree.root.children[action]
    else:
        temp_env = Environment(headless = True, controls = False, left_slab_position = tree.root.state.left_slab.position, right_slab_position = tree.root.state.right_slab.position, ball_position = tree.root.state.ball.position, ball_direction = tree.root.state.ball.direction)
        temp_env.left_slab.move(int(action))
        tree.root = Node(temp_env)
    if(i > 1000):
        break

KeyboardInterrupt: 