In [1]:
from mss import mss # screen grabber
import pyautogui
from PIL import Image, ImageTk
import time # For delay
import cv2
import random 
import os
import numpy as np

In [20]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

class YourEnvironment(gym.Env):
    def __init__(self):
        # Initialization of your environment parameters
        self.action_space = gym.spaces.Discrete(4)  # Assuming 4 possible actions
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(11,), dtype=float)  # Assuming 10 states
        self.reset()

    def reset(self):
        # Reset the environment to the initial state
        self.target_locked = 0 # Done
        self.target_locked_time = 0 
        self.damage_done = 0 # Can be Done
        self.movement_horizontal = 0 # Can be Done
        self.movement_vertical = 0 # Can be Done
        self.adjust_view = 0 # Random
        self.ship_health_percentage = 100.0
        self.time_until_last_locked_target = 0 # Can be Done
        self.enemy_on_radar = 0 # Done
        self.red_bar_on_view = 0 # Done
        self.reload_1 = 0
        self.current_screenshot = None
        self.continuous_lockon_enemy = 0
        self.collision_time = 0
        self.reward = 0

        return self.get_state()

    def step(self, action):
        # Take a step in the environment based on the agent's action
        self.update_states(action)
        reward = self.calculate_reward()  # Implement your reward logic
        done = self.check_done()  # Implement your termination logic
        arr = ['none', 'fire', 'move', 'view'] 
        print(f"Action: {arr[action]}, Reward: {reward}, Done: {done}, Idle Time: {self.time_until_last_locked_target}")

        return self.get_state(), reward, done, {}

    def update_states(self, action):
        # Update states based on the agent's action
        self.current_screenshot = self.update_screenshot()

        # Check if target is locked:
        img, is_present = self.matchImage("./images/aim_complete.png", 0.3, np.array([50, 218, 183]),  np.array([175, 255, 255]), (0, 0, 255))
        self.target_locked = is_present
        if is_present == False:
            self.target_locked_time = 0
        else:
            self.target_locked_time += 1
        
        # Check if enemy is on radar:
        img, is_present = self.checkRadar("./images/radar_tracker.png", 0.17, np.array([0, 187, 230]),  np.array([0, 255, 255]), (0, 255, 0))
        self.enemy_on_radar = is_present
        
        # Check if red bar is present:
        img, is_present = self.matchImage("./images/probable_enemy.png", 0.85, np.array([0, 187, 230]),  np.array([0, 255, 255]), (255, 0, 0))
        self.red_bar_on_view = is_present
        
        if is_present == 1:
            self.continuous_lockon_enemy += 1;
        else:
            self.continuous_lockon_enemy = 0;

        
        img, is_present = self.matchImage("./images/collide_tracker.png", 0.50, np.array([0, 0, 145]), np.array([161, 43, 255]), (0, 255, 0))

        if is_present == 1:
            self.collision_time += 1;
        else:
            self.collision = 0;
        
        if self.target_locked != 1:
            self.time_until_last_locked_target += 1
            self.reload_1 += 1

            
        if action == 0:  # Action: Do nothing
            self.reward = 0
            pass
        elif action == 1: # Action: fire
            if self.target_locked == 1 and self.reload_1 > 10:
                self.time_until_last_locked_target = 0
                self.reload_1 = 0
                self.fire_key_1()
                self.damage_done = 1
                self.reward = 40
        elif action == 2:  # Action: Random movements
            if self.target_locked != 1:
                self.reward = 0
                self.do_random_movement()
        elif action == 3:  # Action: Adjust view
            if self.target_locked != 1:
                self.reward = 0
                self.do_random_viewing()

    def get_state(self):
        # Return the current state as a vector
        return [
            self.target_locked,
            self.damage_done,
            self.movement_horizontal,
            self.movement_vertical,
            self.adjust_view,
            self.ship_health_percentage,
            self.time_until_last_locked_target,
            self.enemy_on_radar,
            self.red_bar_on_view,
            self.continuous_lockon_enemy,
            self.collision_time
        ]

    def calculate_reward(self):
        # Reward initialization
        reward = self.reward
    
        # Check if fire is initialized
        if self.target_locked == 1:
            reward += 7*(self.target_locked_time)
            reward += 30
    
        # Check if the enemy is on radar
        if self.enemy_on_radar == 1:
            reward += 10
            
        # Check if not locked on
        if self.target_locked == 0:
            reward -= 15
    
        # Check if no enemy on the radar
        if self.enemy_on_radar == 0:
            reward -= 3*(self.time_until_last_locked_target)
    
        # Check if red bar on the view
        if self.red_bar_on_view == 1:
            reward += 3*(self.continuous_lockon_enemy)

        if self.collision_time == 0:
            reward += 1
        elif self.collision_time > 1:
            reward -= 6
        elif self.collision_time > 3:
            reward -= 25
    
        return reward


    def check_done(self):
        # Implement your termination condition
        if self.time_until_last_locked_target > 80:
            self.reset()
            return True  # Environment terminates when time exceeds 80
        else:
            return False  # Environment continues if time is less than or equal to 80
        
    def update_screenshot(self):
        region = {'left': 0, 'top': 0, 'width': 950, 'height': 566}
        with mss() as sct:
            screenshot = sct.grab(region)
            cur_img = Image.frombytes('RGB', screenshot.size, screenshot.rgb)
            screenshot_cv2 = cv2.cvtColor(np.array(cur_img), cv2.COLOR_RGB2BGR)
    
        return screenshot_cv2

    def fire_key_1(self):
        pyautogui.keyDown("Num1")
        time.sleep(0.1)
        pyautogui.keyUp("Num1")

    def move_left(self):
        pyautogui.keyDown("A")
        time.sleep(0.2)
        pyautogui.keyUp("A")

    def move_right(self):
        pyautogui.keyDown("D")
        time.sleep(0.2)
        pyautogui.keyUp("D")

    def steer_up(self):
        pyautogui.keyDown("W")
        time.sleep(0.1)
        pyautogui.keyUp("W")

    def steer_down(self):
        pyautogui.keyDown("S")
        time.sleep(0.1)
        pyautogui.keyUp("S")

    def view_right(self, t):
        pyautogui.keyDown("L")
        time.sleep(t)
        pyautogui.keyUp("L")

    def view_left(self, t):
        pyautogui.keyDown("J")
        time.sleep(t)
        pyautogui.keyUp("J")
    
    def do_random_movement(self):
        horizontal_movement = random.choice([-1, 0, 1])
        vertical_movement = random.choice([-1, 0, 1])

        if horizontal_movement == -1:
            self.move_left()
        elif horizontal_movement == 1:
            self.move_right()

        if vertical_movement == -1:
            self.steer_down()
            self.vertical_movement = max(self.movement_vertical-1, -1)
        elif vertical_movement == 1:
            self.steer_up()
            self.vertical_movement = min(self.movement_vertical+1, 2)

    def do_random_viewing(self):
        view_direction = random.choice([-1, 1])
        if view_direction == -1:
            self.view_left(0.2)
        elif view_direction == 1:
            self.view_right(0.2)

    def checkRadar(self, pattern_url, threshold, lower_range, upper_range, rect_color):
        pattern_cv2 = cv2.imread(pattern_url)
        tar_gray = cv2.cvtColor(pattern_cv2, cv2.COLOR_BGR2GRAY)
    
    
        # Convert to HSV Color Space:
        # screenshot_cv2 = np.array(screenshot)[:, :, ::-1].copy()
        screenshot_cv2 = self.current_screenshot
        height, width, _ = screenshot_cv2.shape
        
        # Calculate the cropping dimensions
        crop_top_percentage = 0.4
        crop_left_percentage = 0.25
        crop_top = int(height * crop_top_percentage)
        crop_left = int(width * crop_left_percentage)
    
        
        # Define the cropping percentages
        screenshot_cv2 = screenshot_cv2[:crop_top, :crop_left]    
        
        hsv_image = cv2.cvtColor(screenshot_cv2, cv2.COLOR_BGR2HSV)
    
    
        # Define a Color Range:
        mask = cv2.inRange(hsv_image, lower_range, upper_range)
        
        # Apply the Mask:
        areas = cv2.bitwise_and(screenshot_cv2, screenshot_cv2, mask=mask)
        
        # Grayscale to remove noise:
        gray_image = cv2.cvtColor(areas, cv2.COLOR_BGR2GRAY)
        
        # Template Matching:
        result = cv2.matchTemplate(gray_image, tar_gray, cv2.TM_CCOEFF_NORMED)
    
        is_found = False
        loc = np.where(result >= threshold)
        
        # Now Aiming is complete fire down
        if(len(loc[0]) > 0):
            is_found = True
            pt = (loc[1][0], loc[0][0])  # Take the first match
            cv2.rectangle(screenshot_cv2, pt, (pt[0] + pattern_cv2.shape[1], pt[1] + pattern_cv2.shape[0]), rect_color, 2)
        
        return screenshot_cv2, is_found

    def matchImage(self, pattern_url, threshold, lower_range, upper_range, rect_color):
        pattern_cv2 = cv2.imread(pattern_url)
        tar_gray = cv2.cvtColor(pattern_cv2, cv2.COLOR_BGR2GRAY)
    
    
        # Convert to HSV Color Space:
        screenshot_cv2 = self.current_screenshot
        hsv_image = cv2.cvtColor(screenshot_cv2, cv2.COLOR_BGR2HSV)
    
    
        # Define a Color Range:
        mask = cv2.inRange(hsv_image, lower_range, upper_range)
        
        # Apply the Mask:
        areas = cv2.bitwise_and(screenshot_cv2, screenshot_cv2, mask=mask)
        
        # Grayscale to remove noise:
        gray_image = cv2.cvtColor(areas, cv2.COLOR_BGR2GRAY)
        
        # Template Matching:
        result = cv2.matchTemplate(gray_image, tar_gray, cv2.TM_CCOEFF_NORMED)
    
        is_found = False
        loc = np.where(result >= threshold)
        
        # Now Aiming is complete fire down
        if(len(loc[0]) > 0):
            is_found = True
            pt = (loc[1][0], loc[0][0])  # Take the first match
            cv2.rectangle(screenshot_cv2, pt, (pt[0] + pattern_cv2.shape[1], pt[1] + pattern_cv2.shape[0]), rect_color, 2)
        
        return screenshot_cv2, is_found

class EpisodeTerminationCallback(BaseCallback):
    def __init__(self, episode_limit, verbose=1):
        super(EpisodeTerminationCallback, self).__init__(verbose)
        self.episode_limit = episode_limit
        self.episode_count = 0

    def _on_step(self) -> bool:
        self.episode_count += 1
        if self.episode_count >= self.episode_limit:
            self.episode_count = 0
            print(f"Terminating training after {self.episode_limit} episodes.")
            return False  # Return False to stop training
        return True  # Return True to continue training

env = YourEnvironment()
model = PPO("MlpPolicy", env, verbose=0)


# # Train the model
# model.learn(total_timesteps=48000, callback=episode_limit_callback)



In [22]:
def first_iteration():
    print("First iteration starting in 5 seconds.")
    time.sleep(5)
    episode_limit_callback = EpisodeTerminationCallback(240)
    model.learn(total_timesteps=48000, callback=episode_limit_callback)
    model.save("trained_model.zip")

In [31]:
def start_train(max_itr):
    itr = 0
    while itr < max_itr:   
        if itr == 0:
            time.sleep(20)
            
        if os.path.exists("trained_model.zip"):
            print("File found, starting model training on the base of the saved model.")
            env = YourEnvironment()
            model = PPO.load("trained_model.zip", env, verbose=0)
            episode_limit_callback = EpisodeTerminationCallback(240)
            model.learn(total_timesteps=48000, callback=episode_limit_callback)
            model.save("trained_model.zip")
        else:
            print("File not found, starting first iteration")
            first_iteration();
        
        itr += 1;

In [1]:
start_train(3)